Langchain-Chatchat/server/llm_api_sh.py

99 lines
4.1 KiB
Python
Raw Normal View History

2023-08-03 14:39:00 +08:00
"""
调用示例: python llm_api_sh.py --model-path-address THUDM/chatglm2-6b@localhost@7650 THUDM/chatglm2-6b-32k@localhost@7651
其他fastchat.server.controller/worker/openai_api_server参数可按照fastchat文档调用
但少数非关键参数如--worker-address,--allowed-origins,--allowed-methods,--allowed-headers不支持
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import LOG_PATH,controller_args,worker_args,server_args,parser
import subprocess
import re
import argparse
args = parser.parse_args()
# 必须要加http//:否则InvalidSchema: No connection adapters were found
args = argparse.Namespace(**vars(args),**{"controller-address":f"http://{args.controller_host}:{str(args.controller_port)}"})
if args.gpus:
if len(args.gpus.split(",")) < args.num_gpus:
raise ValueError(
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
# 0,controller, model_worker, openai_api_server
# 1, 命令行选项
# 2,LOG_PATH
# 3, log的文件名
base_launch_sh = "nohup python3 -m fastchat.serve.{0} {1} >{2}/{3}.log 2>&1 &"
# 0 log_path
#! 1 log的文件名必须与bash_launch_sh一致
# 2 controller, worker, openai_api_server
base_check_sh = """while [ `grep -c "Uvicorn running on" {0}/{1}.log` -eq '0' ];do
sleep 1s;
echo "wait {2} running"
done
echo '{2} running' """
def string_args(args,args_list):
"""将args中的key转化为字符串"""
args_str = ""
for key, value in args._get_kwargs():
# args._get_kwargs中的key以_为分隔符,先转换再判断是否在指定的args列表中
key = key.replace("_","-")
if key not in args_list:
continue
# fastchat中port,host没有前缀去除前缀
key = key.split("-")[-1] if re.search("port|host",key) else key
if not value:
pass
# 1==True -> True
elif isinstance(value,bool) and value == True:
args_str += f" --{key} "
elif isinstance(value, list) or isinstance(value, tuple) or isinstance(value, set):
value = " ".join(value)
args_str += f" --{key} {value} "
else:
args_str += f" --{key} {value} "
return args_str
def launch_worker(item):
log_name = item.split("/")[-1].split("\\")[-1].replace("-","_").replace("@","_").replace(".","_")
# 先分割model-path-address,在传到string_args中分析参数
args.model_path,args.worker_host, args.worker_port = item.split("@")
2023-08-02 09:19:58 +08:00
print("*"*80)
worker_str_args = string_args(args,worker_args)
print(worker_str_args)
worker_sh = base_launch_sh.format("model_worker",worker_str_args,LOG_PATH,f"worker_{log_name}")
worker_check_sh = base_check_sh.format(LOG_PATH,f"worker_{log_name}","model_worker")
subprocess.run(worker_sh,shell=True,check=True)
subprocess.run(worker_check_sh,shell=True,check=True)
def launch_all():
controller_str_args = string_args(args,controller_args)
controller_sh = base_launch_sh.format("controller",controller_str_args,LOG_PATH,"controller")
controller_check_sh = base_check_sh.format(LOG_PATH,"controller","controller")
subprocess.run(controller_sh,shell=True,check=True)
subprocess.run(controller_check_sh,shell=True,check=True)
if isinstance(args.model_path_address, str):
launch_worker(args.model_path_address)
else:
for idx,item in enumerate(args.model_path_address):
print(f"开始加载第{idx}个模型:{item}")
launch_worker(item)
server_str_args = string_args(args,server_args)
server_sh = base_launch_sh.format("openai_api_server",server_str_args,LOG_PATH,"openai_api_server")
server_check_sh = base_check_sh.format(LOG_PATH,"openai_api_server","openai_api_server")
subprocess.run(server_sh,shell=True,check=True)
subprocess.run(server_check_sh,shell=True,check=True)
if __name__ == "__main__":
launch_all()