update format of api_allinone.py and llm_api_launch.py

This commit is contained in:
imClumsyPanda 2023-08-16 22:24:29 +08:00
parent b9fa84635d
commit 8396b57101
2 changed files with 16 additions and 15 deletions

View File

@ -11,10 +11,11 @@ python server/api_allinone.py --model-path-address model@host@port --num-gpus 2
""" """
import sys import sys
import os import os
sys.path.append(os.path.dirname(__file__)) sys.path.append(os.path.dirname(__file__))
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from llm_api_launch import launch_all,parser,controller_args,worker_args,server_args from llm_api_launch import launch_all, parser, controller_args, worker_args, server_args
from api import create_app from api import create_app
import uvicorn import uvicorn
@ -23,8 +24,8 @@ parser.add_argument("--api-port", type=int, default=7861)
parser.add_argument("--ssl_keyfile", type=str) parser.add_argument("--ssl_keyfile", type=str)
parser.add_argument("--ssl_certfile", type=str) parser.add_argument("--ssl_certfile", type=str)
api_args = ["api-host", "api-port", "ssl_keyfile", "ssl_certfile"]
api_args = ["api-host","api-port","ssl_keyfile","ssl_certfile"]
def run_api(host, port, **kwargs): def run_api(host, port, **kwargs):
app = create_app() app = create_app()
@ -38,13 +39,14 @@ def run_api(host, port, **kwargs):
else: else:
uvicorn.run(app, host=host, port=port) uvicorn.run(app, host=host, port=port)
if __name__ == "__main__": if __name__ == "__main__":
print("Luanching api_allinoneit would take a while, please be patient...") print("Luanching api_allinoneit would take a while, please be patient...")
print("正在启动api_allinoneLLM服务启动约3-10分钟请耐心等待...") print("正在启动api_allinoneLLM服务启动约3-10分钟请耐心等待...")
# 初始化消息 # 初始化消息
args = parser.parse_args() args = parser.parse_args()
args_dict = vars(args) args_dict = vars(args)
launch_all(args=args,controller_args=controller_args,worker_args=worker_args,server_args=server_args) launch_all(args=args, controller_args=controller_args, worker_args=worker_args, server_args=server_args)
run_api( run_api(
host=args.api_host, host=args.api_host,
port=args.api_port, port=args.api_port,

View File

@ -132,7 +132,7 @@ worker_args = [
"gptq-ckpt", "gptq-wbits", "gptq-groupsize", "gptq-ckpt", "gptq-wbits", "gptq-groupsize",
"gptq-act-order", "model-names", "limit-worker-concurrency", "gptq-act-order", "model-names", "limit-worker-concurrency",
"stream-interval", "no-register", "stream-interval", "no-register",
"controller-address","worker-address" "controller-address", "worker-address"
] ]
# -----------------openai server--------------------------- # -----------------openai server---------------------------
@ -159,8 +159,6 @@ server_args = ["server-host", "server-port", "allow-credentials", "api-keys",
"controller-address" "controller-address"
] ]
# 0,controller, model_worker, openai_api_server # 0,controller, model_worker, openai_api_server
# 1, 命令行选项 # 1, 命令行选项
# 2,LOG_PATH # 2,LOG_PATH
@ -201,7 +199,7 @@ def string_args(args, args_list):
return args_str return args_str
def launch_worker(item,args,worker_args=worker_args): def launch_worker(item, args, worker_args=worker_args):
log_name = item.split("/")[-1].split("\\")[-1].replace("-", "_").replace("@", "_").replace(".", "_") log_name = item.split("/")[-1].split("\\")[-1].replace("-", "_").replace("@", "_").replace(".", "_")
# 先分割model-path-address,在传到string_args中分析参数 # 先分割model-path-address,在传到string_args中分析参数
args.model_path, args.worker_host, args.worker_port = item.split("@") args.model_path, args.worker_host, args.worker_port = item.split("@")
@ -230,11 +228,11 @@ def launch_all(args,
subprocess.run(controller_check_sh, shell=True, check=True) subprocess.run(controller_check_sh, shell=True, check=True)
print(f"worker启动时间视设备不同而不同约需3-10分钟请耐心等待...") print(f"worker启动时间视设备不同而不同约需3-10分钟请耐心等待...")
if isinstance(args.model_path_address, str): if isinstance(args.model_path_address, str):
launch_worker(args.model_path_address,args=args,worker_args=worker_args) launch_worker(args.model_path_address, args=args, worker_args=worker_args)
else: else:
for idx, item in enumerate(args.model_path_address): for idx, item in enumerate(args.model_path_address):
print(f"开始加载第{idx}个模型:{item}") print(f"开始加载第{idx}个模型:{item}")
launch_worker(item,args=args,worker_args=worker_args) launch_worker(item, args=args, worker_args=worker_args)
server_str_args = string_args(args, server_args) server_str_args = string_args(args, server_args)
server_sh = base_launch_sh.format("openai_api_server", server_str_args, LOG_PATH, "openai_api_server") server_sh = base_launch_sh.format("openai_api_server", server_str_args, LOG_PATH, "openai_api_server")
@ -244,6 +242,7 @@ def launch_all(args,
print("Launching LLM service done!") print("Launching LLM service done!")
print("LLM服务启动完毕。") print("LLM服务启动完毕。")
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
# 必须要加http//:否则InvalidSchema: No connection adapters were found # 必须要加http//:否则InvalidSchema: No connection adapters were found