diff --git a/.gitignore b/.gitignore index a7ef90f..af75440 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ __pycache__/ configs/*.py .vscode/ .pytest_cache/ +*.bak diff --git a/startup.py b/startup.py index 591e3b1..15af03e 100644 --- a/startup.py +++ b/startup.py @@ -3,7 +3,8 @@ import multiprocessing as mp import os import subprocess import sys -from multiprocessing import Process, Queue +from multiprocessing import Process +from datetime import datetime from pprint import pprint # 设置numexpr最大线程数,默认为CPU核心数 @@ -19,7 +20,7 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__))) from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \ logger from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER, - FSCHAT_OPENAI_API, ) + FSCHAT_OPENAI_API, HTTPX_DEFAULT_TIMEOUT) from server.utils import (fschat_controller_address, fschat_model_worker_address, fschat_openai_api_address, set_httpx_timeout, get_model_worker_config, get_all_model_worker_configs, @@ -47,7 +48,7 @@ def create_controller_app( return app -def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]: +def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: import fastchat.constants fastchat.constants.LOGDIR = LOG_PATH from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger @@ -188,29 +189,15 @@ def create_openai_api_app( return app -def _set_app_seq(app: FastAPI, q: Queue, run_seq: int): - if q is None or not isinstance(run_seq, int): - return - - if run_seq == 1: - @app.on_event("startup") - async def on_startup(): - set_httpx_timeout() - q.put(run_seq) - elif run_seq > 1: - @app.on_event("startup") - async def on_startup(): - set_httpx_timeout() - while True: - no = q.get() - if no != run_seq - 1: - q.put(no) - else: - break - q.put(run_seq) +def _set_app_event(app: FastAPI, started_event: mp.Event = None): + @app.on_event("startup") + async def on_startup(): + set_httpx_timeout() + if started_event is not None: + started_event.set() -def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Event = None): +def run_controller(log_level: str = "INFO", started_event: mp.Event = None): import uvicorn import httpx from fastapi import Body @@ -221,12 +208,7 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Ev dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"), log_level=log_level, ) - _set_app_seq(app, q, run_seq) - - @app.on_event("startup") - def on_startup(): - if e is not None: - e.set() + _set_app_event(app, started_event) # add interface to release and load model worker @app.post("/release_worker") @@ -266,7 +248,7 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Ev return {"code": 500, "msg": msg} if new_model_name: - timer = 300 # wait 5 minutes for new model_worker register + timer = HTTPX_DEFAULT_TIMEOUT * 2 # wait for new model_worker register while timer > 0: models = app._controller.list_models() if new_model_name in models: @@ -299,9 +281,9 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Ev def run_model_worker( model_name: str = LLM_MODEL, controller_address: str = "", - q: Queue = None, - run_seq: int = 2, log_level: str = "INFO", + q: mp.Queue = None, + started_event: mp.Event = None, ): import uvicorn from fastapi import Body @@ -317,7 +299,7 @@ def run_model_worker( kwargs["model_path"] = model_path app = create_model_worker_app(log_level=log_level, **kwargs) - _set_app_seq(app, q, run_seq) + _set_app_event(app, started_event) if log_level == "ERROR": sys.stdout = sys.__stdout__ sys.stderr = sys.__stderr__ @@ -325,29 +307,29 @@ def run_model_worker( # add interface to release and load model @app.post("/release") def release_model( - new_model_name: str = Body(None, description="释放后加载该模型"), - keep_origin: bool = Body(False, description="不释放原模型,加载新模型") + new_model_name: str = Body(None, description="释放后加载该模型"), + keep_origin: bool = Body(False, description="不释放原模型,加载新模型") ) -> Dict: if keep_origin: if new_model_name: - q.put(["start", new_model_name]) + q.put([model_name, "start", new_model_name]) else: if new_model_name: - q.put(["replace", new_model_name]) + q.put([model_name, "replace", new_model_name]) else: - q.put(["stop"]) + q.put([model_name, "stop", None]) return {"code": 200, "msg": "done"} uvicorn.run(app, host=host, port=port, log_level=log_level.lower()) -def run_openai_api(q: Queue, run_seq: int = 3, log_level: str = "INFO"): +def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None): import uvicorn import sys controller_addr = fschat_controller_address() app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet. - _set_app_seq(app, q, run_seq) + _set_app_event(app, started_event) host = FSCHAT_OPENAI_API["host"] port = FSCHAT_OPENAI_API["port"] @@ -357,12 +339,12 @@ def run_openai_api(q: Queue, run_seq: int = 3, log_level: str = "INFO"): uvicorn.run(app, host=host, port=port) -def run_api_server(q: Queue, run_seq: int = 4): +def run_api_server(started_event: mp.Event = None): from server.api import create_app import uvicorn app = create_app() - _set_app_seq(app, q, run_seq) + _set_app_event(app, started_event) host = API_SERVER["host"] port = API_SERVER["port"] @@ -370,21 +352,14 @@ def run_api_server(q: Queue, run_seq: int = 4): uvicorn.run(app, host=host, port=port) -def run_webui(q: Queue, run_seq: int = 5): +def run_webui(started_event: mp.Event = None): host = WEBUI_SERVER["host"] port = WEBUI_SERVER["port"] - if q is not None and isinstance(run_seq, int): - while True: - no = q.get() - if no != run_seq - 1: - q.put(no) - else: - break - q.put(run_seq) p = subprocess.Popen(["streamlit", "run", "webui.py", "--server.address", host, "--server.port", str(port)]) + started_event.set() p.wait() @@ -427,8 +402,9 @@ def parse_args() -> argparse.ArgumentParser: "-n", "--model-name", type=str, - default=LLM_MODEL, - help="specify model name for model worker.", + nargs="+", + default=[LLM_MODEL], + help="specify model name for model worker. add addition names with space seperated to start multiple model workers.", dest="model_name", ) parser.add_argument( @@ -483,11 +459,12 @@ def dump_server_info(after_start=False, args=None): print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}") print("\n") - model = LLM_MODEL + models = [LLM_MODEL] if args and args.model_name: - model = args.model_name - print(f"当前LLM模型:{model} @ {llm_device()}") - pprint(llm_model_dict[model]) + models = args.model_name + print(f"当前启动的LLM模型:{models} @ {llm_device()}") + for model in models: + pprint(llm_model_dict[model]) print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}") if after_start: @@ -554,10 +531,10 @@ async def start_main_server(): logger.info(f"正在启动服务:") logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}") - processes = {"online-api": []} + processes = {"online_api": {}, "model_worker": {}} def process_count(): - return len(processes) + len(processes["online-api"]) - 1 + return len(processes) + len(processes["online_api"]) - 1 if args.quiet: log_level = "ERROR" @@ -569,63 +546,73 @@ async def start_main_server(): process = Process( target=run_controller, name=f"controller", - args=(queue, process_count() + 1, log_level, controller_started), + kwargs=dict(log_level=log_level, started_event=controller_started), daemon=True, ) - processes["controller"] = process process = Process( target=run_openai_api, name=f"openai_api", - args=(queue, process_count() + 1), daemon=True, ) processes["openai_api"] = process + model_worker_started = [] if args.model_worker: - config = get_model_worker_config(args.model_name) - if not config.get("online_api"): - process = Process( - target=run_model_worker, - name=f"model_worker - {args.model_name}", - args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level), - daemon=True, - ) - - processes["model_worker"] = process + for model_name in args.model_name: + config = get_model_worker_config(model_name) + if not config.get("online_api"): + e = manager.Event() + model_worker_started.append(e) + process = Process( + target=run_model_worker, + name=f"model_worker - {model_name}", + kwargs=dict(model_name=model_name, + controller_address=args.controller_address, + log_level=log_level, + q=queue, + started_event=e), + daemon=True, + ) + processes["model_worker"][model_name] = process if args.api_worker: configs = get_all_model_worker_configs() for model_name, config in configs.items(): if config.get("online_api") and config.get("worker_class"): + e = manager.Event() + model_worker_started.append(e) process = Process( target=run_model_worker, - name=f"model_worker - {model_name}", - args=(model_name, args.controller_address, queue, process_count() + 1, log_level), + name=f"api_worker - {model_name}", + kwargs=dict(model_name=model_name, + controller_address=args.controller_address, + log_level=log_level, + q=queue, + started_event=e), daemon=True, ) + processes["online_api"][model_name] = process - processes["online-api"].append(process) - + api_started = manager.Event() if args.api: process = Process( target=run_api_server, name=f"API Server", - args=(queue, process_count() + 1), + kwargs=dict(started_event=api_started), daemon=True, ) - processes["api"] = process + webui_started = manager.Event() if args.webui: process = Process( target=run_webui, name=f"WEBUI Server", - args=(queue, process_count() + 1), + kwargs=dict(started_event=webui_started), daemon=True, ) - processes["webui"] = process if process_count() == 0: @@ -636,60 +623,106 @@ async def start_main_server(): if p:= processes.get("controller"): p.start() p.name = f"{p.name} ({p.pid})" - controller_started.wait() + controller_started.wait() # 等待controller启动完成 if p:= processes.get("openai_api"): p.start() p.name = f"{p.name} ({p.pid})" - if p:= processes.get("model_worker"): + for n, p in processes.get("model_worker", {}).items(): p.start() p.name = f"{p.name} ({p.pid})" - for p in processes.get("online-api", []): + for n, p in processes.get("online_api", []).items(): p.start() p.name = f"{p.name} ({p.pid})" + # 等待所有model_worker启动完成 + for e in model_worker_started: + e.wait() + if p:= processes.get("api"): p.start() p.name = f"{p.name} ({p.pid})" + api_started.wait() # 等待api.py启动完成 if p:= processes.get("webui"): p.start() p.name = f"{p.name} ({p.pid})" + webui_started.wait() # 等待webui.py启动完成 + + dump_server_info(after_start=True, args=args) while True: - no = queue.get() - if no == process_count(): - time.sleep(0.5) - dump_server_info(after_start=True, args=args) - break - else: - queue.put(no) + cmd = queue.get() # 收到切换模型的消息 + e = manager.Event() + if isinstance(cmd, list): + model_name, cmd, new_model_name = cmd + if cmd == "start": # 运行新模型 + logger.info(f"准备启动新模型进程:{new_model_name}") + process = Process( + target=run_model_worker, + name=f"model_worker - {new_model_name}", + kwargs=dict(model_name=new_model_name, + controller_address=args.controller_address, + log_level=log_level, + q=queue, + started_event=e), + daemon=True, + ) + process.start() + process.name = f"{process.name} ({process.pid})" + processes["model_worker"][new_model_name] = process + e.wait() + logger.info(f"成功启动新模型进程:{new_model_name}") + elif cmd == "stop": + if process := processes["model_worker"].get(model_name): + time.sleep(1) + process.terminate() + process.join() + logger.info(f"停止模型进程:{model_name}") + else: + logger.error(f"未找到模型进程:{model_name}") + elif cmd == "replace": + if process := processes["model_worker"].pop(model_name, None): + logger.info(f"停止模型进程:{model_name}") + start_time = datetime.now() + time.sleep(1) + process.terminate() + process.join() + process = Process( + target=run_model_worker, + name=f"model_worker - {new_model_name}", + kwargs=dict(model_name=new_model_name, + controller_address=args.controller_address, + log_level=log_level, + q=queue, + started_event=e), + daemon=True, + ) + process.start() + process.name = f"{process.name} ({process.pid})" + processes["model_worker"][new_model_name] = process + e.wait() + timing = datetime.now() - start_time + logger.info(f"成功启动新模型进程:{new_model_name}。用时:{timing}。") + else: + logger.error(f"未找到模型进程:{model_name}") - if model_worker_process := processes.get("model_worker"): - model_worker_process.join() - for process in processes.get("online-api", []): - process.join() - for name, process in processes.items(): - if name not in ["model_worker", "online-api"]: - if isinstance(p, list): - for work_process in p: - work_process.join() - else: - process.join() + + # for process in processes.get("model_worker", {}).values(): + # process.join() + # for process in processes.get("online_api", {}).values(): + # process.join() + + # for name, process in processes.items(): + # if name not in ["model_worker", "online_api"]: + # if isinstance(p, dict): + # for work_process in p.values(): + # work_process.join() + # else: + # process.join() except Exception as e: - # if model_worker_process := processes.pop("model_worker", None): - # model_worker_process.terminate() - # for process in processes.pop("online-api", []): - # process.terminate() - # for process in processes.values(): - # - # if isinstance(process, list): - # for work_process in process: - # work_process.terminate() - # else: - # process.terminate() logger.error(e) logger.warning("Caught KeyboardInterrupt! Setting stop event...") finally: @@ -702,10 +735,9 @@ async def start_main_server(): # Queues and other inter-process communication primitives can break when # process is killed, but we don't care here - if isinstance(p, list): - for process in p: + if isinstance(p, dict): + for process in p.values(): process.kill() - else: p.kill() diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 25b885b..947be74 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -65,7 +65,9 @@ def dialogue_page(api: ApiRequest): ) def on_llm_change(): - st.session_state["prev_llm_model"] = llm_model + config = get_model_worker_config(llm_model) + if not config.get("online_api"): # 只有本地model_worker可以切换模型 + st.session_state["prev_llm_model"] = llm_model def llm_model_format_func(x): if x in running_models: @@ -91,7 +93,7 @@ def dialogue_page(api: ApiRequest): ) if (st.session_state.get("prev_llm_model") != llm_model and not get_model_worker_config(llm_model).get("online_api")): - with st.spinner(f"正在加载模型: {llm_model}"): + with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"): r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model) st.session_state["prev_llm_model"] = llm_model