update startup.py: (#1416)

1. 恢复模型切换功能
2. --model-name支持多个名称(空格分开),同时启动多个模型
3. 优化服务启动顺序。严格按照顺序启动:controller -> [openai-api,
   model_worker, api_worker]并行 -> api.py -> webui.py
4. 修复:从在线API模型切换到本地模型时失败
This commit is contained in:
liunux4odoo 2023-09-08 15:18:13 +08:00 committed by GitHub
parent 775870a516
commit f94f2793f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 153 additions and 118 deletions

1
.gitignore vendored
View File

@ -7,3 +7,4 @@ __pycache__/
configs/*.py configs/*.py
.vscode/ .vscode/
.pytest_cache/ .pytest_cache/
*.bak

View File

@ -3,7 +3,8 @@ import multiprocessing as mp
import os import os
import subprocess import subprocess
import sys import sys
from multiprocessing import Process, Queue from multiprocessing import Process
from datetime import datetime
from pprint import pprint from pprint import pprint
# 设置numexpr最大线程数默认为CPU核心数 # 设置numexpr最大线程数默认为CPU核心数
@ -19,7 +20,7 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \ from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
logger logger
from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER, from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER,
FSCHAT_OPENAI_API, ) FSCHAT_OPENAI_API, HTTPX_DEFAULT_TIMEOUT)
from server.utils import (fschat_controller_address, fschat_model_worker_address, from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_timeout, fschat_openai_api_address, set_httpx_timeout,
get_model_worker_config, get_all_model_worker_configs, get_model_worker_config, get_all_model_worker_configs,
@ -47,7 +48,7 @@ def create_controller_app(
return app return app
def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]: def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
import fastchat.constants import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger
@ -188,29 +189,15 @@ def create_openai_api_app(
return app return app
def _set_app_seq(app: FastAPI, q: Queue, run_seq: int): def _set_app_event(app: FastAPI, started_event: mp.Event = None):
if q is None or not isinstance(run_seq, int): @app.on_event("startup")
return async def on_startup():
set_httpx_timeout()
if run_seq == 1: if started_event is not None:
@app.on_event("startup") started_event.set()
async def on_startup():
set_httpx_timeout()
q.put(run_seq)
elif run_seq > 1:
@app.on_event("startup")
async def on_startup():
set_httpx_timeout()
while True:
no = q.get()
if no != run_seq - 1:
q.put(no)
else:
break
q.put(run_seq)
def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Event = None): def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
import uvicorn import uvicorn
import httpx import httpx
from fastapi import Body from fastapi import Body
@ -221,12 +208,7 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Ev
dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"), dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"),
log_level=log_level, log_level=log_level,
) )
_set_app_seq(app, q, run_seq) _set_app_event(app, started_event)
@app.on_event("startup")
def on_startup():
if e is not None:
e.set()
# add interface to release and load model worker # add interface to release and load model worker
@app.post("/release_worker") @app.post("/release_worker")
@ -266,7 +248,7 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Ev
return {"code": 500, "msg": msg} return {"code": 500, "msg": msg}
if new_model_name: if new_model_name:
timer = 300 # wait 5 minutes for new model_worker register timer = HTTPX_DEFAULT_TIMEOUT * 2 # wait for new model_worker register
while timer > 0: while timer > 0:
models = app._controller.list_models() models = app._controller.list_models()
if new_model_name in models: if new_model_name in models:
@ -299,9 +281,9 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Ev
def run_model_worker( def run_model_worker(
model_name: str = LLM_MODEL, model_name: str = LLM_MODEL,
controller_address: str = "", controller_address: str = "",
q: Queue = None,
run_seq: int = 2,
log_level: str = "INFO", log_level: str = "INFO",
q: mp.Queue = None,
started_event: mp.Event = None,
): ):
import uvicorn import uvicorn
from fastapi import Body from fastapi import Body
@ -317,7 +299,7 @@ def run_model_worker(
kwargs["model_path"] = model_path kwargs["model_path"] = model_path
app = create_model_worker_app(log_level=log_level, **kwargs) app = create_model_worker_app(log_level=log_level, **kwargs)
_set_app_seq(app, q, run_seq) _set_app_event(app, started_event)
if log_level == "ERROR": if log_level == "ERROR":
sys.stdout = sys.__stdout__ sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__ sys.stderr = sys.__stderr__
@ -325,29 +307,29 @@ def run_model_worker(
# add interface to release and load model # add interface to release and load model
@app.post("/release") @app.post("/release")
def release_model( def release_model(
new_model_name: str = Body(None, description="释放后加载该模型"), new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型") keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict: ) -> Dict:
if keep_origin: if keep_origin:
if new_model_name: if new_model_name:
q.put(["start", new_model_name]) q.put([model_name, "start", new_model_name])
else: else:
if new_model_name: if new_model_name:
q.put(["replace", new_model_name]) q.put([model_name, "replace", new_model_name])
else: else:
q.put(["stop"]) q.put([model_name, "stop", None])
return {"code": 200, "msg": "done"} return {"code": 200, "msg": "done"}
uvicorn.run(app, host=host, port=port, log_level=log_level.lower()) uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
def run_openai_api(q: Queue, run_seq: int = 3, log_level: str = "INFO"): def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
import uvicorn import uvicorn
import sys import sys
controller_addr = fschat_controller_address() controller_addr = fschat_controller_address()
app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet. app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet.
_set_app_seq(app, q, run_seq) _set_app_event(app, started_event)
host = FSCHAT_OPENAI_API["host"] host = FSCHAT_OPENAI_API["host"]
port = FSCHAT_OPENAI_API["port"] port = FSCHAT_OPENAI_API["port"]
@ -357,12 +339,12 @@ def run_openai_api(q: Queue, run_seq: int = 3, log_level: str = "INFO"):
uvicorn.run(app, host=host, port=port) uvicorn.run(app, host=host, port=port)
def run_api_server(q: Queue, run_seq: int = 4): def run_api_server(started_event: mp.Event = None):
from server.api import create_app from server.api import create_app
import uvicorn import uvicorn
app = create_app() app = create_app()
_set_app_seq(app, q, run_seq) _set_app_event(app, started_event)
host = API_SERVER["host"] host = API_SERVER["host"]
port = API_SERVER["port"] port = API_SERVER["port"]
@ -370,21 +352,14 @@ def run_api_server(q: Queue, run_seq: int = 4):
uvicorn.run(app, host=host, port=port) uvicorn.run(app, host=host, port=port)
def run_webui(q: Queue, run_seq: int = 5): def run_webui(started_event: mp.Event = None):
host = WEBUI_SERVER["host"] host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"] port = WEBUI_SERVER["port"]
if q is not None and isinstance(run_seq, int):
while True:
no = q.get()
if no != run_seq - 1:
q.put(no)
else:
break
q.put(run_seq)
p = subprocess.Popen(["streamlit", "run", "webui.py", p = subprocess.Popen(["streamlit", "run", "webui.py",
"--server.address", host, "--server.address", host,
"--server.port", str(port)]) "--server.port", str(port)])
started_event.set()
p.wait() p.wait()
@ -427,8 +402,9 @@ def parse_args() -> argparse.ArgumentParser:
"-n", "-n",
"--model-name", "--model-name",
type=str, type=str,
default=LLM_MODEL, nargs="+",
help="specify model name for model worker.", default=[LLM_MODEL],
help="specify model name for model worker. add addition names with space seperated to start multiple model workers.",
dest="model_name", dest="model_name",
) )
parser.add_argument( parser.add_argument(
@ -483,11 +459,12 @@ def dump_server_info(after_start=False, args=None):
print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}") print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}")
print("\n") print("\n")
model = LLM_MODEL models = [LLM_MODEL]
if args and args.model_name: if args and args.model_name:
model = args.model_name models = args.model_name
print(f"当前LLM模型{model} @ {llm_device()}") print(f"当前启动的LLM模型{models} @ {llm_device()}")
pprint(llm_model_dict[model]) for model in models:
pprint(llm_model_dict[model])
print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {embedding_device()}") print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {embedding_device()}")
if after_start: if after_start:
@ -554,10 +531,10 @@ async def start_main_server():
logger.info(f"正在启动服务:") logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}") logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
processes = {"online-api": []} processes = {"online_api": {}, "model_worker": {}}
def process_count(): def process_count():
return len(processes) + len(processes["online-api"]) - 1 return len(processes) + len(processes["online_api"]) - 1
if args.quiet: if args.quiet:
log_level = "ERROR" log_level = "ERROR"
@ -569,63 +546,73 @@ async def start_main_server():
process = Process( process = Process(
target=run_controller, target=run_controller,
name=f"controller", name=f"controller",
args=(queue, process_count() + 1, log_level, controller_started), kwargs=dict(log_level=log_level, started_event=controller_started),
daemon=True, daemon=True,
) )
processes["controller"] = process processes["controller"] = process
process = Process( process = Process(
target=run_openai_api, target=run_openai_api,
name=f"openai_api", name=f"openai_api",
args=(queue, process_count() + 1),
daemon=True, daemon=True,
) )
processes["openai_api"] = process processes["openai_api"] = process
model_worker_started = []
if args.model_worker: if args.model_worker:
config = get_model_worker_config(args.model_name) for model_name in args.model_name:
if not config.get("online_api"): config = get_model_worker_config(model_name)
process = Process( if not config.get("online_api"):
target=run_model_worker, e = manager.Event()
name=f"model_worker - {args.model_name}", model_worker_started.append(e)
args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level), process = Process(
daemon=True, target=run_model_worker,
) name=f"model_worker - {model_name}",
kwargs=dict(model_name=model_name,
processes["model_worker"] = process controller_address=args.controller_address,
log_level=log_level,
q=queue,
started_event=e),
daemon=True,
)
processes["model_worker"][model_name] = process
if args.api_worker: if args.api_worker:
configs = get_all_model_worker_configs() configs = get_all_model_worker_configs()
for model_name, config in configs.items(): for model_name, config in configs.items():
if config.get("online_api") and config.get("worker_class"): if config.get("online_api") and config.get("worker_class"):
e = manager.Event()
model_worker_started.append(e)
process = Process( process = Process(
target=run_model_worker, target=run_model_worker,
name=f"model_worker - {model_name}", name=f"api_worker - {model_name}",
args=(model_name, args.controller_address, queue, process_count() + 1, log_level), kwargs=dict(model_name=model_name,
controller_address=args.controller_address,
log_level=log_level,
q=queue,
started_event=e),
daemon=True, daemon=True,
) )
processes["online_api"][model_name] = process
processes["online-api"].append(process) api_started = manager.Event()
if args.api: if args.api:
process = Process( process = Process(
target=run_api_server, target=run_api_server,
name=f"API Server", name=f"API Server",
args=(queue, process_count() + 1), kwargs=dict(started_event=api_started),
daemon=True, daemon=True,
) )
processes["api"] = process processes["api"] = process
webui_started = manager.Event()
if args.webui: if args.webui:
process = Process( process = Process(
target=run_webui, target=run_webui,
name=f"WEBUI Server", name=f"WEBUI Server",
args=(queue, process_count() + 1), kwargs=dict(started_event=webui_started),
daemon=True, daemon=True,
) )
processes["webui"] = process processes["webui"] = process
if process_count() == 0: if process_count() == 0:
@ -636,60 +623,106 @@ async def start_main_server():
if p:= processes.get("controller"): if p:= processes.get("controller"):
p.start() p.start()
p.name = f"{p.name} ({p.pid})" p.name = f"{p.name} ({p.pid})"
controller_started.wait() controller_started.wait() # 等待controller启动完成
if p:= processes.get("openai_api"): if p:= processes.get("openai_api"):
p.start() p.start()
p.name = f"{p.name} ({p.pid})" p.name = f"{p.name} ({p.pid})"
if p:= processes.get("model_worker"): for n, p in processes.get("model_worker", {}).items():
p.start() p.start()
p.name = f"{p.name} ({p.pid})" p.name = f"{p.name} ({p.pid})"
for p in processes.get("online-api", []): for n, p in processes.get("online_api", []).items():
p.start() p.start()
p.name = f"{p.name} ({p.pid})" p.name = f"{p.name} ({p.pid})"
# 等待所有model_worker启动完成
for e in model_worker_started:
e.wait()
if p:= processes.get("api"): if p:= processes.get("api"):
p.start() p.start()
p.name = f"{p.name} ({p.pid})" p.name = f"{p.name} ({p.pid})"
api_started.wait() # 等待api.py启动完成
if p:= processes.get("webui"): if p:= processes.get("webui"):
p.start() p.start()
p.name = f"{p.name} ({p.pid})" p.name = f"{p.name} ({p.pid})"
webui_started.wait() # 等待webui.py启动完成
dump_server_info(after_start=True, args=args)
while True: while True:
no = queue.get() cmd = queue.get() # 收到切换模型的消息
if no == process_count(): e = manager.Event()
time.sleep(0.5) if isinstance(cmd, list):
dump_server_info(after_start=True, args=args) model_name, cmd, new_model_name = cmd
break if cmd == "start": # 运行新模型
else: logger.info(f"准备启动新模型进程:{new_model_name}")
queue.put(no) process = Process(
target=run_model_worker,
name=f"model_worker - {new_model_name}",
kwargs=dict(model_name=new_model_name,
controller_address=args.controller_address,
log_level=log_level,
q=queue,
started_event=e),
daemon=True,
)
process.start()
process.name = f"{process.name} ({process.pid})"
processes["model_worker"][new_model_name] = process
e.wait()
logger.info(f"成功启动新模型进程:{new_model_name}")
elif cmd == "stop":
if process := processes["model_worker"].get(model_name):
time.sleep(1)
process.terminate()
process.join()
logger.info(f"停止模型进程:{model_name}")
else:
logger.error(f"未找到模型进程:{model_name}")
elif cmd == "replace":
if process := processes["model_worker"].pop(model_name, None):
logger.info(f"停止模型进程:{model_name}")
start_time = datetime.now()
time.sleep(1)
process.terminate()
process.join()
process = Process(
target=run_model_worker,
name=f"model_worker - {new_model_name}",
kwargs=dict(model_name=new_model_name,
controller_address=args.controller_address,
log_level=log_level,
q=queue,
started_event=e),
daemon=True,
)
process.start()
process.name = f"{process.name} ({process.pid})"
processes["model_worker"][new_model_name] = process
e.wait()
timing = datetime.now() - start_time
logger.info(f"成功启动新模型进程:{new_model_name}。用时:{timing}")
else:
logger.error(f"未找到模型进程:{model_name}")
if model_worker_process := processes.get("model_worker"):
model_worker_process.join() # for process in processes.get("model_worker", {}).values():
for process in processes.get("online-api", []): # process.join()
process.join() # for process in processes.get("online_api", {}).values():
for name, process in processes.items(): # process.join()
if name not in ["model_worker", "online-api"]:
if isinstance(p, list): # for name, process in processes.items():
for work_process in p: # if name not in ["model_worker", "online_api"]:
work_process.join() # if isinstance(p, dict):
else: # for work_process in p.values():
process.join() # work_process.join()
# else:
# process.join()
except Exception as e: except Exception as e:
# if model_worker_process := processes.pop("model_worker", None):
# model_worker_process.terminate()
# for process in processes.pop("online-api", []):
# process.terminate()
# for process in processes.values():
#
# if isinstance(process, list):
# for work_process in process:
# work_process.terminate()
# else:
# process.terminate()
logger.error(e) logger.error(e)
logger.warning("Caught KeyboardInterrupt! Setting stop event...") logger.warning("Caught KeyboardInterrupt! Setting stop event...")
finally: finally:
@ -702,10 +735,9 @@ async def start_main_server():
# Queues and other inter-process communication primitives can break when # Queues and other inter-process communication primitives can break when
# process is killed, but we don't care here # process is killed, but we don't care here
if isinstance(p, list): if isinstance(p, dict):
for process in p: for process in p.values():
process.kill() process.kill()
else: else:
p.kill() p.kill()

View File

@ -65,7 +65,9 @@ def dialogue_page(api: ApiRequest):
) )
def on_llm_change(): def on_llm_change():
st.session_state["prev_llm_model"] = llm_model config = get_model_worker_config(llm_model)
if not config.get("online_api"): # 只有本地model_worker可以切换模型
st.session_state["prev_llm_model"] = llm_model
def llm_model_format_func(x): def llm_model_format_func(x):
if x in running_models: if x in running_models:
@ -91,7 +93,7 @@ def dialogue_page(api: ApiRequest):
) )
if (st.session_state.get("prev_llm_model") != llm_model if (st.session_state.get("prev_llm_model") != llm_model
and not get_model_worker_config(llm_model).get("online_api")): and not get_model_worker_config(llm_model).get("online_api")):
with st.spinner(f"正在加载模型: {llm_model}"): with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model) r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
st.session_state["prev_llm_model"] = llm_model st.session_state["prev_llm_model"] = llm_model