update startup.py: (#1416)

1. 恢复模型切换功能
2. --model-name支持多个名称(空格分开),同时启动多个模型
3. 优化服务启动顺序。严格按照顺序启动:controller -> [openai-api,
   model_worker, api_worker]并行 -> api.py -> webui.py
4. 修复:从在线API模型切换到本地模型时失败
This commit is contained in:
liunux4odoo 2023-09-08 15:18:13 +08:00 committed by GitHub
parent 775870a516
commit f94f2793f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 153 additions and 118 deletions

1
.gitignore vendored
View File

@ -7,3 +7,4 @@ __pycache__/
configs/*.py
.vscode/
.pytest_cache/
*.bak

View File

@ -3,7 +3,8 @@ import multiprocessing as mp
import os
import subprocess
import sys
from multiprocessing import Process, Queue
from multiprocessing import Process
from datetime import datetime
from pprint import pprint
# 设置numexpr最大线程数默认为CPU核心数
@ -19,7 +20,7 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
logger
from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER,
FSCHAT_OPENAI_API, )
FSCHAT_OPENAI_API, HTTPX_DEFAULT_TIMEOUT)
from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_timeout,
get_model_worker_config, get_all_model_worker_configs,
@ -47,7 +48,7 @@ def create_controller_app(
return app
def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]:
def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger
@ -188,29 +189,15 @@ def create_openai_api_app(
return app
def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
if q is None or not isinstance(run_seq, int):
return
if run_seq == 1:
@app.on_event("startup")
async def on_startup():
set_httpx_timeout()
q.put(run_seq)
elif run_seq > 1:
@app.on_event("startup")
async def on_startup():
set_httpx_timeout()
while True:
no = q.get()
if no != run_seq - 1:
q.put(no)
else:
break
q.put(run_seq)
def _set_app_event(app: FastAPI, started_event: mp.Event = None):
@app.on_event("startup")
async def on_startup():
set_httpx_timeout()
if started_event is not None:
started_event.set()
def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Event = None):
def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
import uvicorn
import httpx
from fastapi import Body
@ -221,12 +208,7 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Ev
dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"),
log_level=log_level,
)
_set_app_seq(app, q, run_seq)
@app.on_event("startup")
def on_startup():
if e is not None:
e.set()
_set_app_event(app, started_event)
# add interface to release and load model worker
@app.post("/release_worker")
@ -266,7 +248,7 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Ev
return {"code": 500, "msg": msg}
if new_model_name:
timer = 300 # wait 5 minutes for new model_worker register
timer = HTTPX_DEFAULT_TIMEOUT * 2 # wait for new model_worker register
while timer > 0:
models = app._controller.list_models()
if new_model_name in models:
@ -299,9 +281,9 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Ev
def run_model_worker(
model_name: str = LLM_MODEL,
controller_address: str = "",
q: Queue = None,
run_seq: int = 2,
log_level: str = "INFO",
q: mp.Queue = None,
started_event: mp.Event = None,
):
import uvicorn
from fastapi import Body
@ -317,7 +299,7 @@ def run_model_worker(
kwargs["model_path"] = model_path
app = create_model_worker_app(log_level=log_level, **kwargs)
_set_app_seq(app, q, run_seq)
_set_app_event(app, started_event)
if log_level == "ERROR":
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
@ -325,29 +307,29 @@ def run_model_worker(
# add interface to release and load model
@app.post("/release")
def release_model(
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict:
if keep_origin:
if new_model_name:
q.put(["start", new_model_name])
q.put([model_name, "start", new_model_name])
else:
if new_model_name:
q.put(["replace", new_model_name])
q.put([model_name, "replace", new_model_name])
else:
q.put(["stop"])
q.put([model_name, "stop", None])
return {"code": 200, "msg": "done"}
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
def run_openai_api(q: Queue, run_seq: int = 3, log_level: str = "INFO"):
def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
import uvicorn
import sys
controller_addr = fschat_controller_address()
app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet.
_set_app_seq(app, q, run_seq)
_set_app_event(app, started_event)
host = FSCHAT_OPENAI_API["host"]
port = FSCHAT_OPENAI_API["port"]
@ -357,12 +339,12 @@ def run_openai_api(q: Queue, run_seq: int = 3, log_level: str = "INFO"):
uvicorn.run(app, host=host, port=port)
def run_api_server(q: Queue, run_seq: int = 4):
def run_api_server(started_event: mp.Event = None):
from server.api import create_app
import uvicorn
app = create_app()
_set_app_seq(app, q, run_seq)
_set_app_event(app, started_event)
host = API_SERVER["host"]
port = API_SERVER["port"]
@ -370,21 +352,14 @@ def run_api_server(q: Queue, run_seq: int = 4):
uvicorn.run(app, host=host, port=port)
def run_webui(q: Queue, run_seq: int = 5):
def run_webui(started_event: mp.Event = None):
host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]
if q is not None and isinstance(run_seq, int):
while True:
no = q.get()
if no != run_seq - 1:
q.put(no)
else:
break
q.put(run_seq)
p = subprocess.Popen(["streamlit", "run", "webui.py",
"--server.address", host,
"--server.port", str(port)])
started_event.set()
p.wait()
@ -427,8 +402,9 @@ def parse_args() -> argparse.ArgumentParser:
"-n",
"--model-name",
type=str,
default=LLM_MODEL,
help="specify model name for model worker.",
nargs="+",
default=[LLM_MODEL],
help="specify model name for model worker. add addition names with space seperated to start multiple model workers.",
dest="model_name",
)
parser.add_argument(
@ -483,11 +459,12 @@ def dump_server_info(after_start=False, args=None):
print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}")
print("\n")
model = LLM_MODEL
models = [LLM_MODEL]
if args and args.model_name:
model = args.model_name
print(f"当前LLM模型{model} @ {llm_device()}")
pprint(llm_model_dict[model])
models = args.model_name
print(f"当前启动的LLM模型{models} @ {llm_device()}")
for model in models:
pprint(llm_model_dict[model])
print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {embedding_device()}")
if after_start:
@ -554,10 +531,10 @@ async def start_main_server():
logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
processes = {"online-api": []}
processes = {"online_api": {}, "model_worker": {}}
def process_count():
return len(processes) + len(processes["online-api"]) - 1
return len(processes) + len(processes["online_api"]) - 1
if args.quiet:
log_level = "ERROR"
@ -569,63 +546,73 @@ async def start_main_server():
process = Process(
target=run_controller,
name=f"controller",
args=(queue, process_count() + 1, log_level, controller_started),
kwargs=dict(log_level=log_level, started_event=controller_started),
daemon=True,
)
processes["controller"] = process
process = Process(
target=run_openai_api,
name=f"openai_api",
args=(queue, process_count() + 1),
daemon=True,
)
processes["openai_api"] = process
model_worker_started = []
if args.model_worker:
config = get_model_worker_config(args.model_name)
if not config.get("online_api"):
process = Process(
target=run_model_worker,
name=f"model_worker - {args.model_name}",
args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level),
daemon=True,
)
processes["model_worker"] = process
for model_name in args.model_name:
config = get_model_worker_config(model_name)
if not config.get("online_api"):
e = manager.Event()
model_worker_started.append(e)
process = Process(
target=run_model_worker,
name=f"model_worker - {model_name}",
kwargs=dict(model_name=model_name,
controller_address=args.controller_address,
log_level=log_level,
q=queue,
started_event=e),
daemon=True,
)
processes["model_worker"][model_name] = process
if args.api_worker:
configs = get_all_model_worker_configs()
for model_name, config in configs.items():
if config.get("online_api") and config.get("worker_class"):
e = manager.Event()
model_worker_started.append(e)
process = Process(
target=run_model_worker,
name=f"model_worker - {model_name}",
args=(model_name, args.controller_address, queue, process_count() + 1, log_level),
name=f"api_worker - {model_name}",
kwargs=dict(model_name=model_name,
controller_address=args.controller_address,
log_level=log_level,
q=queue,
started_event=e),
daemon=True,
)
processes["online_api"][model_name] = process
processes["online-api"].append(process)
api_started = manager.Event()
if args.api:
process = Process(
target=run_api_server,
name=f"API Server",
args=(queue, process_count() + 1),
kwargs=dict(started_event=api_started),
daemon=True,
)
processes["api"] = process
webui_started = manager.Event()
if args.webui:
process = Process(
target=run_webui,
name=f"WEBUI Server",
args=(queue, process_count() + 1),
kwargs=dict(started_event=webui_started),
daemon=True,
)
processes["webui"] = process
if process_count() == 0:
@ -636,60 +623,106 @@ async def start_main_server():
if p:= processes.get("controller"):
p.start()
p.name = f"{p.name} ({p.pid})"
controller_started.wait()
controller_started.wait() # 等待controller启动完成
if p:= processes.get("openai_api"):
p.start()
p.name = f"{p.name} ({p.pid})"
if p:= processes.get("model_worker"):
for n, p in processes.get("model_worker", {}).items():
p.start()
p.name = f"{p.name} ({p.pid})"
for p in processes.get("online-api", []):
for n, p in processes.get("online_api", []).items():
p.start()
p.name = f"{p.name} ({p.pid})"
# 等待所有model_worker启动完成
for e in model_worker_started:
e.wait()
if p:= processes.get("api"):
p.start()
p.name = f"{p.name} ({p.pid})"
api_started.wait() # 等待api.py启动完成
if p:= processes.get("webui"):
p.start()
p.name = f"{p.name} ({p.pid})"
webui_started.wait() # 等待webui.py启动完成
dump_server_info(after_start=True, args=args)
while True:
no = queue.get()
if no == process_count():
time.sleep(0.5)
dump_server_info(after_start=True, args=args)
break
else:
queue.put(no)
cmd = queue.get() # 收到切换模型的消息
e = manager.Event()
if isinstance(cmd, list):
model_name, cmd, new_model_name = cmd
if cmd == "start": # 运行新模型
logger.info(f"准备启动新模型进程:{new_model_name}")
process = Process(
target=run_model_worker,
name=f"model_worker - {new_model_name}",
kwargs=dict(model_name=new_model_name,
controller_address=args.controller_address,
log_level=log_level,
q=queue,
started_event=e),
daemon=True,
)
process.start()
process.name = f"{process.name} ({process.pid})"
processes["model_worker"][new_model_name] = process
e.wait()
logger.info(f"成功启动新模型进程:{new_model_name}")
elif cmd == "stop":
if process := processes["model_worker"].get(model_name):
time.sleep(1)
process.terminate()
process.join()
logger.info(f"停止模型进程:{model_name}")
else:
logger.error(f"未找到模型进程:{model_name}")
elif cmd == "replace":
if process := processes["model_worker"].pop(model_name, None):
logger.info(f"停止模型进程:{model_name}")
start_time = datetime.now()
time.sleep(1)
process.terminate()
process.join()
process = Process(
target=run_model_worker,
name=f"model_worker - {new_model_name}",
kwargs=dict(model_name=new_model_name,
controller_address=args.controller_address,
log_level=log_level,
q=queue,
started_event=e),
daemon=True,
)
process.start()
process.name = f"{process.name} ({process.pid})"
processes["model_worker"][new_model_name] = process
e.wait()
timing = datetime.now() - start_time
logger.info(f"成功启动新模型进程:{new_model_name}。用时:{timing}")
else:
logger.error(f"未找到模型进程:{model_name}")
if model_worker_process := processes.get("model_worker"):
model_worker_process.join()
for process in processes.get("online-api", []):
process.join()
for name, process in processes.items():
if name not in ["model_worker", "online-api"]:
if isinstance(p, list):
for work_process in p:
work_process.join()
else:
process.join()
# for process in processes.get("model_worker", {}).values():
# process.join()
# for process in processes.get("online_api", {}).values():
# process.join()
# for name, process in processes.items():
# if name not in ["model_worker", "online_api"]:
# if isinstance(p, dict):
# for work_process in p.values():
# work_process.join()
# else:
# process.join()
except Exception as e:
# if model_worker_process := processes.pop("model_worker", None):
# model_worker_process.terminate()
# for process in processes.pop("online-api", []):
# process.terminate()
# for process in processes.values():
#
# if isinstance(process, list):
# for work_process in process:
# work_process.terminate()
# else:
# process.terminate()
logger.error(e)
logger.warning("Caught KeyboardInterrupt! Setting stop event...")
finally:
@ -702,10 +735,9 @@ async def start_main_server():
# Queues and other inter-process communication primitives can break when
# process is killed, but we don't care here
if isinstance(p, list):
for process in p:
if isinstance(p, dict):
for process in p.values():
process.kill()
else:
p.kill()

View File

@ -65,7 +65,9 @@ def dialogue_page(api: ApiRequest):
)
def on_llm_change():
st.session_state["prev_llm_model"] = llm_model
config = get_model_worker_config(llm_model)
if not config.get("online_api"): # 只有本地model_worker可以切换模型
st.session_state["prev_llm_model"] = llm_model
def llm_model_format_func(x):
if x in running_models:
@ -91,7 +93,7 @@ def dialogue_page(api: ApiRequest):
)
if (st.session_state.get("prev_llm_model") != llm_model
and not get_model_worker_config(llm_model).get("online_api")):
with st.spinner(f"正在加载模型: {llm_model}"):
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
st.session_state["prev_llm_model"] = llm_model