diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 53097db..dddd401 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -7,6 +7,19 @@ logger.setLevel(logging.INFO) logging.basicConfig(format=LOG_FORMAT) +# 分布式部署时,不运行LLM的机器上可以不装torch +def default_device(): + try: + import torch + if torch.cuda.is_available(): + return "cuda" + if torch.backends.mps.is_available(): + return "mps" + except: + pass + return "cpu" + + # 在以下字典中修改属性值,以指定本地embedding模型存储位置 # 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese" # 此处请写绝对路径 @@ -34,7 +47,6 @@ EMBEDDING_MODEL = "m3e-base" # Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 EMBEDDING_DEVICE = "auto" - llm_model_dict = { "chatglm-6b": { "local_model_path": "THUDM/chatglm-6b", @@ -74,6 +86,15 @@ llm_model_dict = { "api_key": os.environ.get("OPENAI_API_KEY"), "openai_proxy": os.environ.get("OPENAI_PROXY") }, + # 线上模型。当前支持智谱AI。 + # 如果没有设置有效的local_model_path,则认为是在线模型API。 + # 请在server_config中为每个在线API设置不同的端口 + "chatglm-api": { + "api_base_url": "http://127.0.0.1:8888/v1", + "api_key": os.environ.get("ZHIPUAI_API_KEY"), + "provider": "ChatGLMWorker", + "version": "chatglm_pro", + }, } # LLM 名称 diff --git a/configs/server_config.py.example b/configs/server_config.py.example index 00f94ea..ad731e3 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -1,4 +1,8 @@ -from .model_config import LLM_MODEL, LLM_DEVICE +from .model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE +import httpx + +# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。 +HTTPX_DEFAULT_TIMEOUT = 300.0 # API 是否开启跨域,默认为False,如果需要开启,请设置为True # is open cross domain @@ -29,15 +33,18 @@ FSCHAT_OPENAI_API = { # 这些模型必须是在model_config.llm_model_dict中正确配置的。 # 在启动startup.py时,可用通过`--model-worker --model-name xxxx`指定模型,不指定则为LLM_MODEL FSCHAT_MODEL_WORKERS = { - LLM_MODEL: { + # 所有模型共用的默认配置,可在模型专项配置或llm_model_dict中进行覆盖。 + "default": { "host": DEFAULT_BIND_HOST, "port": 20002, "device": LLM_DEVICE, - # todo: 多卡加载需要配置的参数 + + # 多卡加载需要配置的参数 # "gpus": None, # 使用的GPU,以str的格式指定,如"0,1" # "num_gpus": 1, # 使用GPU的数量 - # 以下为非常用参数,可根据需要配置 # "max_gpu_memory": "20GiB", # 每个GPU占用的最大显存 + + # 以下为非常用参数,可根据需要配置 # "load_8bit": False, # 开启8bit量化 # "cpu_offloading": None, # "gptq_ckpt": None, @@ -53,11 +60,17 @@ FSCHAT_MODEL_WORKERS = { # "stream_interval": 2, # "no_register": False, }, + "baichuan-7b": { # 使用default中的IP和端口 + "device": "cpu", + }, + "chatglm-api": { # 请为每个在线API设置不同的端口 + "port": 20003, + }, } # fastchat multi model worker server FSCHAT_MULTI_MODEL_WORKERS = { - # todo + # TODO: } # fastchat controller server diff --git a/server/api.py b/server/api.py index fe5e156..37954b7 100644 --- a/server/api.py +++ b/server/api.py @@ -4,11 +4,12 @@ import os sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from configs.model_config import NLTK_DATA_PATH -from configs.server_config import OPEN_CROSS_DOMAIN +from configs.model_config import LLM_MODEL, NLTK_DATA_PATH +from configs.server_config import OPEN_CROSS_DOMAIN, HTTPX_DEFAULT_TIMEOUT from configs import VERSION import argparse import uvicorn +from fastapi import Body from fastapi.middleware.cors import CORSMiddleware from starlette.responses import RedirectResponse from server.chat import (chat, knowledge_base_chat, openai_chat, @@ -17,7 +18,8 @@ from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store, search_docs, DocumentWithScore) -from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline +from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address +import httpx from typing import List nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path @@ -123,6 +125,75 @@ def create_app(): summary="根据content中文档重建向量库,流式输出处理进度。" )(recreate_vector_store) + # LLM模型相关接口 + @app.post("/llm_model/list_models", + tags=["LLM Model Management"], + summary="列出当前已加载的模型") + def list_models( + controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) + ) -> BaseResponse: + ''' + 从fastchat controller获取已加载模型列表 + ''' + try: + controller_address = controller_address or fschat_controller_address() + r = httpx.post(controller_address + "/list_models") + return BaseResponse(data=r.json()["models"]) + except Exception as e: + return BaseResponse( + code=500, + data=[], + msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}") + + @app.post("/llm_model/stop", + tags=["LLM Model Management"], + summary="停止指定的LLM模型(Model Worker)", + ) + def stop_llm_model( + model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]), + controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) + ) -> BaseResponse: + ''' + 向fastchat controller请求停止某个LLM模型。 + 注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。 + ''' + try: + controller_address = controller_address or fschat_controller_address() + r = httpx.post( + controller_address + "/release_worker", + json={"model_name": model_name}, + ) + return r.json() + except Exception as e: + return BaseResponse( + code=500, + msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}") + + @app.post("/llm_model/change", + tags=["LLM Model Management"], + summary="切换指定的LLM模型(Model Worker)", + ) + def change_llm_model( + model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]), + new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]), + controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) + ): + ''' + 向fastchat controller请求切换LLM模型。 + ''' + try: + controller_address = controller_address or fschat_controller_address() + r = httpx.post( + controller_address + "/release_worker", + json={"model_name": model_name, "new_model_name": new_model_name}, + timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model + ) + return r.json() + except Exception as e: + return BaseResponse( + code=500, + msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}") + return app diff --git a/server/chat/chat.py b/server/chat/chat.py index 2e939f1..ba23a5a 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -20,11 +20,13 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成 {"role": "assistant", "content": "虎头虎脑"}]] ), stream: bool = Body(False, description="流式输出"), + model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), ): history = [History.from_data(h) for h in history] async def chat_iterator(query: str, history: List[History] = [], + model_name: str = LLM_MODEL, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() @@ -32,10 +34,10 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成 streaming=True, verbose=True, callbacks=[callback], - openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], - openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], - model_name=LLM_MODEL, - openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy") + openai_api_key=llm_model_dict[model_name]["api_key"], + openai_api_base=llm_model_dict[model_name]["api_base_url"], + model_name=model_name, + openai_proxy=llm_model_dict[model_name].get("openai_proxy") ) input_msg = History(role="user", content="{{ input }}").to_msg_template(False) @@ -61,5 +63,5 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成 await task - return StreamingResponse(chat_iterator(query, history), + return StreamingResponse(chat_iterator(query, history, model_name), media_type="text/event-stream") diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 2774569..69ec25d 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -31,6 +31,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp "content": "虎头虎脑"}]] ), stream: bool = Body(False, description="流式输出"), + model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"), request: Request = None, ): @@ -44,16 +45,17 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp kb: KBService, top_k: int, history: Optional[List[History]], + model_name: str = LLM_MODEL, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() model = ChatOpenAI( streaming=True, verbose=True, callbacks=[callback], - openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], - openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], - model_name=LLM_MODEL, - openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy") + openai_api_key=llm_model_dict[model_name]["api_key"], + openai_api_base=llm_model_dict[model_name]["api_base_url"], + model_name=model_name, + openai_proxy=llm_model_dict[model_name].get("openai_proxy") ) docs = search_docs(query, knowledge_base_name, top_k, score_threshold) context = "\n".join([doc.page_content for doc in docs]) @@ -97,5 +99,5 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp await task - return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history), + return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history, model_name), media_type="text/event-stream") diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 8a2633b..8fe7dae 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -69,6 +69,7 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl "content": "虎头虎脑"}]] ), stream: bool = Body(False, description="流式输出"), + model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), ): if search_engine_name not in SEARCH_ENGINES.keys(): return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") @@ -82,16 +83,17 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl search_engine_name: str, top_k: int, history: Optional[List[History]], + model_name: str = LLM_MODEL, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() model = ChatOpenAI( streaming=True, verbose=True, callbacks=[callback], - openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], - openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], - model_name=LLM_MODEL, - openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy") + openai_api_key=llm_model_dict[model_name]["api_key"], + openai_api_base=llm_model_dict[model_name]["api_base_url"], + model_name=model_name, + openai_proxy=llm_model_dict[model_name].get("openai_proxy") ) docs = lookup_search_engine(query, search_engine_name, top_k) @@ -129,5 +131,5 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl ensure_ascii=False) await task - return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history), + return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history, model_name), media_type="text/event-stream") diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index 3e3dd52..9beab3a 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -6,7 +6,8 @@ from langchain.vectorstores import PGVector from langchain.vectorstores.pgvector import DistanceStrategy from sqlalchemy import text -from configs.model_config import kbs_config +from configs.model_config import EMBEDDING_DEVICE, kbs_config + from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \ score_threshold_process from server.knowledge_base.utils import load_embeddings, KnowledgeFile diff --git a/server/model_workers/__init__.py b/server/model_workers/__init__.py new file mode 100644 index 0000000..932c2f3 --- /dev/null +++ b/server/model_workers/__init__.py @@ -0,0 +1 @@ +from .zhipu import ChatGLMWorker diff --git a/server/model_workers/base.py b/server/model_workers/base.py new file mode 100644 index 0000000..b72f683 --- /dev/null +++ b/server/model_workers/base.py @@ -0,0 +1,71 @@ +from configs.model_config import LOG_PATH +import fastchat.constants +fastchat.constants.LOGDIR = LOG_PATH +from fastchat.serve.model_worker import BaseModelWorker +import uuid +import json +import sys +from pydantic import BaseModel +import fastchat +import threading +from typing import Dict, List + + +# 恢复被fastchat覆盖的标准输出 +sys.stdout = sys.__stdout__ +sys.stderr = sys.__stderr__ + + +class ApiModelOutMsg(BaseModel): + error_code: int = 0 + text: str + +class ApiModelWorker(BaseModelWorker): + BASE_URL: str + SUPPORT_MODELS: List + + def __init__( + self, + model_names: List[str], + controller_addr: str, + worker_addr: str, + context_len: int = 2048, + **kwargs, + ): + kwargs.setdefault("worker_id", uuid.uuid4().hex[:8]) + kwargs.setdefault("model_path", "") + kwargs.setdefault("limit_worker_concurrency", 5) + super().__init__(model_names=model_names, + controller_addr=controller_addr, + worker_addr=worker_addr, + **kwargs) + self.context_len = context_len + self.init_heart_beat() + + def count_token(self, params): + # TODO:需要完善 + print("count token") + print(params) + prompt = params["prompt"] + return {"count": len(str(prompt)), "error_code": 0} + + def generate_stream_gate(self, params): + self.call_ct += 1 + + def generate_gate(self, params): + for x in self.generate_stream_gate(params): + pass + return json.loads(x[:-1].decode()) + + def get_embeddings(self, params): + print("embedding") + print(params) + + # workaround to make program exit with Ctrl+c + # it should be deleted after pr is merged by fastchat + def init_heart_beat(self): + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True, + ) + self.heart_beat_thread.start() diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py new file mode 100644 index 0000000..4e4e15e --- /dev/null +++ b/server/model_workers/zhipu.py @@ -0,0 +1,75 @@ +import zhipuai +from server.model_workers.base import ApiModelWorker +from fastchat import conversation as conv +import sys +import json +from typing import List, Literal + + +class ChatGLMWorker(ApiModelWorker): + BASE_URL = "https://open.bigmodel.cn/api/paas/v3/model-api" + SUPPORT_MODELS = ["chatglm_pro", "chatglm_std", "chatglm_lite"] + + def __init__( + self, + *, + model_names: List[str] = ["chatglm-api"], + version: Literal["chatglm_pro", "chatglm_std", "chatglm_lite"] = "chatglm_std", + controller_addr: str, + worker_addr: str, + **kwargs, + ): + kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) + kwargs.setdefault("context_len", 32768) + super().__init__(**kwargs) + self.version = version + + # 这里的是chatglm api的模板,其它API的conv_template需要定制 + self.conv = conv.Conversation( + name="chatglm-api", + system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。", + messages=[], + roles=["Human", "Assistant"], + sep="\n### ", + stop_str="###", + ) + + def generate_stream_gate(self, params): + # TODO: 支持stream参数,维护request_id,传过来的prompt也有问题 + from server.utils import get_model_worker_config + + super().generate_stream_gate(params) + zhipuai.api_key = get_model_worker_config("chatglm-api").get("api_key") + + response = zhipuai.model_api.sse_invoke( + model=self.version, + prompt=[{"role": "user", "content": params["prompt"]}], + temperature=params.get("temperature"), + top_p=params.get("top_p"), + incremental=False, + ) + for e in response.events(): + if e.event == "add": + yield json.dumps({"error_code": 0, "text": e.data}, ensure_ascii=False).encode() + b"\0" + # TODO: 更健壮的消息处理 + # elif e.event == "finish": + # ... + + def get_embeddings(self, params): + # TODO: 支持embeddings + print("embedding") + print(params) + + +if __name__ == "__main__": + import uvicorn + from server.utils import MakeFastAPIOffline + from fastchat.serve.model_worker import app + + worker = ChatGLMWorker( + controller_addr="http://127.0.0.1:20001", + worker_addr="http://127.0.0.1:20003", + ) + sys.modules["fastchat.serve.model_worker"].worker = worker + MakeFastAPIOffline(app) + uvicorn.run(app, port=20003) diff --git a/server/utils.py b/server/utils.py index d716582..20a1562 100644 --- a/server/utils.py +++ b/server/utils.py @@ -5,13 +5,17 @@ import torch from fastapi import FastAPI from pathlib import Path import asyncio -from configs.model_config import LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE -from typing import Literal, Optional +from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDING_DEVICE +from configs.server_config import FSCHAT_MODEL_WORKERS +import os +from server import model_workers +from typing import Literal, Optional, Any class BaseResponse(BaseModel): code: int = pydantic.Field(200, description="API status code") msg: str = pydantic.Field("success", description="API status message") + data: Any = pydantic.Field(None, description="API data") class Config: schema_extra = { @@ -201,10 +205,28 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict: config = FSCHAT_MODEL_WORKERS.get("default", {}).copy() config.update(llm_model_dict.get(model_name, {})) config.update(FSCHAT_MODEL_WORKERS.get(model_name, {})) - config["device"] = llm_device(config.get("device")) + + # 如果没有设置有效的local_model_path,则认为是在线模型API + if not os.path.isdir(config.get("local_model_path", "")): + config["online_api"] = True + if provider := config.get("provider"): + try: + config["worker_class"] = getattr(model_workers, provider) + except Exception as e: + print(f"在线模型 ‘{model_name}’ 的provider没有正确配置") + return config +def get_all_model_worker_configs() -> dict: + result = {} + model_names = set(llm_model_dict.keys()) | set(FSCHAT_MODEL_WORKERS.keys()) + for name in model_names: + if name != "default": + result[name] = get_model_worker_config(name) + return result + + def fschat_controller_address() -> str: from configs.server_config import FSCHAT_CONTROLLER diff --git a/startup.py b/startup.py index 5ef05ce..afd1c6d 100644 --- a/startup.py +++ b/startup.py @@ -16,14 +16,14 @@ except: sys.path.append(os.path.dirname(os.path.dirname(__file__))) from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \ logger -from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS, +from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER, FSCHAT_OPENAI_API, ) from server.utils import (fschat_controller_address, fschat_model_worker_address, fschat_openai_api_address, set_httpx_timeout, - llm_device, embedding_device, get_model_worker_config) -from server.utils import MakeFastAPIOffline, FastAPI + get_model_worker_config, get_all_model_worker_configs, + MakeFastAPIOffline, FastAPI, llm_device, embedding_device) import argparse -from typing import Tuple, List +from typing import Tuple, List, Dict from configs import VERSION @@ -41,6 +41,7 @@ def create_controller_app( MakeFastAPIOffline(app) app.title = "FastChat Controller" + app._controller = controller return app @@ -97,43 +98,62 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse ) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus - gptq_config = GptqConfig( - ckpt=args.gptq_ckpt or args.model_path, - wbits=args.gptq_wbits, - groupsize=args.gptq_groupsize, - act_order=args.gptq_act_order, - ) - awq_config = AWQConfig( - ckpt=args.awq_ckpt or args.model_path, - wbits=args.awq_wbits, - groupsize=args.awq_groupsize, - ) + # 在线模型API + if worker_class := kwargs.get("worker_class"): + worker = worker_class(model_names=args.model_names, + controller_addr=args.controller_address, + worker_addr=args.worker_address) + # 本地模型 + else: + # workaround to make program exit with Ctrl+c + # it should be deleted after pr is merged by fastchat + def _new_init_heart_beat(self): + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True, + ) + self.heart_beat_thread.start() - worker = ModelWorker( - controller_addr=args.controller_address, - worker_addr=args.worker_address, - worker_id=worker_id, - model_path=args.model_path, - model_names=args.model_names, - limit_worker_concurrency=args.limit_worker_concurrency, - no_register=args.no_register, - device=args.device, - num_gpus=args.num_gpus, - max_gpu_memory=args.max_gpu_memory, - load_8bit=args.load_8bit, - cpu_offloading=args.cpu_offloading, - gptq_config=gptq_config, - awq_config=awq_config, - stream_interval=args.stream_interval, - conv_template=args.conv_template, - ) + ModelWorker.init_heart_beat = _new_init_heart_beat + + gptq_config = GptqConfig( + ckpt=args.gptq_ckpt or args.model_path, + wbits=args.gptq_wbits, + groupsize=args.gptq_groupsize, + act_order=args.gptq_act_order, + ) + awq_config = AWQConfig( + ckpt=args.awq_ckpt or args.model_path, + wbits=args.awq_wbits, + groupsize=args.awq_groupsize, + ) + + worker = ModelWorker( + controller_addr=args.controller_address, + worker_addr=args.worker_address, + worker_id=worker_id, + model_path=args.model_path, + model_names=args.model_names, + limit_worker_concurrency=args.limit_worker_concurrency, + no_register=args.no_register, + device=args.device, + num_gpus=args.num_gpus, + max_gpu_memory=args.max_gpu_memory, + load_8bit=args.load_8bit, + cpu_offloading=args.cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + stream_interval=args.stream_interval, + conv_template=args.conv_template, + ) + sys.modules["fastchat.serve.model_worker"].args = args + sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config sys.modules["fastchat.serve.model_worker"].worker = worker - sys.modules["fastchat.serve.model_worker"].args = args - sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config MakeFastAPIOffline(app) - app.title = f"FastChat LLM Server ({LLM_MODEL})" + app.title = f"FastChat LLM Server ({args.model_names[0]})" + app._worker = worker return app @@ -190,6 +210,9 @@ def _set_app_seq(app: FastAPI, q: Queue, run_seq: int): def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"): import uvicorn + import httpx + from fastapi import Body + import time import sys app = create_controller_app( @@ -198,11 +221,71 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"): ) _set_app_seq(app, q, run_seq) + # add interface to release and load model worker + @app.post("/release_worker") + def release_worker( + model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]), + # worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[fschat_controller_address()]), + new_model_name: str = Body(None, description="释放后加载该模型"), + keep_origin: bool = Body(False, description="不释放原模型,加载新模型") + ) -> Dict: + available_models = app._controller.list_models() + if new_model_name in available_models: + msg = f"要切换的LLM模型 {new_model_name} 已经存在" + logger.info(msg) + return {"code": 500, "msg": msg} + + if new_model_name: + logger.info(f"开始切换LLM模型:从 {model_name} 到 {new_model_name}") + else: + logger.info(f"即将停止LLM模型: {model_name}") + + if model_name not in available_models: + msg = f"the model {model_name} is not available" + logger.error(msg) + return {"code": 500, "msg": msg} + + worker_address = app._controller.get_worker_address(model_name) + if not worker_address: + msg = f"can not find model_worker address for {model_name}" + logger.error(msg) + return {"code": 500, "msg": msg} + + r = httpx.post(worker_address + "/release", + json={"new_model_name": new_model_name, "keep_origin": keep_origin}) + if r.status_code != 200: + msg = f"failed to release model: {model_name}" + logger.error(msg) + return {"code": 500, "msg": msg} + + if new_model_name: + timer = 300 # wait 5 minutes for new model_worker register + while timer > 0: + models = app._controller.list_models() + if new_model_name in models: + break + time.sleep(1) + timer -= 1 + if timer > 0: + msg = f"sucess change model from {model_name} to {new_model_name}" + logger.info(msg) + return {"code": 200, "msg": msg} + else: + msg = f"failed change model from {model_name} to {new_model_name}" + logger.error(msg) + return {"code": 500, "msg": msg} + else: + msg = f"sucess to release model: {model_name}" + logger.info(msg) + return {"code": 200, "msg": msg} + host = FSCHAT_CONTROLLER["host"] port = FSCHAT_CONTROLLER["port"] + if log_level == "ERROR": sys.stdout = sys.__stdout__ sys.stderr = sys.__stderr__ + uvicorn.run(app, host=host, port=port, log_level=log_level.lower()) @@ -214,16 +297,17 @@ def run_model_worker( log_level: str ="INFO", ): import uvicorn + from fastapi import Body import sys kwargs = get_model_worker_config(model_name) host = kwargs.pop("host") port = kwargs.pop("port") - model_path = llm_model_dict[model_name].get("local_model_path", "") - kwargs["model_path"] = model_path kwargs["model_names"] = [model_name] kwargs["controller_address"] = controller_address or fschat_controller_address() - kwargs["worker_address"] = fschat_model_worker_address() + kwargs["worker_address"] = fschat_model_worker_address(model_name) + model_path = kwargs.get("local_model_path", "") + kwargs["model_path"] = model_path app = create_model_worker_app(log_level=log_level, **kwargs) _set_app_seq(app, q, run_seq) @@ -231,6 +315,22 @@ def run_model_worker( sys.stdout = sys.__stdout__ sys.stderr = sys.__stderr__ + # add interface to release and load model + @app.post("/release") + def release_model( + new_model_name: str = Body(None, description="释放后加载该模型"), + keep_origin: bool = Body(False, description="不释放原模型,加载新模型") + ) -> Dict: + if keep_origin: + if new_model_name: + q.put(["start", new_model_name]) + else: + if new_model_name: + q.put(["replace", new_model_name]) + else: + q.put(["stop"]) + return {"code": 200, "msg": "done"} + uvicorn.run(app, host=host, port=port, log_level=log_level.lower()) @@ -337,6 +437,13 @@ def parse_args() -> argparse.ArgumentParser: help="run api.py server", dest="api", ) + parser.add_argument( + "-p", + "--api-worker", + action="store_true", + help="run online model api such as zhipuai", + dest="api_worker", + ) parser.add_argument( "-w", "--webui", @@ -368,9 +475,14 @@ def dump_server_info(after_start=False, args=None): print(f"项目版本:{VERSION}") print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}") print("\n") - print(f"当前LLM模型:{LLM_MODEL} @ {llm_device()}") - pprint(llm_model_dict[LLM_MODEL]) + + model = LLM_MODEL + if args and args.model_name: + model = args.model_name + print(f"当前LLM模型:{model} @ {llm_device()}") + pprint(llm_model_dict[model]) print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}") + if after_start: print("\n") print(f"服务端运行信息:") @@ -396,17 +508,20 @@ if __name__ == "__main__": args.openai_api = True args.model_worker = True args.api = True + args.api_worker = True args.webui = True elif args.all_api: args.openai_api = True args.model_worker = True args.api = True + args.api_worker = True args.webui = False elif args.llm_api: args.openai_api = True args.model_worker = True + args.api_worker = True args.api = False args.webui = False @@ -416,7 +531,11 @@ if __name__ == "__main__": logger.info(f"正在启动服务:") logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}") - processes = {} + processes = {"online-api": []} + + def process_count(): + return len(processes) + len(processes["online-api"]) -1 + if args.quiet: log_level = "ERROR" else: @@ -426,7 +545,7 @@ if __name__ == "__main__": process = Process( target=run_controller, name=f"controller({os.getpid()})", - args=(queue, len(processes) + 1, log_level), + args=(queue, process_count() + 1, log_level), daemon=True, ) process.start() @@ -435,29 +554,42 @@ if __name__ == "__main__": process = Process( target=run_openai_api, name=f"openai_api({os.getpid()})", - args=(queue, len(processes) + 1), + args=(queue, process_count() + 1), daemon=True, ) process.start() processes["openai_api"] = process if args.model_worker: - model_path = llm_model_dict[args.model_name].get("local_model_path", "") - if os.path.isdir(model_path): + config = get_model_worker_config(args.model_name) + if not config.get("online_api"): process = Process( target=run_model_worker, - name=f"model_worker({os.getpid()})", - args=(args.model_name, args.controller_address, queue, len(processes) + 1, log_level), + name=f"model_worker - {args.model_name} ({os.getpid()})", + args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level), daemon=True, ) process.start() processes["model_worker"] = process + if args.api_worker: + configs = get_all_model_worker_configs() + for model_name, config in configs.items(): + if config.get("online_api") and config.get("worker_class"): + process = Process( + target=run_model_worker, + name=f"model_worker - {model_name} ({os.getpid()})", + args=(model_name, args.controller_address, queue, process_count() + 1, log_level), + daemon=True, + ) + process.start() + processes["online-api"].append(process) + if args.api: process = Process( target=run_api_server, name=f"API Server{os.getpid()})", - args=(queue, len(processes) + 1), + args=(queue, process_count() + 1), daemon=True, ) process.start() @@ -467,37 +599,38 @@ if __name__ == "__main__": process = Process( target=run_webui, name=f"WEBUI Server{os.getpid()})", - args=(queue, len(processes) + 1), + args=(queue, process_count() + 1), daemon=True, ) process.start() processes["webui"] = process - if len(processes) == 0: + if process_count() == 0: parser.print_help() else: try: - # log infors while True: no = queue.get() - if no == len(processes): + if no == process_count(): time.sleep(0.5) dump_server_info(after_start=True, args=args) break else: queue.put(no) - if model_worker_process := processes.get("model_worker"): + if model_worker_process := processes.pop("model_worker", None): model_worker_process.join() + for process in processes.pop("online-api", []): + process.join() for name, process in processes.items(): - if name != "model_worker": - process.join() + process.join() except: - if model_worker_process := processes.get("model_worker"): + if model_worker_process := processes.pop("model_worker", None): model_worker_process.terminate() + for process in processes.pop("online-api", []): + process.terminate() for name, process in processes.items(): - if name != "model_worker": - process.terminate() + process.terminate() # 服务启动后接口调用示例: diff --git a/tests/api/test_kb_api.py b/tests/api/test_kb_api.py index cefeec0..51bbac1 100644 --- a/tests/api/test_kb_api.py +++ b/tests/api/test_kb_api.py @@ -1,4 +1,3 @@ -from doctest import testfile import requests import json import sys diff --git a/tests/api/test_llm_api.py b/tests/api/test_llm_api.py new file mode 100644 index 0000000..f348fe7 --- /dev/null +++ b/tests/api/test_llm_api.py @@ -0,0 +1,74 @@ +import requests +import json +import sys +from pathlib import Path + +root_path = Path(__file__).parent.parent.parent +sys.path.append(str(root_path)) +from configs.server_config import api_address, FSCHAT_MODEL_WORKERS +from configs.model_config import LLM_MODEL, llm_model_dict + +from pprint import pprint +import random + + +def get_configured_models(): + model_workers = list(FSCHAT_MODEL_WORKERS) + if "default" in model_workers: + model_workers.remove("default") + + llm_dict = list(llm_model_dict) + + return model_workers, llm_dict + + +api_base_url = api_address() + + +def get_running_models(api="/llm_model/list_models"): + url = api_base_url + api + r = requests.post(url) + if r.status_code == 200: + return r.json()["data"] + return [] + + +def test_running_models(api="/llm_model/list_models"): + url = api_base_url + api + r = requests.post(url) + assert r.status_code == 200 + print("\n获取当前正在运行的模型列表:") + pprint(r.json()) + assert isinstance(r.json()["data"], list) + assert len(r.json()["data"]) > 0 + + +# 不建议使用stop_model功能。按现在的实现,停止了就只能手动再启动 +# def test_stop_model(api="/llm_model/stop"): +# url = api_base_url + api +# r = requests.post(url, json={""}) + + +def test_change_model(api="/llm_model/change"): + url = api_base_url + api + + running_models = get_running_models() + assert len(running_models) > 0 + + model_workers, llm_dict = get_configured_models() + + availabel_new_models = set(model_workers) - set(running_models) + if len(availabel_new_models) == 0: + availabel_new_models = set(llm_dict) - set(running_models) + availabel_new_models = list(availabel_new_models) + assert len(availabel_new_models) > 0 + print(availabel_new_models) + + model_name = random.choice(running_models) + new_model_name = random.choice(availabel_new_models) + print(f"\n尝试将模型从 {model_name} 切换到 {new_model_name}") + r = requests.post(url, json={"model_name": model_name, "new_model_name": new_model_name}) + assert r.status_code == 200 + + running_models = get_running_models() + assert new_model_name in running_models diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 2d5e260..730b642 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -1,10 +1,14 @@ import streamlit as st +from configs.server_config import FSCHAT_MODEL_WORKERS from webui_pages.utils import * from streamlit_chatbox import * from datetime import datetime from server.chat.search_engine_chat import SEARCH_ENGINES -from typing import List, Dict import os +from configs.model_config import llm_model_dict, LLM_MODEL +from server.utils import get_model_worker_config +from typing import List, Dict + chat_box = ChatBox( assistant_avatar=os.path.join( @@ -59,6 +63,38 @@ def dialogue_page(api: ApiRequest): on_change=on_mode_change, key="dialogue_mode", ) + + def on_llm_change(): + st.session_state["prev_llm_model"] = llm_model + + def llm_model_format_func(x): + if x in running_models: + return f"{x} (Running)" + return x + + running_models = api.list_running_models() + config_models = api.list_config_models() + for x in running_models: + if x in config_models: + config_models.remove(x) + llm_models = running_models + config_models + if "prev_llm_model" not in st.session_state: + index = llm_models.index(LLM_MODEL) + else: + index = 0 + llm_model = st.selectbox("选择LLM模型:", + llm_models, + index, + format_func=llm_model_format_func, + on_change=on_llm_change, + # key="llm_model", + ) + if (st.session_state.get("prev_llm_model") != llm_model + and not get_model_worker_config(llm_model).get("online_api")): + with st.spinner(f"正在加载模型: {llm_model}"): + r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model) + st.session_state["prev_llm_model"] = llm_model + history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN) def on_kb_change(): @@ -99,7 +135,7 @@ def dialogue_page(api: ApiRequest): if dialogue_mode == "LLM 对话": chat_box.ai_say("正在思考...") text = "" - r = api.chat_chat(prompt, history) + r = api.chat_chat(prompt, history=history, model=llm_model) for t in r: if error_msg := check_error_msg(t): # check whether error occured st.error(error_msg) @@ -114,7 +150,7 @@ def dialogue_page(api: ApiRequest): Markdown("...", in_expander=True, title="知识库匹配结果"), ]) text = "" - for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history): + for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history, model=llm_model): if error_msg := check_error_msg(d): # check whether error occured st.error(error_msg) text += d["answer"] @@ -127,8 +163,8 @@ def dialogue_page(api: ApiRequest): Markdown("...", in_expander=True, title="网络搜索结果"), ]) text = "" - for d in api.search_engine_chat(prompt, search_engine, se_top_k): - if error_msg := check_error_msg(d): # check whether error occured + for d in api.search_engine_chat(prompt, search_engine, se_top_k, model=llm_model): + if error_msg := check_error_msg(d): # check whether error occured st.error(error_msg) else: text += d["answer"] diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 58b08e8..0851104 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -6,12 +6,14 @@ from configs.model_config import ( DEFAULT_VS_TYPE, KB_ROOT_PATH, LLM_MODEL, + llm_model_dict, HISTORY_LEN, SCORE_THRESHOLD, VECTOR_SEARCH_TOP_K, SEARCH_ENGINE_TOP_K, logger, ) +from configs.server_config import HTTPX_DEFAULT_TIMEOUT import httpx import asyncio from server.chat.openai_chat import OpenAiChatMsgIn @@ -42,7 +44,7 @@ class ApiRequest: def __init__( self, base_url: str = "http://127.0.0.1:7861", - timeout: float = 60.0, + timeout: float = HTTPX_DEFAULT_TIMEOUT, no_remote_api: bool = False, # call api view function directly ): self.base_url = base_url @@ -289,6 +291,7 @@ class ApiRequest: query: str, history: List[Dict] = [], stream: bool = True, + model: str = LLM_MODEL, no_remote_api: bool = None, ): ''' @@ -301,6 +304,7 @@ class ApiRequest: "query": query, "history": history, "stream": stream, + "model_name": model, } print(f"received input message:") @@ -322,6 +326,7 @@ class ApiRequest: score_threshold: float = SCORE_THRESHOLD, history: List[Dict] = [], stream: bool = True, + model: str = LLM_MODEL, no_remote_api: bool = None, ): ''' @@ -337,6 +342,7 @@ class ApiRequest: "score_threshold": score_threshold, "history": history, "stream": stream, + "model_name": model, "local_doc_url": no_remote_api, } @@ -361,6 +367,7 @@ class ApiRequest: search_engine_name: str, top_k: int = SEARCH_ENGINE_TOP_K, stream: bool = True, + model: str = LLM_MODEL, no_remote_api: bool = None, ): ''' @@ -374,6 +381,7 @@ class ApiRequest: "search_engine_name": search_engine_name, "top_k": top_k, "stream": stream, + "model_name": model, } print(f"received input message:") @@ -645,6 +653,84 @@ class ApiRequest: ) return self._httpx_stream2generator(response, as_json=True) + def list_running_models(self, controller_address: str = None): + ''' + 获取Fastchat中正运行的模型列表 + ''' + r = self.post( + "/llm_model/list_models", + ) + return r.json().get("data", []) + + def list_config_models(self): + ''' + 获取configs中配置的模型列表 + ''' + return list(llm_model_dict.keys()) + + def stop_llm_model( + self, + model_name: str, + controller_address: str = None, + ): + ''' + 停止某个LLM模型。 + 注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。 + ''' + data = { + "model_name": model_name, + "controller_address": controller_address, + } + r = self.post( + "/llm_model/stop", + json=data, + ) + return r.json() + + def change_llm_model( + self, + model_name: str, + new_model_name: str, + controller_address: str = None, + ): + ''' + 向fastchat controller请求切换LLM模型。 + ''' + if not model_name or not new_model_name: + return + + if new_model_name == model_name: + return { + "code": 200, + "msg": "什么都不用做" + } + + running_models = self.list_running_models() + if model_name not in running_models: + return { + "code": 500, + "msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}" + } + + config_models = self.list_config_models() + if new_model_name not in config_models: + return { + "code": 500, + "msg": f"要切换的模型'{new_model_name}'在configs中没有配置。" + } + + data = { + "model_name": model_name, + "new_model_name": new_model_name, + "controller_address": controller_address, + } + r = self.post( + "/llm_model/change", + json=data, + timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model + ) + return r.json() + def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str: '''