update startup.py: (#1416)
1. 恢复模型切换功能 2. --model-name支持多个名称(空格分开),同时启动多个模型 3. 优化服务启动顺序。严格按照顺序启动:controller -> [openai-api, model_worker, api_worker]并行 -> api.py -> webui.py 4. 修复:从在线API模型切换到本地模型时失败
This commit is contained in:
parent
775870a516
commit
f94f2793f8
|
|
@ -7,3 +7,4 @@ __pycache__/
|
|||
configs/*.py
|
||||
.vscode/
|
||||
.pytest_cache/
|
||||
*.bak
|
||||
|
|
|
|||
264
startup.py
264
startup.py
|
|
@ -3,7 +3,8 @@ import multiprocessing as mp
|
|||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from multiprocessing import Process, Queue
|
||||
from multiprocessing import Process
|
||||
from datetime import datetime
|
||||
from pprint import pprint
|
||||
|
||||
# 设置numexpr最大线程数,默认为CPU核心数
|
||||
|
|
@ -19,7 +20,7 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
|||
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
|
||||
logger
|
||||
from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER,
|
||||
FSCHAT_OPENAI_API, )
|
||||
FSCHAT_OPENAI_API, HTTPX_DEFAULT_TIMEOUT)
|
||||
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
||||
fschat_openai_api_address, set_httpx_timeout,
|
||||
get_model_worker_config, get_all_model_worker_configs,
|
||||
|
|
@ -47,7 +48,7 @@ def create_controller_app(
|
|||
return app
|
||||
|
||||
|
||||
def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]:
|
||||
def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger
|
||||
|
|
@ -188,29 +189,15 @@ def create_openai_api_app(
|
|||
return app
|
||||
|
||||
|
||||
def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
|
||||
if q is None or not isinstance(run_seq, int):
|
||||
return
|
||||
|
||||
if run_seq == 1:
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
set_httpx_timeout()
|
||||
q.put(run_seq)
|
||||
elif run_seq > 1:
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
set_httpx_timeout()
|
||||
while True:
|
||||
no = q.get()
|
||||
if no != run_seq - 1:
|
||||
q.put(no)
|
||||
else:
|
||||
break
|
||||
q.put(run_seq)
|
||||
def _set_app_event(app: FastAPI, started_event: mp.Event = None):
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
set_httpx_timeout()
|
||||
if started_event is not None:
|
||||
started_event.set()
|
||||
|
||||
|
||||
def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Event = None):
|
||||
def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
|
||||
import uvicorn
|
||||
import httpx
|
||||
from fastapi import Body
|
||||
|
|
@ -221,12 +208,7 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Ev
|
|||
dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"),
|
||||
log_level=log_level,
|
||||
)
|
||||
_set_app_seq(app, q, run_seq)
|
||||
|
||||
@app.on_event("startup")
|
||||
def on_startup():
|
||||
if e is not None:
|
||||
e.set()
|
||||
_set_app_event(app, started_event)
|
||||
|
||||
# add interface to release and load model worker
|
||||
@app.post("/release_worker")
|
||||
|
|
@ -266,7 +248,7 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Ev
|
|||
return {"code": 500, "msg": msg}
|
||||
|
||||
if new_model_name:
|
||||
timer = 300 # wait 5 minutes for new model_worker register
|
||||
timer = HTTPX_DEFAULT_TIMEOUT * 2 # wait for new model_worker register
|
||||
while timer > 0:
|
||||
models = app._controller.list_models()
|
||||
if new_model_name in models:
|
||||
|
|
@ -299,9 +281,9 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Ev
|
|||
def run_model_worker(
|
||||
model_name: str = LLM_MODEL,
|
||||
controller_address: str = "",
|
||||
q: Queue = None,
|
||||
run_seq: int = 2,
|
||||
log_level: str = "INFO",
|
||||
q: mp.Queue = None,
|
||||
started_event: mp.Event = None,
|
||||
):
|
||||
import uvicorn
|
||||
from fastapi import Body
|
||||
|
|
@ -317,7 +299,7 @@ def run_model_worker(
|
|||
kwargs["model_path"] = model_path
|
||||
|
||||
app = create_model_worker_app(log_level=log_level, **kwargs)
|
||||
_set_app_seq(app, q, run_seq)
|
||||
_set_app_event(app, started_event)
|
||||
if log_level == "ERROR":
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
|
|
@ -325,29 +307,29 @@ def run_model_worker(
|
|||
# add interface to release and load model
|
||||
@app.post("/release")
|
||||
def release_model(
|
||||
new_model_name: str = Body(None, description="释放后加载该模型"),
|
||||
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
|
||||
new_model_name: str = Body(None, description="释放后加载该模型"),
|
||||
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
|
||||
) -> Dict:
|
||||
if keep_origin:
|
||||
if new_model_name:
|
||||
q.put(["start", new_model_name])
|
||||
q.put([model_name, "start", new_model_name])
|
||||
else:
|
||||
if new_model_name:
|
||||
q.put(["replace", new_model_name])
|
||||
q.put([model_name, "replace", new_model_name])
|
||||
else:
|
||||
q.put(["stop"])
|
||||
q.put([model_name, "stop", None])
|
||||
return {"code": 200, "msg": "done"}
|
||||
|
||||
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
|
||||
|
||||
|
||||
def run_openai_api(q: Queue, run_seq: int = 3, log_level: str = "INFO"):
|
||||
def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
|
||||
import uvicorn
|
||||
import sys
|
||||
|
||||
controller_addr = fschat_controller_address()
|
||||
app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet.
|
||||
_set_app_seq(app, q, run_seq)
|
||||
_set_app_event(app, started_event)
|
||||
|
||||
host = FSCHAT_OPENAI_API["host"]
|
||||
port = FSCHAT_OPENAI_API["port"]
|
||||
|
|
@ -357,12 +339,12 @@ def run_openai_api(q: Queue, run_seq: int = 3, log_level: str = "INFO"):
|
|||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
def run_api_server(q: Queue, run_seq: int = 4):
|
||||
def run_api_server(started_event: mp.Event = None):
|
||||
from server.api import create_app
|
||||
import uvicorn
|
||||
|
||||
app = create_app()
|
||||
_set_app_seq(app, q, run_seq)
|
||||
_set_app_event(app, started_event)
|
||||
|
||||
host = API_SERVER["host"]
|
||||
port = API_SERVER["port"]
|
||||
|
|
@ -370,21 +352,14 @@ def run_api_server(q: Queue, run_seq: int = 4):
|
|||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
def run_webui(q: Queue, run_seq: int = 5):
|
||||
def run_webui(started_event: mp.Event = None):
|
||||
host = WEBUI_SERVER["host"]
|
||||
port = WEBUI_SERVER["port"]
|
||||
|
||||
if q is not None and isinstance(run_seq, int):
|
||||
while True:
|
||||
no = q.get()
|
||||
if no != run_seq - 1:
|
||||
q.put(no)
|
||||
else:
|
||||
break
|
||||
q.put(run_seq)
|
||||
p = subprocess.Popen(["streamlit", "run", "webui.py",
|
||||
"--server.address", host,
|
||||
"--server.port", str(port)])
|
||||
started_event.set()
|
||||
p.wait()
|
||||
|
||||
|
||||
|
|
@ -427,8 +402,9 @@ def parse_args() -> argparse.ArgumentParser:
|
|||
"-n",
|
||||
"--model-name",
|
||||
type=str,
|
||||
default=LLM_MODEL,
|
||||
help="specify model name for model worker.",
|
||||
nargs="+",
|
||||
default=[LLM_MODEL],
|
||||
help="specify model name for model worker. add addition names with space seperated to start multiple model workers.",
|
||||
dest="model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
|
@ -483,11 +459,12 @@ def dump_server_info(after_start=False, args=None):
|
|||
print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
|
||||
print("\n")
|
||||
|
||||
model = LLM_MODEL
|
||||
models = [LLM_MODEL]
|
||||
if args and args.model_name:
|
||||
model = args.model_name
|
||||
print(f"当前LLM模型:{model} @ {llm_device()}")
|
||||
pprint(llm_model_dict[model])
|
||||
models = args.model_name
|
||||
print(f"当前启动的LLM模型:{models} @ {llm_device()}")
|
||||
for model in models:
|
||||
pprint(llm_model_dict[model])
|
||||
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
|
||||
|
||||
if after_start:
|
||||
|
|
@ -554,10 +531,10 @@ async def start_main_server():
|
|||
logger.info(f"正在启动服务:")
|
||||
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
|
||||
|
||||
processes = {"online-api": []}
|
||||
processes = {"online_api": {}, "model_worker": {}}
|
||||
|
||||
def process_count():
|
||||
return len(processes) + len(processes["online-api"]) - 1
|
||||
return len(processes) + len(processes["online_api"]) - 1
|
||||
|
||||
if args.quiet:
|
||||
log_level = "ERROR"
|
||||
|
|
@ -569,63 +546,73 @@ async def start_main_server():
|
|||
process = Process(
|
||||
target=run_controller,
|
||||
name=f"controller",
|
||||
args=(queue, process_count() + 1, log_level, controller_started),
|
||||
kwargs=dict(log_level=log_level, started_event=controller_started),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
processes["controller"] = process
|
||||
|
||||
process = Process(
|
||||
target=run_openai_api,
|
||||
name=f"openai_api",
|
||||
args=(queue, process_count() + 1),
|
||||
daemon=True,
|
||||
)
|
||||
processes["openai_api"] = process
|
||||
|
||||
model_worker_started = []
|
||||
if args.model_worker:
|
||||
config = get_model_worker_config(args.model_name)
|
||||
if not config.get("online_api"):
|
||||
process = Process(
|
||||
target=run_model_worker,
|
||||
name=f"model_worker - {args.model_name}",
|
||||
args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
processes["model_worker"] = process
|
||||
for model_name in args.model_name:
|
||||
config = get_model_worker_config(model_name)
|
||||
if not config.get("online_api"):
|
||||
e = manager.Event()
|
||||
model_worker_started.append(e)
|
||||
process = Process(
|
||||
target=run_model_worker,
|
||||
name=f"model_worker - {model_name}",
|
||||
kwargs=dict(model_name=model_name,
|
||||
controller_address=args.controller_address,
|
||||
log_level=log_level,
|
||||
q=queue,
|
||||
started_event=e),
|
||||
daemon=True,
|
||||
)
|
||||
processes["model_worker"][model_name] = process
|
||||
|
||||
if args.api_worker:
|
||||
configs = get_all_model_worker_configs()
|
||||
for model_name, config in configs.items():
|
||||
if config.get("online_api") and config.get("worker_class"):
|
||||
e = manager.Event()
|
||||
model_worker_started.append(e)
|
||||
process = Process(
|
||||
target=run_model_worker,
|
||||
name=f"model_worker - {model_name}",
|
||||
args=(model_name, args.controller_address, queue, process_count() + 1, log_level),
|
||||
name=f"api_worker - {model_name}",
|
||||
kwargs=dict(model_name=model_name,
|
||||
controller_address=args.controller_address,
|
||||
log_level=log_level,
|
||||
q=queue,
|
||||
started_event=e),
|
||||
daemon=True,
|
||||
)
|
||||
processes["online_api"][model_name] = process
|
||||
|
||||
processes["online-api"].append(process)
|
||||
|
||||
api_started = manager.Event()
|
||||
if args.api:
|
||||
process = Process(
|
||||
target=run_api_server,
|
||||
name=f"API Server",
|
||||
args=(queue, process_count() + 1),
|
||||
kwargs=dict(started_event=api_started),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
processes["api"] = process
|
||||
|
||||
webui_started = manager.Event()
|
||||
if args.webui:
|
||||
process = Process(
|
||||
target=run_webui,
|
||||
name=f"WEBUI Server",
|
||||
args=(queue, process_count() + 1),
|
||||
kwargs=dict(started_event=webui_started),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
processes["webui"] = process
|
||||
|
||||
if process_count() == 0:
|
||||
|
|
@ -636,60 +623,106 @@ async def start_main_server():
|
|||
if p:= processes.get("controller"):
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
controller_started.wait()
|
||||
controller_started.wait() # 等待controller启动完成
|
||||
|
||||
if p:= processes.get("openai_api"):
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
|
||||
if p:= processes.get("model_worker"):
|
||||
for n, p in processes.get("model_worker", {}).items():
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
|
||||
for p in processes.get("online-api", []):
|
||||
for n, p in processes.get("online_api", []).items():
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
|
||||
# 等待所有model_worker启动完成
|
||||
for e in model_worker_started:
|
||||
e.wait()
|
||||
|
||||
if p:= processes.get("api"):
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
api_started.wait() # 等待api.py启动完成
|
||||
|
||||
if p:= processes.get("webui"):
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
webui_started.wait() # 等待webui.py启动完成
|
||||
|
||||
dump_server_info(after_start=True, args=args)
|
||||
|
||||
while True:
|
||||
no = queue.get()
|
||||
if no == process_count():
|
||||
time.sleep(0.5)
|
||||
dump_server_info(after_start=True, args=args)
|
||||
break
|
||||
else:
|
||||
queue.put(no)
|
||||
cmd = queue.get() # 收到切换模型的消息
|
||||
e = manager.Event()
|
||||
if isinstance(cmd, list):
|
||||
model_name, cmd, new_model_name = cmd
|
||||
if cmd == "start": # 运行新模型
|
||||
logger.info(f"准备启动新模型进程:{new_model_name}")
|
||||
process = Process(
|
||||
target=run_model_worker,
|
||||
name=f"model_worker - {new_model_name}",
|
||||
kwargs=dict(model_name=new_model_name,
|
||||
controller_address=args.controller_address,
|
||||
log_level=log_level,
|
||||
q=queue,
|
||||
started_event=e),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
process.name = f"{process.name} ({process.pid})"
|
||||
processes["model_worker"][new_model_name] = process
|
||||
e.wait()
|
||||
logger.info(f"成功启动新模型进程:{new_model_name}")
|
||||
elif cmd == "stop":
|
||||
if process := processes["model_worker"].get(model_name):
|
||||
time.sleep(1)
|
||||
process.terminate()
|
||||
process.join()
|
||||
logger.info(f"停止模型进程:{model_name}")
|
||||
else:
|
||||
logger.error(f"未找到模型进程:{model_name}")
|
||||
elif cmd == "replace":
|
||||
if process := processes["model_worker"].pop(model_name, None):
|
||||
logger.info(f"停止模型进程:{model_name}")
|
||||
start_time = datetime.now()
|
||||
time.sleep(1)
|
||||
process.terminate()
|
||||
process.join()
|
||||
process = Process(
|
||||
target=run_model_worker,
|
||||
name=f"model_worker - {new_model_name}",
|
||||
kwargs=dict(model_name=new_model_name,
|
||||
controller_address=args.controller_address,
|
||||
log_level=log_level,
|
||||
q=queue,
|
||||
started_event=e),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
process.name = f"{process.name} ({process.pid})"
|
||||
processes["model_worker"][new_model_name] = process
|
||||
e.wait()
|
||||
timing = datetime.now() - start_time
|
||||
logger.info(f"成功启动新模型进程:{new_model_name}。用时:{timing}。")
|
||||
else:
|
||||
logger.error(f"未找到模型进程:{model_name}")
|
||||
|
||||
if model_worker_process := processes.get("model_worker"):
|
||||
model_worker_process.join()
|
||||
for process in processes.get("online-api", []):
|
||||
process.join()
|
||||
for name, process in processes.items():
|
||||
if name not in ["model_worker", "online-api"]:
|
||||
if isinstance(p, list):
|
||||
for work_process in p:
|
||||
work_process.join()
|
||||
else:
|
||||
process.join()
|
||||
|
||||
# for process in processes.get("model_worker", {}).values():
|
||||
# process.join()
|
||||
# for process in processes.get("online_api", {}).values():
|
||||
# process.join()
|
||||
|
||||
# for name, process in processes.items():
|
||||
# if name not in ["model_worker", "online_api"]:
|
||||
# if isinstance(p, dict):
|
||||
# for work_process in p.values():
|
||||
# work_process.join()
|
||||
# else:
|
||||
# process.join()
|
||||
except Exception as e:
|
||||
# if model_worker_process := processes.pop("model_worker", None):
|
||||
# model_worker_process.terminate()
|
||||
# for process in processes.pop("online-api", []):
|
||||
# process.terminate()
|
||||
# for process in processes.values():
|
||||
#
|
||||
# if isinstance(process, list):
|
||||
# for work_process in process:
|
||||
# work_process.terminate()
|
||||
# else:
|
||||
# process.terminate()
|
||||
logger.error(e)
|
||||
logger.warning("Caught KeyboardInterrupt! Setting stop event...")
|
||||
finally:
|
||||
|
|
@ -702,10 +735,9 @@ async def start_main_server():
|
|||
# Queues and other inter-process communication primitives can break when
|
||||
# process is killed, but we don't care here
|
||||
|
||||
if isinstance(p, list):
|
||||
for process in p:
|
||||
if isinstance(p, dict):
|
||||
for process in p.values():
|
||||
process.kill()
|
||||
|
||||
else:
|
||||
p.kill()
|
||||
|
||||
|
|
|
|||
|
|
@ -65,7 +65,9 @@ def dialogue_page(api: ApiRequest):
|
|||
)
|
||||
|
||||
def on_llm_change():
|
||||
st.session_state["prev_llm_model"] = llm_model
|
||||
config = get_model_worker_config(llm_model)
|
||||
if not config.get("online_api"): # 只有本地model_worker可以切换模型
|
||||
st.session_state["prev_llm_model"] = llm_model
|
||||
|
||||
def llm_model_format_func(x):
|
||||
if x in running_models:
|
||||
|
|
@ -91,7 +93,7 @@ def dialogue_page(api: ApiRequest):
|
|||
)
|
||||
if (st.session_state.get("prev_llm_model") != llm_model
|
||||
and not get_model_worker_config(llm_model).get("online_api")):
|
||||
with st.spinner(f"正在加载模型: {llm_model}"):
|
||||
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
|
||||
r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
|
||||
st.session_state["prev_llm_model"] = llm_model
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue