From 6c00c03faa1183f6d0dba715e5a32e5b518f5a46 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Wed, 16 Aug 2023 17:48:55 +0800 Subject: [PATCH] llm_api can be terminated by Ctrl+c. success on windows --- server/llm_api.py | 46 +++++++++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/server/llm_api.py b/server/llm_api.py index 49d8292..ab71b3d 100644 --- a/server/llm_api.py +++ b/server/llm_api.py @@ -1,4 +1,5 @@ from multiprocessing import Process, Queue +import multiprocessing as mp import sys import os @@ -12,7 +13,6 @@ controller_port = 20001 model_worker_port = 20002 openai_api_port = 8888 base_url = "http://127.0.0.1:{}" -queue = Queue() def set_httpx_timeout(timeout=60.0): @@ -64,6 +64,18 @@ def create_model_worker_app( fastchat.constants.LOGDIR = LOG_PATH from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id import argparse + import threading + import fastchat.serve.model_worker + + # workaround to make program exit with Ctrl+c + # it should be deleted after pr is merged by fastchat + def _new_init_heart_beat(self): + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True, + ) + self.heart_beat_thread.start() + ModelWorker.init_heart_beat = _new_init_heart_beat parser = argparse.ArgumentParser() args = parser.parse_args() @@ -214,6 +226,8 @@ def run_openai_api(q): if __name__ == "__main__": + mp.set_start_method("spawn") + queue = Queue() logger.info(llm_model_dict[LLM_MODEL]) model_path = llm_model_dict[LLM_MODEL]["local_model_path"] @@ -230,15 +244,14 @@ if __name__ == "__main__": ) controller_process.start() - # cuda 没办法用在fork的多进程中 - # model_worker_process = Process( - # target=run_model_worker, - # name=f"model_worker({os.getpid()})", - # args=(queue,), - # # kwargs={"load_8bit": True}, - # daemon=True, - # ) - # model_worker_process.start() + model_worker_process = Process( + target=run_model_worker, + name=f"model_worker({os.getpid()})", + args=(queue,), + # kwargs={"load_8bit": True}, + daemon=True, + ) + model_worker_process.start() openai_api_process = Process( target=run_openai_api, @@ -248,11 +261,14 @@ if __name__ == "__main__": ) openai_api_process.start() - run_model_worker(queue) - - controller_process.join() - # model_worker_process.join() - openai_api_process.join() + try: + model_worker_process.join() + controller_process.join() + openai_api_process.join() + except KeyboardInterrupt: + model_worker_process.terminate() + controller_process.terminate() + openai_api_process.terminate() # 服务启动后接口调用示例: # import openai