diff --git a/startup.py b/startup.py index 3680e37..4a55cd0 100644 --- a/startup.py +++ b/startup.py @@ -1,6 +1,7 @@ from multiprocessing import Process, Queue import multiprocessing as mp import subprocess +import asyncio import sys import os from pprint import pprint @@ -8,6 +9,7 @@ from pprint import pprint # 设置numexpr最大线程数,默认为CPU核心数 try: import numexpr + n_cores = numexpr.utils.detect_number_of_cores() os.environ["NUMEXPR_MAX_THREADS"] = str(n_cores) except: @@ -19,9 +21,9 @@ from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER, FSCHAT_OPENAI_API, ) from server.utils import (fschat_controller_address, fschat_model_worker_address, - fschat_openai_api_address, set_httpx_timeout, - get_model_worker_config, get_all_model_worker_configs, - MakeFastAPIOffline, FastAPI, llm_device, embedding_device) + fschat_openai_api_address, set_httpx_timeout, + get_model_worker_config, get_all_model_worker_configs, + MakeFastAPIOffline, FastAPI, llm_device, embedding_device) import argparse from typing import Tuple, List, Dict from configs import VERSION @@ -101,8 +103,8 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse # 在线模型API if worker_class := kwargs.get("worker_class"): worker = worker_class(model_names=args.model_names, - controller_addr=args.controller_address, - worker_addr=args.worker_address) + controller_addr=args.controller_address, + worker_addr=args.worker_address) # 本地模型 else: # workaround to make program exit with Ctrl+c @@ -208,7 +210,7 @@ def _set_app_seq(app: FastAPI, q: Queue, run_seq: int): q.put(run_seq) -def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"): +def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO"): import uvicorn import httpx from fastapi import Body @@ -224,10 +226,10 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"): # add interface to release and load model worker @app.post("/release_worker") def release_worker( - model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]), - # worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[fschat_controller_address()]), - new_model_name: str = Body(None, description="释放后加载该模型"), - keep_origin: bool = Body(False, description="不释放原模型,加载新模型") + model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]), + # worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[fschat_controller_address()]), + new_model_name: str = Body(None, description="释放后加载该模型"), + keep_origin: bool = Body(False, description="不释放原模型,加载新模型") ) -> Dict: available_models = app._controller.list_models() if new_model_name in available_models: @@ -252,14 +254,14 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"): return {"code": 500, "msg": msg} r = httpx.post(worker_address + "/release", - json={"new_model_name": new_model_name, "keep_origin": keep_origin}) + json={"new_model_name": new_model_name, "keep_origin": keep_origin}) if r.status_code != 200: msg = f"failed to release model: {model_name}" logger.error(msg) return {"code": 500, "msg": msg} if new_model_name: - timer = 300 # wait 5 minutes for new model_worker register + timer = 300 # wait 5 minutes for new model_worker register while timer > 0: models = app._controller.list_models() if new_model_name in models: @@ -294,7 +296,7 @@ def run_model_worker( controller_address: str = "", q: Queue = None, run_seq: int = 2, - log_level: str ="INFO", + log_level: str = "INFO", ): import uvicorn from fastapi import Body @@ -318,8 +320,8 @@ def run_model_worker( # add interface to release and load model @app.post("/release") def release_model( - new_model_name: str = Body(None, description="释放后加载该模型"), - keep_origin: bool = Body(False, description="不释放原模型,加载新模型") + new_model_name: str = Body(None, description="释放后加载该模型"), + keep_origin: bool = Body(False, description="不释放原模型,加载新模型") ) -> Dict: if keep_origin: if new_model_name: @@ -497,15 +499,15 @@ def dump_server_info(after_start=False, args=None): print("\n") -if __name__ == "__main__": +async def start_main_server(): import time mp.set_start_method("spawn") # TODO 链式启动的队列,确实可以用于控制启动顺序, # 但目前引入proxy_worker后,启动的独立于框架的work processes无法确认当前的位置, # 导致注册器未启动时,无法注册。整个启动链因为异常被终止 - # 需要将 `run_seq`设置为MAX_int,才能保证在最后启动。 - # 是否需要引入depends_on的机制? + # 使用await asyncio.sleep(3)可以让后续代码等待一段时间,但不是最优解 + queue = Queue() args, parser = parse_args() @@ -539,7 +541,7 @@ if __name__ == "__main__": processes = {"online-api": []} def process_count(): - return len(processes) + len(processes["online-api"]) -1 + return len(processes) + len(processes["online-api"]) - 1 if args.quiet: log_level = "ERROR" @@ -554,6 +556,7 @@ if __name__ == "__main__": daemon=True, ) process.start() + await asyncio.sleep(3) processes["controller"] = process process = Process( @@ -638,6 +641,9 @@ if __name__ == "__main__": process.terminate() +if __name__ == "__main__": + # 同步调用协程代码 + asyncio.get_event_loop().run_until_complete(start_main_server()) # 服务启动后接口调用示例: # import openai # openai.api_key = "EMPTY" # Not support yet