update startup.py: (#1416)
1. 恢复模型切换功能 2. --model-name支持多个名称(空格分开),同时启动多个模型 3. 优化服务启动顺序。严格按照顺序启动:controller -> [openai-api, model_worker, api_worker]并行 -> api.py -> webui.py 4. 修复:从在线API模型切换到本地模型时失败
This commit is contained in:
parent
775870a516
commit
f94f2793f8
|
|
@ -7,3 +7,4 @@ __pycache__/
|
||||||
configs/*.py
|
configs/*.py
|
||||||
.vscode/
|
.vscode/
|
||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
|
*.bak
|
||||||
|
|
|
||||||
264
startup.py
264
startup.py
|
|
@ -3,7 +3,8 @@ import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from multiprocessing import Process, Queue
|
from multiprocessing import Process
|
||||||
|
from datetime import datetime
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
||||||
# 设置numexpr最大线程数,默认为CPU核心数
|
# 设置numexpr最大线程数,默认为CPU核心数
|
||||||
|
|
@ -19,7 +20,7 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||||
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
|
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
|
||||||
logger
|
logger
|
||||||
from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER,
|
from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER,
|
||||||
FSCHAT_OPENAI_API, )
|
FSCHAT_OPENAI_API, HTTPX_DEFAULT_TIMEOUT)
|
||||||
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
||||||
fschat_openai_api_address, set_httpx_timeout,
|
fschat_openai_api_address, set_httpx_timeout,
|
||||||
get_model_worker_config, get_all_model_worker_configs,
|
get_model_worker_config, get_all_model_worker_configs,
|
||||||
|
|
@ -47,7 +48,7 @@ def create_controller_app(
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]:
|
def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
||||||
import fastchat.constants
|
import fastchat.constants
|
||||||
fastchat.constants.LOGDIR = LOG_PATH
|
fastchat.constants.LOGDIR = LOG_PATH
|
||||||
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger
|
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger
|
||||||
|
|
@ -188,29 +189,15 @@ def create_openai_api_app(
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
|
def _set_app_event(app: FastAPI, started_event: mp.Event = None):
|
||||||
if q is None or not isinstance(run_seq, int):
|
@app.on_event("startup")
|
||||||
return
|
async def on_startup():
|
||||||
|
set_httpx_timeout()
|
||||||
if run_seq == 1:
|
if started_event is not None:
|
||||||
@app.on_event("startup")
|
started_event.set()
|
||||||
async def on_startup():
|
|
||||||
set_httpx_timeout()
|
|
||||||
q.put(run_seq)
|
|
||||||
elif run_seq > 1:
|
|
||||||
@app.on_event("startup")
|
|
||||||
async def on_startup():
|
|
||||||
set_httpx_timeout()
|
|
||||||
while True:
|
|
||||||
no = q.get()
|
|
||||||
if no != run_seq - 1:
|
|
||||||
q.put(no)
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
q.put(run_seq)
|
|
||||||
|
|
||||||
|
|
||||||
def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Event = None):
|
def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
|
|
@ -221,12 +208,7 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Ev
|
||||||
dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"),
|
dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"),
|
||||||
log_level=log_level,
|
log_level=log_level,
|
||||||
)
|
)
|
||||||
_set_app_seq(app, q, run_seq)
|
_set_app_event(app, started_event)
|
||||||
|
|
||||||
@app.on_event("startup")
|
|
||||||
def on_startup():
|
|
||||||
if e is not None:
|
|
||||||
e.set()
|
|
||||||
|
|
||||||
# add interface to release and load model worker
|
# add interface to release and load model worker
|
||||||
@app.post("/release_worker")
|
@app.post("/release_worker")
|
||||||
|
|
@ -266,7 +248,7 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Ev
|
||||||
return {"code": 500, "msg": msg}
|
return {"code": 500, "msg": msg}
|
||||||
|
|
||||||
if new_model_name:
|
if new_model_name:
|
||||||
timer = 300 # wait 5 minutes for new model_worker register
|
timer = HTTPX_DEFAULT_TIMEOUT * 2 # wait for new model_worker register
|
||||||
while timer > 0:
|
while timer > 0:
|
||||||
models = app._controller.list_models()
|
models = app._controller.list_models()
|
||||||
if new_model_name in models:
|
if new_model_name in models:
|
||||||
|
|
@ -299,9 +281,9 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Ev
|
||||||
def run_model_worker(
|
def run_model_worker(
|
||||||
model_name: str = LLM_MODEL,
|
model_name: str = LLM_MODEL,
|
||||||
controller_address: str = "",
|
controller_address: str = "",
|
||||||
q: Queue = None,
|
|
||||||
run_seq: int = 2,
|
|
||||||
log_level: str = "INFO",
|
log_level: str = "INFO",
|
||||||
|
q: mp.Queue = None,
|
||||||
|
started_event: mp.Event = None,
|
||||||
):
|
):
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
|
|
@ -317,7 +299,7 @@ def run_model_worker(
|
||||||
kwargs["model_path"] = model_path
|
kwargs["model_path"] = model_path
|
||||||
|
|
||||||
app = create_model_worker_app(log_level=log_level, **kwargs)
|
app = create_model_worker_app(log_level=log_level, **kwargs)
|
||||||
_set_app_seq(app, q, run_seq)
|
_set_app_event(app, started_event)
|
||||||
if log_level == "ERROR":
|
if log_level == "ERROR":
|
||||||
sys.stdout = sys.__stdout__
|
sys.stdout = sys.__stdout__
|
||||||
sys.stderr = sys.__stderr__
|
sys.stderr = sys.__stderr__
|
||||||
|
|
@ -325,29 +307,29 @@ def run_model_worker(
|
||||||
# add interface to release and load model
|
# add interface to release and load model
|
||||||
@app.post("/release")
|
@app.post("/release")
|
||||||
def release_model(
|
def release_model(
|
||||||
new_model_name: str = Body(None, description="释放后加载该模型"),
|
new_model_name: str = Body(None, description="释放后加载该模型"),
|
||||||
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
|
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
if keep_origin:
|
if keep_origin:
|
||||||
if new_model_name:
|
if new_model_name:
|
||||||
q.put(["start", new_model_name])
|
q.put([model_name, "start", new_model_name])
|
||||||
else:
|
else:
|
||||||
if new_model_name:
|
if new_model_name:
|
||||||
q.put(["replace", new_model_name])
|
q.put([model_name, "replace", new_model_name])
|
||||||
else:
|
else:
|
||||||
q.put(["stop"])
|
q.put([model_name, "stop", None])
|
||||||
return {"code": 200, "msg": "done"}
|
return {"code": 200, "msg": "done"}
|
||||||
|
|
||||||
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
|
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
|
||||||
|
|
||||||
|
|
||||||
def run_openai_api(q: Queue, run_seq: int = 3, log_level: str = "INFO"):
|
def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
controller_addr = fschat_controller_address()
|
controller_addr = fschat_controller_address()
|
||||||
app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet.
|
app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet.
|
||||||
_set_app_seq(app, q, run_seq)
|
_set_app_event(app, started_event)
|
||||||
|
|
||||||
host = FSCHAT_OPENAI_API["host"]
|
host = FSCHAT_OPENAI_API["host"]
|
||||||
port = FSCHAT_OPENAI_API["port"]
|
port = FSCHAT_OPENAI_API["port"]
|
||||||
|
|
@ -357,12 +339,12 @@ def run_openai_api(q: Queue, run_seq: int = 3, log_level: str = "INFO"):
|
||||||
uvicorn.run(app, host=host, port=port)
|
uvicorn.run(app, host=host, port=port)
|
||||||
|
|
||||||
|
|
||||||
def run_api_server(q: Queue, run_seq: int = 4):
|
def run_api_server(started_event: mp.Event = None):
|
||||||
from server.api import create_app
|
from server.api import create_app
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
app = create_app()
|
app = create_app()
|
||||||
_set_app_seq(app, q, run_seq)
|
_set_app_event(app, started_event)
|
||||||
|
|
||||||
host = API_SERVER["host"]
|
host = API_SERVER["host"]
|
||||||
port = API_SERVER["port"]
|
port = API_SERVER["port"]
|
||||||
|
|
@ -370,21 +352,14 @@ def run_api_server(q: Queue, run_seq: int = 4):
|
||||||
uvicorn.run(app, host=host, port=port)
|
uvicorn.run(app, host=host, port=port)
|
||||||
|
|
||||||
|
|
||||||
def run_webui(q: Queue, run_seq: int = 5):
|
def run_webui(started_event: mp.Event = None):
|
||||||
host = WEBUI_SERVER["host"]
|
host = WEBUI_SERVER["host"]
|
||||||
port = WEBUI_SERVER["port"]
|
port = WEBUI_SERVER["port"]
|
||||||
|
|
||||||
if q is not None and isinstance(run_seq, int):
|
|
||||||
while True:
|
|
||||||
no = q.get()
|
|
||||||
if no != run_seq - 1:
|
|
||||||
q.put(no)
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
q.put(run_seq)
|
|
||||||
p = subprocess.Popen(["streamlit", "run", "webui.py",
|
p = subprocess.Popen(["streamlit", "run", "webui.py",
|
||||||
"--server.address", host,
|
"--server.address", host,
|
||||||
"--server.port", str(port)])
|
"--server.port", str(port)])
|
||||||
|
started_event.set()
|
||||||
p.wait()
|
p.wait()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -427,8 +402,9 @@ def parse_args() -> argparse.ArgumentParser:
|
||||||
"-n",
|
"-n",
|
||||||
"--model-name",
|
"--model-name",
|
||||||
type=str,
|
type=str,
|
||||||
default=LLM_MODEL,
|
nargs="+",
|
||||||
help="specify model name for model worker.",
|
default=[LLM_MODEL],
|
||||||
|
help="specify model name for model worker. add addition names with space seperated to start multiple model workers.",
|
||||||
dest="model_name",
|
dest="model_name",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
@ -483,11 +459,12 @@ def dump_server_info(after_start=False, args=None):
|
||||||
print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
|
print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
|
||||||
print("\n")
|
print("\n")
|
||||||
|
|
||||||
model = LLM_MODEL
|
models = [LLM_MODEL]
|
||||||
if args and args.model_name:
|
if args and args.model_name:
|
||||||
model = args.model_name
|
models = args.model_name
|
||||||
print(f"当前LLM模型:{model} @ {llm_device()}")
|
print(f"当前启动的LLM模型:{models} @ {llm_device()}")
|
||||||
pprint(llm_model_dict[model])
|
for model in models:
|
||||||
|
pprint(llm_model_dict[model])
|
||||||
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
|
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
|
||||||
|
|
||||||
if after_start:
|
if after_start:
|
||||||
|
|
@ -554,10 +531,10 @@ async def start_main_server():
|
||||||
logger.info(f"正在启动服务:")
|
logger.info(f"正在启动服务:")
|
||||||
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
|
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
|
||||||
|
|
||||||
processes = {"online-api": []}
|
processes = {"online_api": {}, "model_worker": {}}
|
||||||
|
|
||||||
def process_count():
|
def process_count():
|
||||||
return len(processes) + len(processes["online-api"]) - 1
|
return len(processes) + len(processes["online_api"]) - 1
|
||||||
|
|
||||||
if args.quiet:
|
if args.quiet:
|
||||||
log_level = "ERROR"
|
log_level = "ERROR"
|
||||||
|
|
@ -569,63 +546,73 @@ async def start_main_server():
|
||||||
process = Process(
|
process = Process(
|
||||||
target=run_controller,
|
target=run_controller,
|
||||||
name=f"controller",
|
name=f"controller",
|
||||||
args=(queue, process_count() + 1, log_level, controller_started),
|
kwargs=dict(log_level=log_level, started_event=controller_started),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
processes["controller"] = process
|
processes["controller"] = process
|
||||||
|
|
||||||
process = Process(
|
process = Process(
|
||||||
target=run_openai_api,
|
target=run_openai_api,
|
||||||
name=f"openai_api",
|
name=f"openai_api",
|
||||||
args=(queue, process_count() + 1),
|
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
processes["openai_api"] = process
|
processes["openai_api"] = process
|
||||||
|
|
||||||
|
model_worker_started = []
|
||||||
if args.model_worker:
|
if args.model_worker:
|
||||||
config = get_model_worker_config(args.model_name)
|
for model_name in args.model_name:
|
||||||
if not config.get("online_api"):
|
config = get_model_worker_config(model_name)
|
||||||
process = Process(
|
if not config.get("online_api"):
|
||||||
target=run_model_worker,
|
e = manager.Event()
|
||||||
name=f"model_worker - {args.model_name}",
|
model_worker_started.append(e)
|
||||||
args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level),
|
process = Process(
|
||||||
daemon=True,
|
target=run_model_worker,
|
||||||
)
|
name=f"model_worker - {model_name}",
|
||||||
|
kwargs=dict(model_name=model_name,
|
||||||
processes["model_worker"] = process
|
controller_address=args.controller_address,
|
||||||
|
log_level=log_level,
|
||||||
|
q=queue,
|
||||||
|
started_event=e),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
processes["model_worker"][model_name] = process
|
||||||
|
|
||||||
if args.api_worker:
|
if args.api_worker:
|
||||||
configs = get_all_model_worker_configs()
|
configs = get_all_model_worker_configs()
|
||||||
for model_name, config in configs.items():
|
for model_name, config in configs.items():
|
||||||
if config.get("online_api") and config.get("worker_class"):
|
if config.get("online_api") and config.get("worker_class"):
|
||||||
|
e = manager.Event()
|
||||||
|
model_worker_started.append(e)
|
||||||
process = Process(
|
process = Process(
|
||||||
target=run_model_worker,
|
target=run_model_worker,
|
||||||
name=f"model_worker - {model_name}",
|
name=f"api_worker - {model_name}",
|
||||||
args=(model_name, args.controller_address, queue, process_count() + 1, log_level),
|
kwargs=dict(model_name=model_name,
|
||||||
|
controller_address=args.controller_address,
|
||||||
|
log_level=log_level,
|
||||||
|
q=queue,
|
||||||
|
started_event=e),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
|
processes["online_api"][model_name] = process
|
||||||
|
|
||||||
processes["online-api"].append(process)
|
api_started = manager.Event()
|
||||||
|
|
||||||
if args.api:
|
if args.api:
|
||||||
process = Process(
|
process = Process(
|
||||||
target=run_api_server,
|
target=run_api_server,
|
||||||
name=f"API Server",
|
name=f"API Server",
|
||||||
args=(queue, process_count() + 1),
|
kwargs=dict(started_event=api_started),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
processes["api"] = process
|
processes["api"] = process
|
||||||
|
|
||||||
|
webui_started = manager.Event()
|
||||||
if args.webui:
|
if args.webui:
|
||||||
process = Process(
|
process = Process(
|
||||||
target=run_webui,
|
target=run_webui,
|
||||||
name=f"WEBUI Server",
|
name=f"WEBUI Server",
|
||||||
args=(queue, process_count() + 1),
|
kwargs=dict(started_event=webui_started),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
processes["webui"] = process
|
processes["webui"] = process
|
||||||
|
|
||||||
if process_count() == 0:
|
if process_count() == 0:
|
||||||
|
|
@ -636,60 +623,106 @@ async def start_main_server():
|
||||||
if p:= processes.get("controller"):
|
if p:= processes.get("controller"):
|
||||||
p.start()
|
p.start()
|
||||||
p.name = f"{p.name} ({p.pid})"
|
p.name = f"{p.name} ({p.pid})"
|
||||||
controller_started.wait()
|
controller_started.wait() # 等待controller启动完成
|
||||||
|
|
||||||
if p:= processes.get("openai_api"):
|
if p:= processes.get("openai_api"):
|
||||||
p.start()
|
p.start()
|
||||||
p.name = f"{p.name} ({p.pid})"
|
p.name = f"{p.name} ({p.pid})"
|
||||||
|
|
||||||
if p:= processes.get("model_worker"):
|
for n, p in processes.get("model_worker", {}).items():
|
||||||
p.start()
|
p.start()
|
||||||
p.name = f"{p.name} ({p.pid})"
|
p.name = f"{p.name} ({p.pid})"
|
||||||
|
|
||||||
for p in processes.get("online-api", []):
|
for n, p in processes.get("online_api", []).items():
|
||||||
p.start()
|
p.start()
|
||||||
p.name = f"{p.name} ({p.pid})"
|
p.name = f"{p.name} ({p.pid})"
|
||||||
|
|
||||||
|
# 等待所有model_worker启动完成
|
||||||
|
for e in model_worker_started:
|
||||||
|
e.wait()
|
||||||
|
|
||||||
if p:= processes.get("api"):
|
if p:= processes.get("api"):
|
||||||
p.start()
|
p.start()
|
||||||
p.name = f"{p.name} ({p.pid})"
|
p.name = f"{p.name} ({p.pid})"
|
||||||
|
api_started.wait() # 等待api.py启动完成
|
||||||
|
|
||||||
if p:= processes.get("webui"):
|
if p:= processes.get("webui"):
|
||||||
p.start()
|
p.start()
|
||||||
p.name = f"{p.name} ({p.pid})"
|
p.name = f"{p.name} ({p.pid})"
|
||||||
|
webui_started.wait() # 等待webui.py启动完成
|
||||||
|
|
||||||
|
dump_server_info(after_start=True, args=args)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
no = queue.get()
|
cmd = queue.get() # 收到切换模型的消息
|
||||||
if no == process_count():
|
e = manager.Event()
|
||||||
time.sleep(0.5)
|
if isinstance(cmd, list):
|
||||||
dump_server_info(after_start=True, args=args)
|
model_name, cmd, new_model_name = cmd
|
||||||
break
|
if cmd == "start": # 运行新模型
|
||||||
else:
|
logger.info(f"准备启动新模型进程:{new_model_name}")
|
||||||
queue.put(no)
|
process = Process(
|
||||||
|
target=run_model_worker,
|
||||||
|
name=f"model_worker - {new_model_name}",
|
||||||
|
kwargs=dict(model_name=new_model_name,
|
||||||
|
controller_address=args.controller_address,
|
||||||
|
log_level=log_level,
|
||||||
|
q=queue,
|
||||||
|
started_event=e),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
process.start()
|
||||||
|
process.name = f"{process.name} ({process.pid})"
|
||||||
|
processes["model_worker"][new_model_name] = process
|
||||||
|
e.wait()
|
||||||
|
logger.info(f"成功启动新模型进程:{new_model_name}")
|
||||||
|
elif cmd == "stop":
|
||||||
|
if process := processes["model_worker"].get(model_name):
|
||||||
|
time.sleep(1)
|
||||||
|
process.terminate()
|
||||||
|
process.join()
|
||||||
|
logger.info(f"停止模型进程:{model_name}")
|
||||||
|
else:
|
||||||
|
logger.error(f"未找到模型进程:{model_name}")
|
||||||
|
elif cmd == "replace":
|
||||||
|
if process := processes["model_worker"].pop(model_name, None):
|
||||||
|
logger.info(f"停止模型进程:{model_name}")
|
||||||
|
start_time = datetime.now()
|
||||||
|
time.sleep(1)
|
||||||
|
process.terminate()
|
||||||
|
process.join()
|
||||||
|
process = Process(
|
||||||
|
target=run_model_worker,
|
||||||
|
name=f"model_worker - {new_model_name}",
|
||||||
|
kwargs=dict(model_name=new_model_name,
|
||||||
|
controller_address=args.controller_address,
|
||||||
|
log_level=log_level,
|
||||||
|
q=queue,
|
||||||
|
started_event=e),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
process.start()
|
||||||
|
process.name = f"{process.name} ({process.pid})"
|
||||||
|
processes["model_worker"][new_model_name] = process
|
||||||
|
e.wait()
|
||||||
|
timing = datetime.now() - start_time
|
||||||
|
logger.info(f"成功启动新模型进程:{new_model_name}。用时:{timing}。")
|
||||||
|
else:
|
||||||
|
logger.error(f"未找到模型进程:{model_name}")
|
||||||
|
|
||||||
if model_worker_process := processes.get("model_worker"):
|
|
||||||
model_worker_process.join()
|
# for process in processes.get("model_worker", {}).values():
|
||||||
for process in processes.get("online-api", []):
|
# process.join()
|
||||||
process.join()
|
# for process in processes.get("online_api", {}).values():
|
||||||
for name, process in processes.items():
|
# process.join()
|
||||||
if name not in ["model_worker", "online-api"]:
|
|
||||||
if isinstance(p, list):
|
# for name, process in processes.items():
|
||||||
for work_process in p:
|
# if name not in ["model_worker", "online_api"]:
|
||||||
work_process.join()
|
# if isinstance(p, dict):
|
||||||
else:
|
# for work_process in p.values():
|
||||||
process.join()
|
# work_process.join()
|
||||||
|
# else:
|
||||||
|
# process.join()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# if model_worker_process := processes.pop("model_worker", None):
|
|
||||||
# model_worker_process.terminate()
|
|
||||||
# for process in processes.pop("online-api", []):
|
|
||||||
# process.terminate()
|
|
||||||
# for process in processes.values():
|
|
||||||
#
|
|
||||||
# if isinstance(process, list):
|
|
||||||
# for work_process in process:
|
|
||||||
# work_process.terminate()
|
|
||||||
# else:
|
|
||||||
# process.terminate()
|
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
logger.warning("Caught KeyboardInterrupt! Setting stop event...")
|
logger.warning("Caught KeyboardInterrupt! Setting stop event...")
|
||||||
finally:
|
finally:
|
||||||
|
|
@ -702,10 +735,9 @@ async def start_main_server():
|
||||||
# Queues and other inter-process communication primitives can break when
|
# Queues and other inter-process communication primitives can break when
|
||||||
# process is killed, but we don't care here
|
# process is killed, but we don't care here
|
||||||
|
|
||||||
if isinstance(p, list):
|
if isinstance(p, dict):
|
||||||
for process in p:
|
for process in p.values():
|
||||||
process.kill()
|
process.kill()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
p.kill()
|
p.kill()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,9 @@ def dialogue_page(api: ApiRequest):
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_llm_change():
|
def on_llm_change():
|
||||||
st.session_state["prev_llm_model"] = llm_model
|
config = get_model_worker_config(llm_model)
|
||||||
|
if not config.get("online_api"): # 只有本地model_worker可以切换模型
|
||||||
|
st.session_state["prev_llm_model"] = llm_model
|
||||||
|
|
||||||
def llm_model_format_func(x):
|
def llm_model_format_func(x):
|
||||||
if x in running_models:
|
if x in running_models:
|
||||||
|
|
@ -91,7 +93,7 @@ def dialogue_page(api: ApiRequest):
|
||||||
)
|
)
|
||||||
if (st.session_state.get("prev_llm_model") != llm_model
|
if (st.session_state.get("prev_llm_model") != llm_model
|
||||||
and not get_model_worker_config(llm_model).get("online_api")):
|
and not get_model_worker_config(llm_model).get("online_api")):
|
||||||
with st.spinner(f"正在加载模型: {llm_model}"):
|
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
|
||||||
r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
|
r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
|
||||||
st.session_state["prev_llm_model"] = llm_model
|
st.session_state["prev_llm_model"] = llm_model
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue