diff --git a/startup.py b/startup.py index 4c52f46..591e3b1 100644 --- a/startup.py +++ b/startup.py @@ -210,7 +210,7 @@ def _set_app_seq(app: FastAPI, q: Queue, run_seq: int): q.put(run_seq) -def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO"): +def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Event = None): import uvicorn import httpx from fastapi import Body @@ -223,6 +223,11 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO"): ) _set_app_seq(app, q, run_seq) + @app.on_event("startup") + def on_startup(): + if e is not None: + e.set() + # add interface to release and load model worker @app.post("/release_worker") def release_worker( @@ -517,8 +522,9 @@ async def start_main_server(): signal.signal(signal.SIGTERM, handler("SIGTERM")) mp.set_start_method("spawn") + manager = mp.Manager() - queue = Queue() + queue = manager.Queue() args, parser = parse_args() if args.all_webui: @@ -558,11 +564,12 @@ async def start_main_server(): else: log_level = "INFO" + controller_started = manager.Event() if args.openai_api: process = Process( target=run_controller, - name=f"controller({os.getpid()})", - args=(queue, process_count() + 1, log_level), + name=f"controller", + args=(queue, process_count() + 1, log_level, controller_started), daemon=True, ) @@ -570,7 +577,7 @@ async def start_main_server(): process = Process( target=run_openai_api, - name=f"openai_api({os.getpid()})", + name=f"openai_api", args=(queue, process_count() + 1), daemon=True, ) @@ -581,7 +588,7 @@ async def start_main_server(): if not config.get("online_api"): process = Process( target=run_model_worker, - name=f"model_worker - {args.model_name} ({os.getpid()})", + name=f"model_worker - {args.model_name}", args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level), daemon=True, ) @@ -594,7 +601,7 @@ async def start_main_server(): if config.get("online_api") and config.get("worker_class"): process = Process( target=run_model_worker, - name=f"model_worker - {model_name} ({os.getpid()})", + name=f"model_worker - {model_name}", args=(model_name, args.controller_address, queue, process_count() + 1, log_level), daemon=True, ) @@ -604,7 +611,7 @@ async def start_main_server(): if args.api: process = Process( target=run_api_server, - name=f"API Server{os.getpid()})", + name=f"API Server", args=(queue, process_count() + 1), daemon=True, ) @@ -614,7 +621,7 @@ async def start_main_server(): if args.webui: process = Process( target=run_webui, - name=f"WEBUI Server{os.getpid()})", + name=f"WEBUI Server", args=(queue, process_count() + 1), daemon=True, ) @@ -626,17 +633,30 @@ async def start_main_server(): else: try: # 保证任务收到SIGINT后,能够正常退出 - processes["controller"].start() - # 使用await asyncio.sleep(3)可以让后续代码等待一段时间 - await asyncio.sleep(3) - - processes["openai_api"].start() - processes["model_worker"].start() - for p in processes["online-api"]: + if p:= processes.get("controller"): p.start() + p.name = f"{p.name} ({p.pid})" + controller_started.wait() - processes["api"].start() - processes["webui"].start() + if p:= processes.get("openai_api"): + p.start() + p.name = f"{p.name} ({p.pid})" + + if p:= processes.get("model_worker"): + p.start() + p.name = f"{p.name} ({p.pid})" + + for p in processes.get("online-api", []): + p.start() + p.name = f"{p.name} ({p.pid})" + + if p:= processes.get("api"): + p.start() + p.name = f"{p.name} ({p.pid})" + + if p:= processes.get("webui"): + p.start() + p.name = f"{p.name} ({p.pid})" while True: no = queue.get() @@ -647,17 +667,18 @@ async def start_main_server(): else: queue.put(no) - if model_worker_process := processes.pop("model_worker", None): + if model_worker_process := processes.get("model_worker"): model_worker_process.join() - for process in processes.pop("online-api", []): + for process in processes.get("online-api", []): process.join() for name, process in processes.items(): - if isinstance(p, list): - for work_process in p: - work_process.join() - else: - process.join() - except: + if name not in ["model_worker", "online-api"]: + if isinstance(p, list): + for work_process in p: + work_process.join() + else: + process.join() + except Exception as e: # if model_worker_process := processes.pop("model_worker", None): # model_worker_process.terminate() # for process in processes.pop("online-api", []): @@ -669,6 +690,7 @@ async def start_main_server(): # work_process.terminate() # else: # process.terminate() + logger.error(e) logger.warning("Caught KeyboardInterrupt! Setting stop event...") finally: # Send SIGINT if process doesn't exit quickly enough, and kill it as last resort @@ -703,6 +725,8 @@ if __name__ == "__main__": asyncio.set_event_loop(loop) # 同步调用协程代码 loop.run_until_complete(start_main_server()) + + # 服务启动后接口调用示例: # import openai # openai.api_key = "EMPTY" # Not support yet