update startup.py:

1. controller改为自动等待
2. 修复:分部分启动时进程引用错误
This commit is contained in:
liunux4odoo 2023-09-05 09:47:27 +08:00
parent 2275b1d84a
commit 053ebb82bf
1 changed files with 50 additions and 26 deletions

View File

@ -210,7 +210,7 @@ def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
q.put(run_seq)
def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO"):
def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Event = None):
import uvicorn
import httpx
from fastapi import Body
@ -223,6 +223,11 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO"):
)
_set_app_seq(app, q, run_seq)
@app.on_event("startup")
def on_startup():
if e is not None:
e.set()
# add interface to release and load model worker
@app.post("/release_worker")
def release_worker(
@ -517,8 +522,9 @@ async def start_main_server():
signal.signal(signal.SIGTERM, handler("SIGTERM"))
mp.set_start_method("spawn")
manager = mp.Manager()
queue = Queue()
queue = manager.Queue()
args, parser = parse_args()
if args.all_webui:
@ -558,11 +564,12 @@ async def start_main_server():
else:
log_level = "INFO"
controller_started = manager.Event()
if args.openai_api:
process = Process(
target=run_controller,
name=f"controller({os.getpid()})",
args=(queue, process_count() + 1, log_level),
name=f"controller",
args=(queue, process_count() + 1, log_level, controller_started),
daemon=True,
)
@ -570,7 +577,7 @@ async def start_main_server():
process = Process(
target=run_openai_api,
name=f"openai_api({os.getpid()})",
name=f"openai_api",
args=(queue, process_count() + 1),
daemon=True,
)
@ -581,7 +588,7 @@ async def start_main_server():
if not config.get("online_api"):
process = Process(
target=run_model_worker,
name=f"model_worker - {args.model_name} ({os.getpid()})",
name=f"model_worker - {args.model_name}",
args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level),
daemon=True,
)
@ -594,7 +601,7 @@ async def start_main_server():
if config.get("online_api") and config.get("worker_class"):
process = Process(
target=run_model_worker,
name=f"model_worker - {model_name} ({os.getpid()})",
name=f"model_worker - {model_name}",
args=(model_name, args.controller_address, queue, process_count() + 1, log_level),
daemon=True,
)
@ -604,7 +611,7 @@ async def start_main_server():
if args.api:
process = Process(
target=run_api_server,
name=f"API Server{os.getpid()})",
name=f"API Server",
args=(queue, process_count() + 1),
daemon=True,
)
@ -614,7 +621,7 @@ async def start_main_server():
if args.webui:
process = Process(
target=run_webui,
name=f"WEBUI Server{os.getpid()})",
name=f"WEBUI Server",
args=(queue, process_count() + 1),
daemon=True,
)
@ -626,17 +633,30 @@ async def start_main_server():
else:
try:
# 保证任务收到SIGINT后能够正常退出
processes["controller"].start()
# 使用await asyncio.sleep(3)可以让后续代码等待一段时间
await asyncio.sleep(3)
processes["openai_api"].start()
processes["model_worker"].start()
for p in processes["online-api"]:
if p:= processes.get("controller"):
p.start()
p.name = f"{p.name} ({p.pid})"
controller_started.wait()
processes["api"].start()
processes["webui"].start()
if p:= processes.get("openai_api"):
p.start()
p.name = f"{p.name} ({p.pid})"
if p:= processes.get("model_worker"):
p.start()
p.name = f"{p.name} ({p.pid})"
for p in processes.get("online-api", []):
p.start()
p.name = f"{p.name} ({p.pid})"
if p:= processes.get("api"):
p.start()
p.name = f"{p.name} ({p.pid})"
if p:= processes.get("webui"):
p.start()
p.name = f"{p.name} ({p.pid})"
while True:
no = queue.get()
@ -647,17 +667,18 @@ async def start_main_server():
else:
queue.put(no)
if model_worker_process := processes.pop("model_worker", None):
if model_worker_process := processes.get("model_worker"):
model_worker_process.join()
for process in processes.pop("online-api", []):
for process in processes.get("online-api", []):
process.join()
for name, process in processes.items():
if name not in ["model_worker", "online-api"]:
if isinstance(p, list):
for work_process in p:
work_process.join()
else:
process.join()
except:
except Exception as e:
# if model_worker_process := processes.pop("model_worker", None):
# model_worker_process.terminate()
# for process in processes.pop("online-api", []):
@ -669,6 +690,7 @@ async def start_main_server():
# work_process.terminate()
# else:
# process.terminate()
logger.error(e)
logger.warning("Caught KeyboardInterrupt! Setting stop event...")
finally:
# Send SIGINT if process doesn't exit quickly enough, and kill it as last resort
@ -703,6 +725,8 @@ if __name__ == "__main__":
asyncio.set_event_loop(loop)
# 同步调用协程代码
loop.run_until_complete(start_main_server())
# 服务启动后接口调用示例:
# import openai
# openai.api_key = "EMPTY" # Not support yet