Langchain-Chatchat/startup.py

786 lines
27 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import multiprocessing as mp
import os
import subprocess
import sys
from multiprocessing import Process
from datetime import datetime
from pprint import pprint
# 设置numexpr最大线程数默认为CPU核心数
try:
import numexpr
n_cores = numexpr.utils.detect_number_of_cores()
os.environ["NUMEXPR_MAX_THREADS"] = str(n_cores)
except:
pass
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
logger, log_verbose, TEXT_SPLITTER
from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER,
FSCHAT_OPENAI_API, HTTPX_DEFAULT_TIMEOUT)
from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_timeout,
get_model_worker_config, get_all_model_worker_configs,
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
import argparse
from typing import Tuple, List, Dict
from configs import VERSION
def create_controller_app(
dispatch_method: str,
log_level: str = "INFO",
) -> FastAPI:
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.controller import app, Controller, logger
logger.setLevel(log_level)
controller = Controller(dispatch_method)
sys.modules["fastchat.serve.controller"].controller = controller
MakeFastAPIOffline(app)
app.title = "FastChat Controller"
app._controller = controller
return app
def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger
import argparse
import threading
import fastchat.serve.model_worker
logger.setLevel(log_level)
# workaround to make program exit with Ctrl+c
# it should be deleted after pr is merged by fastchat
def _new_init_heart_beat(self):
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
)
self.heart_beat_thread.start()
ModelWorker.init_heart_beat = _new_init_heart_beat
parser = argparse.ArgumentParser()
args = parser.parse_args([])
# default args. should be deleted after pr is merged by fastchat
args.gpus = None
args.max_gpu_memory = "20GiB"
args.load_8bit = False
args.cpu_offloading = None
args.gptq_ckpt = None
args.gptq_wbits = 16
args.gptq_groupsize = -1
args.gptq_act_order = False
args.awq_ckpt = None
args.awq_wbits = 16
args.awq_groupsize = -1
args.num_gpus = 1
args.model_names = []
args.conv_template = None
args.limit_worker_concurrency = 5
args.stream_interval = 2
args.no_register = False
args.embed_in_truncate = False
for k, v in kwargs.items():
setattr(args, k, v)
if args.gpus:
if args.num_gpus is None:
args.num_gpus = len(args.gpus.split(','))
if len(args.gpus.split(",")) < args.num_gpus:
raise ValueError(
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
# 在线模型API
if worker_class := kwargs.get("worker_class"):
worker = worker_class(model_names=args.model_names,
controller_addr=args.controller_address,
worker_addr=args.worker_address)
# 本地模型
else:
# workaround to make program exit with Ctrl+c
# it should be deleted after pr is merged by fastchat
def _new_init_heart_beat(self):
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
)
self.heart_beat_thread.start()
ModelWorker.init_heart_beat = _new_init_heart_beat
gptq_config = GptqConfig(
ckpt=args.gptq_ckpt or args.model_path,
wbits=args.gptq_wbits,
groupsize=args.gptq_groupsize,
act_order=args.gptq_act_order,
)
awq_config = AWQConfig(
ckpt=args.awq_ckpt or args.model_path,
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
)
worker = ModelWorker(
controller_addr=args.controller_address,
worker_addr=args.worker_address,
worker_id=worker_id,
model_path=args.model_path,
model_names=args.model_names,
limit_worker_concurrency=args.limit_worker_concurrency,
no_register=args.no_register,
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
embed_in_truncate=args.embed_in_truncate,
)
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
app.title = f"FastChat LLM Server ({args.model_names[0]})"
app._worker = worker
return app
def create_openai_api_app(
controller_address: str,
api_keys: List = [],
log_level: str = "INFO",
) -> FastAPI:
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
from fastchat.utils import build_logger
logger = build_logger("openai_api", "openai_api.log")
logger.setLevel(log_level)
app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
sys.modules["fastchat.serve.openai_api_server"].logger = logger
app_settings.controller_address = controller_address
app_settings.api_keys = api_keys
MakeFastAPIOffline(app)
app.title = "FastChat OpeanAI API Server"
return app
def _set_app_event(app: FastAPI, started_event: mp.Event = None):
@app.on_event("startup")
async def on_startup():
set_httpx_timeout()
if started_event is not None:
started_event.set()
def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
import uvicorn
import httpx
from fastapi import Body
import time
import sys
app = create_controller_app(
dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"),
log_level=log_level,
)
_set_app_event(app, started_event)
# add interface to release and load model worker
@app.post("/release_worker")
def release_worker(
model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]),
# worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[fschat_controller_address()]),
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict:
available_models = app._controller.list_models()
if new_model_name in available_models:
msg = f"要切换的LLM模型 {new_model_name} 已经存在"
logger.info(msg)
return {"code": 500, "msg": msg}
if new_model_name:
logger.info(f"开始切换LLM模型{model_name}{new_model_name}")
else:
logger.info(f"即将停止LLM模型 {model_name}")
if model_name not in available_models:
msg = f"the model {model_name} is not available"
logger.error(msg)
return {"code": 500, "msg": msg}
worker_address = app._controller.get_worker_address(model_name)
if not worker_address:
msg = f"can not find model_worker address for {model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
r = httpx.post(worker_address + "/release",
json={"new_model_name": new_model_name, "keep_origin": keep_origin})
if r.status_code != 200:
msg = f"failed to release model: {model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
if new_model_name:
timer = HTTPX_DEFAULT_TIMEOUT * 2 # wait for new model_worker register
while timer > 0:
models = app._controller.list_models()
if new_model_name in models:
break
time.sleep(1)
timer -= 1
if timer > 0:
msg = f"sucess change model from {model_name} to {new_model_name}"
logger.info(msg)
return {"code": 200, "msg": msg}
else:
msg = f"failed change model from {model_name} to {new_model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
else:
msg = f"sucess to release model: {model_name}"
logger.info(msg)
return {"code": 200, "msg": msg}
host = FSCHAT_CONTROLLER["host"]
port = FSCHAT_CONTROLLER["port"]
if log_level == "ERROR":
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
def run_model_worker(
model_name: str = LLM_MODEL,
controller_address: str = "",
log_level: str = "INFO",
q: mp.Queue = None,
started_event: mp.Event = None,
):
import uvicorn
from fastapi import Body
import sys
kwargs = get_model_worker_config(model_name)
host = kwargs.pop("host")
port = kwargs.pop("port")
kwargs["model_names"] = [model_name]
kwargs["controller_address"] = controller_address or fschat_controller_address()
kwargs["worker_address"] = fschat_model_worker_address(model_name)
model_path = kwargs.get("local_model_path", "")
kwargs["model_path"] = model_path
app = create_model_worker_app(log_level=log_level, **kwargs)
_set_app_event(app, started_event)
if log_level == "ERROR":
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
# add interface to release and load model
@app.post("/release")
def release_model(
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict:
if keep_origin:
if new_model_name:
q.put([model_name, "start", new_model_name])
else:
if new_model_name:
q.put([model_name, "replace", new_model_name])
else:
q.put([model_name, "stop", None])
return {"code": 200, "msg": "done"}
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
import uvicorn
import sys
controller_addr = fschat_controller_address()
app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet.
_set_app_event(app, started_event)
host = FSCHAT_OPENAI_API["host"]
port = FSCHAT_OPENAI_API["port"]
if log_level == "ERROR":
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
uvicorn.run(app, host=host, port=port)
def run_api_server(started_event: mp.Event = None):
from server.api import create_app
import uvicorn
app = create_app()
_set_app_event(app, started_event)
host = API_SERVER["host"]
port = API_SERVER["port"]
uvicorn.run(app, host=host, port=port)
def run_webui(started_event: mp.Event = None):
host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]
p = subprocess.Popen(["streamlit", "run", "webui.py",
"--server.address", host,
"--server.port", str(port),
"--theme.base", "light",
"--theme.primaryColor", "#165dff",
"--theme.secondaryBackgroundColor", "#f5f5f5",
"--theme.textColor", "#000000",
])
started_event.set()
p.wait()
def parse_args() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"-a",
"--all-webui",
action="store_true",
help="run fastchat's controller/openai_api/model_worker servers, run api.py and webui.py",
dest="all_webui",
)
parser.add_argument(
"--all-api",
action="store_true",
help="run fastchat's controller/openai_api/model_worker servers, run api.py",
dest="all_api",
)
parser.add_argument(
"--llm-api",
action="store_true",
help="run fastchat's controller/openai_api/model_worker servers",
dest="llm_api",
)
parser.add_argument(
"-o",
"--openai-api",
action="store_true",
help="run fastchat's controller/openai_api servers",
dest="openai_api",
)
parser.add_argument(
"-m",
"--model-worker",
action="store_true",
help="run fastchat's model_worker server with specified model name. specify --model-name if not using default LLM_MODEL",
dest="model_worker",
)
parser.add_argument(
"-n",
"--model-name",
type=str,
nargs="+",
default=[LLM_MODEL],
help="specify model name for model worker. add addition names with space seperated to start multiple model workers.",
dest="model_name",
)
parser.add_argument(
"-c",
"--controller",
type=str,
help="specify controller address the worker is registered to. default is server_config.FSCHAT_CONTROLLER",
dest="controller_address",
)
parser.add_argument(
"--api",
action="store_true",
help="run api.py server",
dest="api",
)
parser.add_argument(
"-p",
"--api-worker",
action="store_true",
help="run online model api such as zhipuai",
dest="api_worker",
)
parser.add_argument(
"-w",
"--webui",
action="store_true",
help="run webui.py server",
dest="webui",
)
parser.add_argument(
"-q",
"--quiet",
action="store_true",
help="减少fastchat服务log信息",
dest="quiet",
)
args = parser.parse_args()
return args, parser
def dump_server_info(after_start=False, args=None):
import platform
import langchain
import fastchat
from server.utils import api_address, webui_address
print("\n")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print(f"操作系统:{platform.platform()}.")
print(f"python版本{sys.version}")
print(f"项目版本:{VERSION}")
print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}")
print("\n")
models = [LLM_MODEL]
if args and args.model_name:
models = args.model_name
print(f"当前使用的分词器:{TEXT_SPLITTER}")
print(f"当前启动的LLM模型{models} @ {llm_device()}")
for model in models:
pprint(llm_model_dict[model])
print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {embedding_device()}")
if after_start:
print("\n")
print(f"服务端运行信息:")
if args.openai_api:
print(f" OpenAI API Server: {fschat_openai_api_address()}/v1")
print(" (请确认llm_model_dict中配置的api_base_url与上面地址一致。)")
if args.api:
print(f" Chatchat API Server: {api_address()}")
if args.webui:
print(f" Chatchat WEBUI Server: {webui_address()}")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print("\n")
async def start_main_server():
import time
import signal
def handler(signalname):
"""
Python 3.9 has `signal.strsignal(signalnum)` so this closure would not be needed.
Also, 3.8 includes `signal.valid_signals()` that can be used to create a mapping for the same purpose.
"""
def f(signal_received, frame):
raise KeyboardInterrupt(f"{signalname} received")
return f
# This will be inherited by the child process if it is forked (not spawned)
signal.signal(signal.SIGINT, handler("SIGINT"))
signal.signal(signal.SIGTERM, handler("SIGTERM"))
mp.set_start_method("spawn")
manager = mp.Manager()
queue = manager.Queue()
args, parser = parse_args()
if args.all_webui:
args.openai_api = True
args.model_worker = True
args.api = True
args.api_worker = True
args.webui = True
elif args.all_api:
args.openai_api = True
args.model_worker = True
args.api = True
args.api_worker = True
args.webui = False
elif args.llm_api:
args.openai_api = True
args.model_worker = True
args.api_worker = True
args.api = False
args.webui = False
dump_server_info(args=args)
if len(sys.argv) > 1:
logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
processes = {"online_api": {}, "model_worker": {}}
def process_count():
return len(processes) + len(processes["online_api"]) + len(processes["model_worker"]) - 2
if args.quiet or not log_verbose:
log_level = "ERROR"
else:
log_level = "INFO"
controller_started = manager.Event()
if args.openai_api:
process = Process(
target=run_controller,
name=f"controller",
kwargs=dict(log_level=log_level, started_event=controller_started),
daemon=True,
)
processes["controller"] = process
process = Process(
target=run_openai_api,
name=f"openai_api",
daemon=True,
)
processes["openai_api"] = process
model_worker_started = []
if args.model_worker:
for model_name in args.model_name:
config = get_model_worker_config(model_name)
if not config.get("online_api"):
e = manager.Event()
model_worker_started.append(e)
process = Process(
target=run_model_worker,
name=f"model_worker - {model_name}",
kwargs=dict(model_name=model_name,
controller_address=args.controller_address,
log_level=log_level,
q=queue,
started_event=e),
daemon=True,
)
processes["model_worker"][model_name] = process
if args.api_worker:
configs = get_all_model_worker_configs()
for model_name, config in configs.items():
if config.get("online_api") and config.get("worker_class"):
e = manager.Event()
model_worker_started.append(e)
process = Process(
target=run_model_worker,
name=f"api_worker - {model_name}",
kwargs=dict(model_name=model_name,
controller_address=args.controller_address,
log_level=log_level,
q=queue,
started_event=e),
daemon=True,
)
processes["online_api"][model_name] = process
api_started = manager.Event()
if args.api:
process = Process(
target=run_api_server,
name=f"API Server",
kwargs=dict(started_event=api_started),
daemon=True,
)
processes["api"] = process
webui_started = manager.Event()
if args.webui:
process = Process(
target=run_webui,
name=f"WEBUI Server",
kwargs=dict(started_event=webui_started),
daemon=True,
)
processes["webui"] = process
if process_count() == 0:
parser.print_help()
else:
try:
# 保证任务收到SIGINT后能够正常退出
if p:= processes.get("controller"):
p.start()
p.name = f"{p.name} ({p.pid})"
controller_started.wait() # 等待controller启动完成
if p:= processes.get("openai_api"):
p.start()
p.name = f"{p.name} ({p.pid})"
for n, p in processes.get("model_worker", {}).items():
p.start()
p.name = f"{p.name} ({p.pid})"
for n, p in processes.get("online_api", []).items():
p.start()
p.name = f"{p.name} ({p.pid})"
# 等待所有model_worker启动完成
for e in model_worker_started:
e.wait()
if p:= processes.get("api"):
p.start()
p.name = f"{p.name} ({p.pid})"
api_started.wait() # 等待api.py启动完成
if p:= processes.get("webui"):
p.start()
p.name = f"{p.name} ({p.pid})"
webui_started.wait() # 等待webui.py启动完成
dump_server_info(after_start=True, args=args)
while True:
cmd = queue.get() # 收到切换模型的消息
e = manager.Event()
if isinstance(cmd, list):
model_name, cmd, new_model_name = cmd
if cmd == "start": # 运行新模型
logger.info(f"准备启动新模型进程:{new_model_name}")
process = Process(
target=run_model_worker,
name=f"model_worker - {new_model_name}",
kwargs=dict(model_name=new_model_name,
controller_address=args.controller_address,
log_level=log_level,
q=queue,
started_event=e),
daemon=True,
)
process.start()
process.name = f"{process.name} ({process.pid})"
processes["model_worker"][new_model_name] = process
e.wait()
logger.info(f"成功启动新模型进程:{new_model_name}")
elif cmd == "stop":
if process := processes["model_worker"].get(model_name):
time.sleep(1)
process.terminate()
process.join()
logger.info(f"停止模型进程:{model_name}")
else:
logger.error(f"未找到模型进程:{model_name}")
elif cmd == "replace":
if process := processes["model_worker"].pop(model_name, None):
logger.info(f"停止模型进程:{model_name}")
start_time = datetime.now()
time.sleep(1)
process.terminate()
process.join()
process = Process(
target=run_model_worker,
name=f"model_worker - {new_model_name}",
kwargs=dict(model_name=new_model_name,
controller_address=args.controller_address,
log_level=log_level,
q=queue,
started_event=e),
daemon=True,
)
process.start()
process.name = f"{process.name} ({process.pid})"
processes["model_worker"][new_model_name] = process
e.wait()
timing = datetime.now() - start_time
logger.info(f"成功启动新模型进程:{new_model_name}。用时:{timing}")
else:
logger.error(f"未找到模型进程:{model_name}")
# for process in processes.get("model_worker", {}).values():
# process.join()
# for process in processes.get("online_api", {}).values():
# process.join()
# for name, process in processes.items():
# if name not in ["model_worker", "online_api"]:
# if isinstance(p, dict):
# for work_process in p.values():
# work_process.join()
# else:
# process.join()
except Exception as e:
logger.error(e)
logger.warning("Caught KeyboardInterrupt! Setting stop event...")
finally:
# Send SIGINT if process doesn't exit quickly enough, and kill it as last resort
# .is_alive() also implicitly joins the process (good practice in linux)
# while alive_procs := [p for p in processes.values() if p.is_alive()]:
for p in processes.values():
logger.warning("Sending SIGKILL to %s", p)
# Queues and other inter-process communication primitives can break when
# process is killed, but we don't care here
if isinstance(p, dict):
for process in p.values():
process.kill()
else:
p.kill()
for p in processes.values():
logger.info("Process status: %s", p)
if __name__ == "__main__":
if sys.version_info < (3, 10):
loop = asyncio.get_event_loop()
else:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 同步调用协程代码
loop.run_until_complete(start_main_server())
# 服务启动后接口调用示例:
# import openai
# openai.api_key = "EMPTY" # Not support yet
# openai.api_base = "http://localhost:8888/v1"
# model = "chatglm2-6b"
# # create a chat completion
# completion = openai.ChatCompletion.create(
# model=model,
# messages=[{"role": "user", "content": "Hello! What is your name?"}]
# )
# # print the completion
# print(completion.choices[0].message.content)