注册器未启动时,整个启动链因为异常被终止

使用await asyncio.sleep(3)可以让后续代码等待一段时间,但不是最优解
This commit is contained in:
glide-the 2023-09-02 19:20:41 +08:00
parent 3391155077
commit d7c884e34a
1 changed files with 25 additions and 19 deletions

View File

@ -1,6 +1,7 @@
from multiprocessing import Process, Queue
import multiprocessing as mp
import subprocess
import asyncio
import sys
import os
from pprint import pprint
@ -8,6 +9,7 @@ from pprint import pprint
# 设置numexpr最大线程数默认为CPU核心数
try:
import numexpr
n_cores = numexpr.utils.detect_number_of_cores()
os.environ["NUMEXPR_MAX_THREADS"] = str(n_cores)
except:
@ -19,9 +21,9 @@ from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG
from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER,
FSCHAT_OPENAI_API, )
from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_timeout,
get_model_worker_config, get_all_model_worker_configs,
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
fschat_openai_api_address, set_httpx_timeout,
get_model_worker_config, get_all_model_worker_configs,
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
import argparse
from typing import Tuple, List, Dict
from configs import VERSION
@ -101,8 +103,8 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse
# 在线模型API
if worker_class := kwargs.get("worker_class"):
worker = worker_class(model_names=args.model_names,
controller_addr=args.controller_address,
worker_addr=args.worker_address)
controller_addr=args.controller_address,
worker_addr=args.worker_address)
# 本地模型
else:
# workaround to make program exit with Ctrl+c
@ -208,7 +210,7 @@ def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
q.put(run_seq)
def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"):
def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO"):
import uvicorn
import httpx
from fastapi import Body
@ -224,10 +226,10 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"):
# add interface to release and load model worker
@app.post("/release_worker")
def release_worker(
model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]),
# worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[fschat_controller_address()]),
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]),
# worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[fschat_controller_address()]),
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict:
available_models = app._controller.list_models()
if new_model_name in available_models:
@ -252,14 +254,14 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"):
return {"code": 500, "msg": msg}
r = httpx.post(worker_address + "/release",
json={"new_model_name": new_model_name, "keep_origin": keep_origin})
json={"new_model_name": new_model_name, "keep_origin": keep_origin})
if r.status_code != 200:
msg = f"failed to release model: {model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
if new_model_name:
timer = 300 # wait 5 minutes for new model_worker register
timer = 300 # wait 5 minutes for new model_worker register
while timer > 0:
models = app._controller.list_models()
if new_model_name in models:
@ -294,7 +296,7 @@ def run_model_worker(
controller_address: str = "",
q: Queue = None,
run_seq: int = 2,
log_level: str ="INFO",
log_level: str = "INFO",
):
import uvicorn
from fastapi import Body
@ -318,8 +320,8 @@ def run_model_worker(
# add interface to release and load model
@app.post("/release")
def release_model(
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict:
if keep_origin:
if new_model_name:
@ -497,15 +499,15 @@ def dump_server_info(after_start=False, args=None):
print("\n")
if __name__ == "__main__":
async def start_main_server():
import time
mp.set_start_method("spawn")
# TODO 链式启动的队列,确实可以用于控制启动顺序,
# 但目前引入proxy_worker后启动的独立于框架的work processes无法确认当前的位置
# 导致注册器未启动时,无法注册。整个启动链因为异常被终止
# 需要将 `run_seq`设置为MAX_int,才能保证在最后启动。
# 是否需要引入depends_on的机制
# 使用await asyncio.sleep(3)可以让后续代码等待一段时间,但不是最优解
queue = Queue()
args, parser = parse_args()
@ -539,7 +541,7 @@ if __name__ == "__main__":
processes = {"online-api": []}
def process_count():
return len(processes) + len(processes["online-api"]) -1
return len(processes) + len(processes["online-api"]) - 1
if args.quiet:
log_level = "ERROR"
@ -554,6 +556,7 @@ if __name__ == "__main__":
daemon=True,
)
process.start()
await asyncio.sleep(3)
processes["controller"] = process
process = Process(
@ -638,6 +641,9 @@ if __name__ == "__main__":
process.terminate()
if __name__ == "__main__":
# 同步调用协程代码
asyncio.get_event_loop().run_until_complete(start_main_server())
# 服务启动后接口调用示例:
# import openai
# openai.api_key = "EMPTY" # Not support yet