启动进程放入 try catch 保证任务收到SIGINT后,能够正常退出

This commit is contained in:
glide-the 2023-09-04 23:03:56 +08:00
parent aa4a5ad224
commit 35f2c04535
1 changed files with 25 additions and 13 deletions

View File

@ -565,9 +565,7 @@ async def start_main_server():
args=(queue, process_count() + 1, log_level), args=(queue, process_count() + 1, log_level),
daemon=True, daemon=True,
) )
process.start()
# 使用await asyncio.sleep(3)可以让后续代码等待一段时间
await asyncio.sleep(3)
processes["controller"] = process processes["controller"] = process
process = Process( process = Process(
@ -576,7 +574,6 @@ async def start_main_server():
args=(queue, process_count() + 1), args=(queue, process_count() + 1),
daemon=True, daemon=True,
) )
process.start()
processes["openai_api"] = process processes["openai_api"] = process
if args.model_worker: if args.model_worker:
@ -588,7 +585,7 @@ async def start_main_server():
args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level), args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level),
daemon=True, daemon=True,
) )
process.start()
processes["model_worker"] = process processes["model_worker"] = process
if args.api_worker: if args.api_worker:
@ -601,7 +598,7 @@ async def start_main_server():
args=(model_name, args.controller_address, queue, process_count() + 1, log_level), args=(model_name, args.controller_address, queue, process_count() + 1, log_level),
daemon=True, daemon=True,
) )
process.start()
processes["online-api"].append(process) processes["online-api"].append(process)
if args.api: if args.api:
@ -611,7 +608,7 @@ async def start_main_server():
args=(queue, process_count() + 1), args=(queue, process_count() + 1),
daemon=True, daemon=True,
) )
process.start()
processes["api"] = process processes["api"] = process
if args.webui: if args.webui:
@ -621,13 +618,26 @@ async def start_main_server():
args=(queue, process_count() + 1), args=(queue, process_count() + 1),
daemon=True, daemon=True,
) )
process.start()
processes["webui"] = process processes["webui"] = process
if process_count() == 0: if process_count() == 0:
parser.print_help() parser.print_help()
else: else:
try: try:
# 保证任务收到SIGINT后能够正常退出
processes["controller"].start()
# 使用await asyncio.sleep(3)可以让后续代码等待一段时间
await asyncio.sleep(3)
processes["openai_api"].start()
processes["model_worker"].start()
for p in processes["online-api"]:
p.start()
processes["api"].start()
processes["webui"].start()
while True: while True:
no = queue.get() no = queue.get()
if no == process_count(): if no == process_count():
@ -652,16 +662,19 @@ async def start_main_server():
# model_worker_process.terminate() # model_worker_process.terminate()
# for process in processes.pop("online-api", []): # for process in processes.pop("online-api", []):
# process.terminate() # process.terminate()
# for name, process in processes.items(): # for process in processes.values():
# process.terminate() #
# if isinstance(process, list):
# for work_process in process:
# work_process.terminate()
# else:
# process.terminate()
logger.warning("Caught KeyboardInterrupt! Setting stop event...") logger.warning("Caught KeyboardInterrupt! Setting stop event...")
finally: finally:
# Send SIGINT if process doesn't exit quickly enough, and kill it as last resort # Send SIGINT if process doesn't exit quickly enough, and kill it as last resort
# .is_alive() also implicitly joins the process (good practice in linux) # .is_alive() also implicitly joins the process (good practice in linux)
# while alive_procs := [p for p in processes.values() if p.is_alive()]: # while alive_procs := [p for p in processes.values() if p.is_alive()]:
for p in processes.values():
logger.info("Process status: %s", p)
for p in processes.values(): for p in processes.values():
logger.warning("Sending SIGKILL to %s", p) logger.warning("Sending SIGKILL to %s", p)
# Queues and other inter-process communication primitives can break when # Queues and other inter-process communication primitives can break when
@ -674,7 +687,6 @@ async def start_main_server():
else: else:
p.kill() p.kill()
sleep(.01)
for p in processes.values(): for p in processes.values():
logger.info("Process status: %s", p) logger.info("Process status: %s", p)