update startup.py:
1. controller改为自动等待 2. 修复:分部分启动时进程引用错误
This commit is contained in:
parent
2275b1d84a
commit
053ebb82bf
66
startup.py
66
startup.py
|
|
@ -210,7 +210,7 @@ def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
|
|||
q.put(run_seq)
|
||||
|
||||
|
||||
def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO"):
|
||||
def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Event = None):
|
||||
import uvicorn
|
||||
import httpx
|
||||
from fastapi import Body
|
||||
|
|
@ -223,6 +223,11 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO"):
|
|||
)
|
||||
_set_app_seq(app, q, run_seq)
|
||||
|
||||
@app.on_event("startup")
|
||||
def on_startup():
|
||||
if e is not None:
|
||||
e.set()
|
||||
|
||||
# add interface to release and load model worker
|
||||
@app.post("/release_worker")
|
||||
def release_worker(
|
||||
|
|
@ -517,8 +522,9 @@ async def start_main_server():
|
|||
signal.signal(signal.SIGTERM, handler("SIGTERM"))
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
manager = mp.Manager()
|
||||
|
||||
queue = Queue()
|
||||
queue = manager.Queue()
|
||||
args, parser = parse_args()
|
||||
|
||||
if args.all_webui:
|
||||
|
|
@ -558,11 +564,12 @@ async def start_main_server():
|
|||
else:
|
||||
log_level = "INFO"
|
||||
|
||||
controller_started = manager.Event()
|
||||
if args.openai_api:
|
||||
process = Process(
|
||||
target=run_controller,
|
||||
name=f"controller({os.getpid()})",
|
||||
args=(queue, process_count() + 1, log_level),
|
||||
name=f"controller",
|
||||
args=(queue, process_count() + 1, log_level, controller_started),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
|
|
@ -570,7 +577,7 @@ async def start_main_server():
|
|||
|
||||
process = Process(
|
||||
target=run_openai_api,
|
||||
name=f"openai_api({os.getpid()})",
|
||||
name=f"openai_api",
|
||||
args=(queue, process_count() + 1),
|
||||
daemon=True,
|
||||
)
|
||||
|
|
@ -581,7 +588,7 @@ async def start_main_server():
|
|||
if not config.get("online_api"):
|
||||
process = Process(
|
||||
target=run_model_worker,
|
||||
name=f"model_worker - {args.model_name} ({os.getpid()})",
|
||||
name=f"model_worker - {args.model_name}",
|
||||
args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level),
|
||||
daemon=True,
|
||||
)
|
||||
|
|
@ -594,7 +601,7 @@ async def start_main_server():
|
|||
if config.get("online_api") and config.get("worker_class"):
|
||||
process = Process(
|
||||
target=run_model_worker,
|
||||
name=f"model_worker - {model_name} ({os.getpid()})",
|
||||
name=f"model_worker - {model_name}",
|
||||
args=(model_name, args.controller_address, queue, process_count() + 1, log_level),
|
||||
daemon=True,
|
||||
)
|
||||
|
|
@ -604,7 +611,7 @@ async def start_main_server():
|
|||
if args.api:
|
||||
process = Process(
|
||||
target=run_api_server,
|
||||
name=f"API Server{os.getpid()})",
|
||||
name=f"API Server",
|
||||
args=(queue, process_count() + 1),
|
||||
daemon=True,
|
||||
)
|
||||
|
|
@ -614,7 +621,7 @@ async def start_main_server():
|
|||
if args.webui:
|
||||
process = Process(
|
||||
target=run_webui,
|
||||
name=f"WEBUI Server{os.getpid()})",
|
||||
name=f"WEBUI Server",
|
||||
args=(queue, process_count() + 1),
|
||||
daemon=True,
|
||||
)
|
||||
|
|
@ -626,17 +633,30 @@ async def start_main_server():
|
|||
else:
|
||||
try:
|
||||
# 保证任务收到SIGINT后,能够正常退出
|
||||
processes["controller"].start()
|
||||
# 使用await asyncio.sleep(3)可以让后续代码等待一段时间
|
||||
await asyncio.sleep(3)
|
||||
|
||||
processes["openai_api"].start()
|
||||
processes["model_worker"].start()
|
||||
for p in processes["online-api"]:
|
||||
if p:= processes.get("controller"):
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
controller_started.wait()
|
||||
|
||||
processes["api"].start()
|
||||
processes["webui"].start()
|
||||
if p:= processes.get("openai_api"):
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
|
||||
if p:= processes.get("model_worker"):
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
|
||||
for p in processes.get("online-api", []):
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
|
||||
if p:= processes.get("api"):
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
|
||||
if p:= processes.get("webui"):
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
|
||||
while True:
|
||||
no = queue.get()
|
||||
|
|
@ -647,17 +667,18 @@ async def start_main_server():
|
|||
else:
|
||||
queue.put(no)
|
||||
|
||||
if model_worker_process := processes.pop("model_worker", None):
|
||||
if model_worker_process := processes.get("model_worker"):
|
||||
model_worker_process.join()
|
||||
for process in processes.pop("online-api", []):
|
||||
for process in processes.get("online-api", []):
|
||||
process.join()
|
||||
for name, process in processes.items():
|
||||
if name not in ["model_worker", "online-api"]:
|
||||
if isinstance(p, list):
|
||||
for work_process in p:
|
||||
work_process.join()
|
||||
else:
|
||||
process.join()
|
||||
except:
|
||||
except Exception as e:
|
||||
# if model_worker_process := processes.pop("model_worker", None):
|
||||
# model_worker_process.terminate()
|
||||
# for process in processes.pop("online-api", []):
|
||||
|
|
@ -669,6 +690,7 @@ async def start_main_server():
|
|||
# work_process.terminate()
|
||||
# else:
|
||||
# process.terminate()
|
||||
logger.error(e)
|
||||
logger.warning("Caught KeyboardInterrupt! Setting stop event...")
|
||||
finally:
|
||||
# Send SIGINT if process doesn't exit quickly enough, and kill it as last resort
|
||||
|
|
@ -703,6 +725,8 @@ if __name__ == "__main__":
|
|||
asyncio.set_event_loop(loop)
|
||||
# 同步调用协程代码
|
||||
loop.run_until_complete(start_main_server())
|
||||
|
||||
|
||||
# 服务启动后接口调用示例:
|
||||
# import openai
|
||||
# openai.api_key = "EMPTY" # Not support yet
|
||||
|
|
|
|||
Loading…
Reference in New Issue