update startup.py:
1. controller改为自动等待 2. 修复:分部分启动时进程引用错误
This commit is contained in:
parent
2275b1d84a
commit
053ebb82bf
76
startup.py
76
startup.py
|
|
@ -210,7 +210,7 @@ def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
|
||||||
q.put(run_seq)
|
q.put(run_seq)
|
||||||
|
|
||||||
|
|
||||||
def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO"):
|
def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Event = None):
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
|
|
@ -223,6 +223,11 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO"):
|
||||||
)
|
)
|
||||||
_set_app_seq(app, q, run_seq)
|
_set_app_seq(app, q, run_seq)
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
def on_startup():
|
||||||
|
if e is not None:
|
||||||
|
e.set()
|
||||||
|
|
||||||
# add interface to release and load model worker
|
# add interface to release and load model worker
|
||||||
@app.post("/release_worker")
|
@app.post("/release_worker")
|
||||||
def release_worker(
|
def release_worker(
|
||||||
|
|
@ -517,8 +522,9 @@ async def start_main_server():
|
||||||
signal.signal(signal.SIGTERM, handler("SIGTERM"))
|
signal.signal(signal.SIGTERM, handler("SIGTERM"))
|
||||||
|
|
||||||
mp.set_start_method("spawn")
|
mp.set_start_method("spawn")
|
||||||
|
manager = mp.Manager()
|
||||||
|
|
||||||
queue = Queue()
|
queue = manager.Queue()
|
||||||
args, parser = parse_args()
|
args, parser = parse_args()
|
||||||
|
|
||||||
if args.all_webui:
|
if args.all_webui:
|
||||||
|
|
@ -558,11 +564,12 @@ async def start_main_server():
|
||||||
else:
|
else:
|
||||||
log_level = "INFO"
|
log_level = "INFO"
|
||||||
|
|
||||||
|
controller_started = manager.Event()
|
||||||
if args.openai_api:
|
if args.openai_api:
|
||||||
process = Process(
|
process = Process(
|
||||||
target=run_controller,
|
target=run_controller,
|
||||||
name=f"controller({os.getpid()})",
|
name=f"controller",
|
||||||
args=(queue, process_count() + 1, log_level),
|
args=(queue, process_count() + 1, log_level, controller_started),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -570,7 +577,7 @@ async def start_main_server():
|
||||||
|
|
||||||
process = Process(
|
process = Process(
|
||||||
target=run_openai_api,
|
target=run_openai_api,
|
||||||
name=f"openai_api({os.getpid()})",
|
name=f"openai_api",
|
||||||
args=(queue, process_count() + 1),
|
args=(queue, process_count() + 1),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
|
|
@ -581,7 +588,7 @@ async def start_main_server():
|
||||||
if not config.get("online_api"):
|
if not config.get("online_api"):
|
||||||
process = Process(
|
process = Process(
|
||||||
target=run_model_worker,
|
target=run_model_worker,
|
||||||
name=f"model_worker - {args.model_name} ({os.getpid()})",
|
name=f"model_worker - {args.model_name}",
|
||||||
args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level),
|
args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
|
|
@ -594,7 +601,7 @@ async def start_main_server():
|
||||||
if config.get("online_api") and config.get("worker_class"):
|
if config.get("online_api") and config.get("worker_class"):
|
||||||
process = Process(
|
process = Process(
|
||||||
target=run_model_worker,
|
target=run_model_worker,
|
||||||
name=f"model_worker - {model_name} ({os.getpid()})",
|
name=f"model_worker - {model_name}",
|
||||||
args=(model_name, args.controller_address, queue, process_count() + 1, log_level),
|
args=(model_name, args.controller_address, queue, process_count() + 1, log_level),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
|
|
@ -604,7 +611,7 @@ async def start_main_server():
|
||||||
if args.api:
|
if args.api:
|
||||||
process = Process(
|
process = Process(
|
||||||
target=run_api_server,
|
target=run_api_server,
|
||||||
name=f"API Server{os.getpid()})",
|
name=f"API Server",
|
||||||
args=(queue, process_count() + 1),
|
args=(queue, process_count() + 1),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
|
|
@ -614,7 +621,7 @@ async def start_main_server():
|
||||||
if args.webui:
|
if args.webui:
|
||||||
process = Process(
|
process = Process(
|
||||||
target=run_webui,
|
target=run_webui,
|
||||||
name=f"WEBUI Server{os.getpid()})",
|
name=f"WEBUI Server",
|
||||||
args=(queue, process_count() + 1),
|
args=(queue, process_count() + 1),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
|
|
@ -626,17 +633,30 @@ async def start_main_server():
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
# 保证任务收到SIGINT后,能够正常退出
|
# 保证任务收到SIGINT后,能够正常退出
|
||||||
processes["controller"].start()
|
if p:= processes.get("controller"):
|
||||||
# 使用await asyncio.sleep(3)可以让后续代码等待一段时间
|
|
||||||
await asyncio.sleep(3)
|
|
||||||
|
|
||||||
processes["openai_api"].start()
|
|
||||||
processes["model_worker"].start()
|
|
||||||
for p in processes["online-api"]:
|
|
||||||
p.start()
|
p.start()
|
||||||
|
p.name = f"{p.name} ({p.pid})"
|
||||||
|
controller_started.wait()
|
||||||
|
|
||||||
processes["api"].start()
|
if p:= processes.get("openai_api"):
|
||||||
processes["webui"].start()
|
p.start()
|
||||||
|
p.name = f"{p.name} ({p.pid})"
|
||||||
|
|
||||||
|
if p:= processes.get("model_worker"):
|
||||||
|
p.start()
|
||||||
|
p.name = f"{p.name} ({p.pid})"
|
||||||
|
|
||||||
|
for p in processes.get("online-api", []):
|
||||||
|
p.start()
|
||||||
|
p.name = f"{p.name} ({p.pid})"
|
||||||
|
|
||||||
|
if p:= processes.get("api"):
|
||||||
|
p.start()
|
||||||
|
p.name = f"{p.name} ({p.pid})"
|
||||||
|
|
||||||
|
if p:= processes.get("webui"):
|
||||||
|
p.start()
|
||||||
|
p.name = f"{p.name} ({p.pid})"
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
no = queue.get()
|
no = queue.get()
|
||||||
|
|
@ -647,17 +667,18 @@ async def start_main_server():
|
||||||
else:
|
else:
|
||||||
queue.put(no)
|
queue.put(no)
|
||||||
|
|
||||||
if model_worker_process := processes.pop("model_worker", None):
|
if model_worker_process := processes.get("model_worker"):
|
||||||
model_worker_process.join()
|
model_worker_process.join()
|
||||||
for process in processes.pop("online-api", []):
|
for process in processes.get("online-api", []):
|
||||||
process.join()
|
process.join()
|
||||||
for name, process in processes.items():
|
for name, process in processes.items():
|
||||||
if isinstance(p, list):
|
if name not in ["model_worker", "online-api"]:
|
||||||
for work_process in p:
|
if isinstance(p, list):
|
||||||
work_process.join()
|
for work_process in p:
|
||||||
else:
|
work_process.join()
|
||||||
process.join()
|
else:
|
||||||
except:
|
process.join()
|
||||||
|
except Exception as e:
|
||||||
# if model_worker_process := processes.pop("model_worker", None):
|
# if model_worker_process := processes.pop("model_worker", None):
|
||||||
# model_worker_process.terminate()
|
# model_worker_process.terminate()
|
||||||
# for process in processes.pop("online-api", []):
|
# for process in processes.pop("online-api", []):
|
||||||
|
|
@ -669,6 +690,7 @@ async def start_main_server():
|
||||||
# work_process.terminate()
|
# work_process.terminate()
|
||||||
# else:
|
# else:
|
||||||
# process.terminate()
|
# process.terminate()
|
||||||
|
logger.error(e)
|
||||||
logger.warning("Caught KeyboardInterrupt! Setting stop event...")
|
logger.warning("Caught KeyboardInterrupt! Setting stop event...")
|
||||||
finally:
|
finally:
|
||||||
# Send SIGINT if process doesn't exit quickly enough, and kill it as last resort
|
# Send SIGINT if process doesn't exit quickly enough, and kill it as last resort
|
||||||
|
|
@ -703,6 +725,8 @@ if __name__ == "__main__":
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
# 同步调用协程代码
|
# 同步调用协程代码
|
||||||
loop.run_until_complete(start_main_server())
|
loop.run_until_complete(start_main_server())
|
||||||
|
|
||||||
|
|
||||||
# 服务启动后接口调用示例:
|
# 服务启动后接口调用示例:
|
||||||
# import openai
|
# import openai
|
||||||
# openai.api_key = "EMPTY" # Not support yet
|
# openai.api_key = "EMPTY" # Not support yet
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue