添加切换模型功能,支持智谱AI在线模型 (#1342)
* 添加LLM模型切换功能,需要在server_config中设置可切换的模型 * add tests for api.py/llm_model/* * - 支持模型切换 - 支持智普AI线上模型 - startup.py增加参数`--api-worker`,自动运行所有的线上API模型。使用`-a (--all-webui), --all-api`时默认开启该选项 * 修复被fastchat覆盖的标准输出 * 对fastchat日志进行更细致的控制,startup.py中增加-q(--quiet)开关,可以减少无用的fastchat日志输出 * 修正chatglm api的对话模板 Co-authored-by: liunux4odoo <liunu@qq.com>
This commit is contained in:
parent
ab4c8d2e5d
commit
6cb1bdf623
|
|
@ -7,6 +7,19 @@ logger.setLevel(logging.INFO)
|
|||
logging.basicConfig(format=LOG_FORMAT)
|
||||
|
||||
|
||||
# 分布式部署时,不运行LLM的机器上可以不装torch
|
||||
def default_device():
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
except:
|
||||
pass
|
||||
return "cpu"
|
||||
|
||||
|
||||
# 在以下字典中修改属性值,以指定本地embedding模型存储位置
|
||||
# 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese"
|
||||
# 此处请写绝对路径
|
||||
|
|
@ -34,7 +47,6 @@ EMBEDDING_MODEL = "m3e-base"
|
|||
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
|
||||
EMBEDDING_DEVICE = "auto"
|
||||
|
||||
|
||||
llm_model_dict = {
|
||||
"chatglm-6b": {
|
||||
"local_model_path": "THUDM/chatglm-6b",
|
||||
|
|
@ -74,6 +86,15 @@ llm_model_dict = {
|
|||
"api_key": os.environ.get("OPENAI_API_KEY"),
|
||||
"openai_proxy": os.environ.get("OPENAI_PROXY")
|
||||
},
|
||||
# 线上模型。当前支持智谱AI。
|
||||
# 如果没有设置有效的local_model_path,则认为是在线模型API。
|
||||
# 请在server_config中为每个在线API设置不同的端口
|
||||
"chatglm-api": {
|
||||
"api_base_url": "http://127.0.0.1:8888/v1",
|
||||
"api_key": os.environ.get("ZHIPUAI_API_KEY"),
|
||||
"provider": "ChatGLMWorker",
|
||||
"version": "chatglm_pro",
|
||||
},
|
||||
}
|
||||
|
||||
# LLM 名称
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
from .model_config import LLM_MODEL, LLM_DEVICE
|
||||
from .model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE
|
||||
import httpx
|
||||
|
||||
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
|
||||
HTTPX_DEFAULT_TIMEOUT = 300.0
|
||||
|
||||
# API 是否开启跨域,默认为False,如果需要开启,请设置为True
|
||||
# is open cross domain
|
||||
|
|
@ -29,15 +33,18 @@ FSCHAT_OPENAI_API = {
|
|||
# 这些模型必须是在model_config.llm_model_dict中正确配置的。
|
||||
# 在启动startup.py时,可用通过`--model-worker --model-name xxxx`指定模型,不指定则为LLM_MODEL
|
||||
FSCHAT_MODEL_WORKERS = {
|
||||
LLM_MODEL: {
|
||||
# 所有模型共用的默认配置,可在模型专项配置或llm_model_dict中进行覆盖。
|
||||
"default": {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 20002,
|
||||
"device": LLM_DEVICE,
|
||||
# todo: 多卡加载需要配置的参数
|
||||
|
||||
# 多卡加载需要配置的参数
|
||||
# "gpus": None, # 使用的GPU,以str的格式指定,如"0,1"
|
||||
# "num_gpus": 1, # 使用GPU的数量
|
||||
# 以下为非常用参数,可根据需要配置
|
||||
# "max_gpu_memory": "20GiB", # 每个GPU占用的最大显存
|
||||
|
||||
# 以下为非常用参数,可根据需要配置
|
||||
# "load_8bit": False, # 开启8bit量化
|
||||
# "cpu_offloading": None,
|
||||
# "gptq_ckpt": None,
|
||||
|
|
@ -53,11 +60,17 @@ FSCHAT_MODEL_WORKERS = {
|
|||
# "stream_interval": 2,
|
||||
# "no_register": False,
|
||||
},
|
||||
"baichuan-7b": { # 使用default中的IP和端口
|
||||
"device": "cpu",
|
||||
},
|
||||
"chatglm-api": { # 请为每个在线API设置不同的端口
|
||||
"port": 20003,
|
||||
},
|
||||
}
|
||||
|
||||
# fastchat multi model worker server
|
||||
FSCHAT_MULTI_MODEL_WORKERS = {
|
||||
# todo
|
||||
# TODO:
|
||||
}
|
||||
|
||||
# fastchat controller server
|
||||
|
|
|
|||
|
|
@ -4,11 +4,12 @@ import os
|
|||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from configs.model_config import NLTK_DATA_PATH
|
||||
from configs.server_config import OPEN_CROSS_DOMAIN
|
||||
from configs.model_config import LLM_MODEL, NLTK_DATA_PATH
|
||||
from configs.server_config import OPEN_CROSS_DOMAIN, HTTPX_DEFAULT_TIMEOUT
|
||||
from configs import VERSION
|
||||
import argparse
|
||||
import uvicorn
|
||||
from fastapi import Body
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.responses import RedirectResponse
|
||||
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
||||
|
|
@ -17,7 +18,8 @@ from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
|||
from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc,
|
||||
update_doc, download_doc, recreate_vector_store,
|
||||
search_docs, DocumentWithScore)
|
||||
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
|
||||
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address
|
||||
import httpx
|
||||
from typing import List
|
||||
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
|
|
@ -123,6 +125,75 @@ def create_app():
|
|||
summary="根据content中文档重建向量库,流式输出处理进度。"
|
||||
)(recreate_vector_store)
|
||||
|
||||
# LLM模型相关接口
|
||||
@app.post("/llm_model/list_models",
|
||||
tags=["LLM Model Management"],
|
||||
summary="列出当前已加载的模型")
|
||||
def list_models(
|
||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
从fastchat controller获取已加载模型列表
|
||||
'''
|
||||
try:
|
||||
controller_address = controller_address or fschat_controller_address()
|
||||
r = httpx.post(controller_address + "/list_models")
|
||||
return BaseResponse(data=r.json()["models"])
|
||||
except Exception as e:
|
||||
return BaseResponse(
|
||||
code=500,
|
||||
data=[],
|
||||
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
|
||||
|
||||
@app.post("/llm_model/stop",
|
||||
tags=["LLM Model Management"],
|
||||
summary="停止指定的LLM模型(Model Worker)",
|
||||
)
|
||||
def stop_llm_model(
|
||||
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
|
||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
向fastchat controller请求停止某个LLM模型。
|
||||
注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。
|
||||
'''
|
||||
try:
|
||||
controller_address = controller_address or fschat_controller_address()
|
||||
r = httpx.post(
|
||||
controller_address + "/release_worker",
|
||||
json={"model_name": model_name},
|
||||
)
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
return BaseResponse(
|
||||
code=500,
|
||||
msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}")
|
||||
|
||||
@app.post("/llm_model/change",
|
||||
tags=["LLM Model Management"],
|
||||
summary="切换指定的LLM模型(Model Worker)",
|
||||
)
|
||||
def change_llm_model(
|
||||
model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]),
|
||||
new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]),
|
||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
||||
):
|
||||
'''
|
||||
向fastchat controller请求切换LLM模型。
|
||||
'''
|
||||
try:
|
||||
controller_address = controller_address or fschat_controller_address()
|
||||
r = httpx.post(
|
||||
controller_address + "/release_worker",
|
||||
json={"model_name": model_name, "new_model_name": new_model_name},
|
||||
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
|
||||
)
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
return BaseResponse(
|
||||
code=500,
|
||||
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -20,11 +20,13 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
|
|||
{"role": "assistant", "content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
):
|
||||
history = [History.from_data(h) for h in history]
|
||||
|
||||
async def chat_iterator(query: str,
|
||||
history: List[History] = [],
|
||||
model_name: str = LLM_MODEL,
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
|
||||
|
|
@ -32,10 +34,10 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
|
|||
streaming=True,
|
||||
verbose=True,
|
||||
callbacks=[callback],
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL,
|
||||
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
|
||||
openai_api_key=llm_model_dict[model_name]["api_key"],
|
||||
openai_api_base=llm_model_dict[model_name]["api_base_url"],
|
||||
model_name=model_name,
|
||||
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
|
||||
)
|
||||
|
||||
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
|
||||
|
|
@ -61,5 +63,5 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
|
|||
|
||||
await task
|
||||
|
||||
return StreamingResponse(chat_iterator(query, history),
|
||||
return StreamingResponse(chat_iterator(query, history, model_name),
|
||||
media_type="text/event-stream")
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
|||
"content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
request: Request = None,
|
||||
):
|
||||
|
|
@ -44,16 +45,17 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
|||
kb: KBService,
|
||||
top_k: int,
|
||||
history: Optional[List[History]],
|
||||
model_name: str = LLM_MODEL,
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callbacks=[callback],
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL,
|
||||
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
|
||||
openai_api_key=llm_model_dict[model_name]["api_key"],
|
||||
openai_api_base=llm_model_dict[model_name]["api_base_url"],
|
||||
model_name=model_name,
|
||||
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
|
||||
)
|
||||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
|
|
@ -97,5 +99,5 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
|||
|
||||
await task
|
||||
|
||||
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history),
|
||||
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history, model_name),
|
||||
media_type="text/event-stream")
|
||||
|
|
|
|||
|
|
@ -69,6 +69,7 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
|||
"content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
):
|
||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||
|
|
@ -82,16 +83,17 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
|||
search_engine_name: str,
|
||||
top_k: int,
|
||||
history: Optional[List[History]],
|
||||
model_name: str = LLM_MODEL,
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callbacks=[callback],
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL,
|
||||
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
|
||||
openai_api_key=llm_model_dict[model_name]["api_key"],
|
||||
openai_api_base=llm_model_dict[model_name]["api_base_url"],
|
||||
model_name=model_name,
|
||||
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
|
||||
)
|
||||
|
||||
docs = lookup_search_engine(query, search_engine_name, top_k)
|
||||
|
|
@ -129,5 +131,5 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
|||
ensure_ascii=False)
|
||||
await task
|
||||
|
||||
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history),
|
||||
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history, model_name),
|
||||
media_type="text/event-stream")
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ from langchain.vectorstores import PGVector
|
|||
from langchain.vectorstores.pgvector import DistanceStrategy
|
||||
from sqlalchemy import text
|
||||
|
||||
from configs.model_config import kbs_config
|
||||
from configs.model_config import EMBEDDING_DEVICE, kbs_config
|
||||
|
||||
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
|
||||
score_threshold_process
|
||||
from server.knowledge_base.utils import load_embeddings, KnowledgeFile
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
from .zhipu import ChatGLMWorker
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
from configs.model_config import LOG_PATH
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.model_worker import BaseModelWorker
|
||||
import uuid
|
||||
import json
|
||||
import sys
|
||||
from pydantic import BaseModel
|
||||
import fastchat
|
||||
import threading
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
# 恢复被fastchat覆盖的标准输出
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
|
||||
class ApiModelOutMsg(BaseModel):
|
||||
error_code: int = 0
|
||||
text: str
|
||||
|
||||
class ApiModelWorker(BaseModelWorker):
|
||||
BASE_URL: str
|
||||
SUPPORT_MODELS: List
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_names: List[str],
|
||||
controller_addr: str,
|
||||
worker_addr: str,
|
||||
context_len: int = 2048,
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
|
||||
kwargs.setdefault("model_path", "")
|
||||
kwargs.setdefault("limit_worker_concurrency", 5)
|
||||
super().__init__(model_names=model_names,
|
||||
controller_addr=controller_addr,
|
||||
worker_addr=worker_addr,
|
||||
**kwargs)
|
||||
self.context_len = context_len
|
||||
self.init_heart_beat()
|
||||
|
||||
def count_token(self, params):
|
||||
# TODO:需要完善
|
||||
print("count token")
|
||||
print(params)
|
||||
prompt = params["prompt"]
|
||||
return {"count": len(str(prompt)), "error_code": 0}
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
self.call_ct += 1
|
||||
|
||||
def generate_gate(self, params):
|
||||
for x in self.generate_stream_gate(params):
|
||||
pass
|
||||
return json.loads(x[:-1].decode())
|
||||
|
||||
def get_embeddings(self, params):
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
# workaround to make program exit with Ctrl+c
|
||||
# it should be deleted after pr is merged by fastchat
|
||||
def init_heart_beat(self):
|
||||
self.register_to_controller()
|
||||
self.heart_beat_thread = threading.Thread(
|
||||
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
|
||||
)
|
||||
self.heart_beat_thread.start()
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
import zhipuai
|
||||
from server.model_workers.base import ApiModelWorker
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
import json
|
||||
from typing import List, Literal
|
||||
|
||||
|
||||
class ChatGLMWorker(ApiModelWorker):
|
||||
BASE_URL = "https://open.bigmodel.cn/api/paas/v3/model-api"
|
||||
SUPPORT_MODELS = ["chatglm_pro", "chatglm_std", "chatglm_lite"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_names: List[str] = ["chatglm-api"],
|
||||
version: Literal["chatglm_pro", "chatglm_std", "chatglm_lite"] = "chatglm_std",
|
||||
controller_addr: str,
|
||||
worker_addr: str,
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 32768)
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
# 这里的是chatglm api的模板,其它API的conv_template需要定制
|
||||
self.conv = conv.Conversation(
|
||||
name="chatglm-api",
|
||||
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
|
||||
messages=[],
|
||||
roles=["Human", "Assistant"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
# TODO: 支持stream参数,维护request_id,传过来的prompt也有问题
|
||||
from server.utils import get_model_worker_config
|
||||
|
||||
super().generate_stream_gate(params)
|
||||
zhipuai.api_key = get_model_worker_config("chatglm-api").get("api_key")
|
||||
|
||||
response = zhipuai.model_api.sse_invoke(
|
||||
model=self.version,
|
||||
prompt=[{"role": "user", "content": params["prompt"]}],
|
||||
temperature=params.get("temperature"),
|
||||
top_p=params.get("top_p"),
|
||||
incremental=False,
|
||||
)
|
||||
for e in response.events():
|
||||
if e.event == "add":
|
||||
yield json.dumps({"error_code": 0, "text": e.data}, ensure_ascii=False).encode() + b"\0"
|
||||
# TODO: 更健壮的消息处理
|
||||
# elif e.event == "finish":
|
||||
# ...
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
from server.utils import MakeFastAPIOffline
|
||||
from fastchat.serve.model_worker import app
|
||||
|
||||
worker = ChatGLMWorker(
|
||||
controller_addr="http://127.0.0.1:20001",
|
||||
worker_addr="http://127.0.0.1:20003",
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=20003)
|
||||
|
|
@ -5,13 +5,17 @@ import torch
|
|||
from fastapi import FastAPI
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
from configs.model_config import LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE
|
||||
from typing import Literal, Optional
|
||||
from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDING_DEVICE
|
||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||
import os
|
||||
from server import model_workers
|
||||
from typing import Literal, Optional, Any
|
||||
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
code: int = pydantic.Field(200, description="API status code")
|
||||
msg: str = pydantic.Field("success", description="API status message")
|
||||
data: Any = pydantic.Field(None, description="API data")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
|
|
@ -201,10 +205,28 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
|
|||
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
||||
config.update(llm_model_dict.get(model_name, {}))
|
||||
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
|
||||
config["device"] = llm_device(config.get("device"))
|
||||
|
||||
# 如果没有设置有效的local_model_path,则认为是在线模型API
|
||||
if not os.path.isdir(config.get("local_model_path", "")):
|
||||
config["online_api"] = True
|
||||
if provider := config.get("provider"):
|
||||
try:
|
||||
config["worker_class"] = getattr(model_workers, provider)
|
||||
except Exception as e:
|
||||
print(f"在线模型 ‘{model_name}’ 的provider没有正确配置")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_all_model_worker_configs() -> dict:
|
||||
result = {}
|
||||
model_names = set(llm_model_dict.keys()) | set(FSCHAT_MODEL_WORKERS.keys())
|
||||
for name in model_names:
|
||||
if name != "default":
|
||||
result[name] = get_model_worker_config(name)
|
||||
return result
|
||||
|
||||
|
||||
def fschat_controller_address() -> str:
|
||||
from configs.server_config import FSCHAT_CONTROLLER
|
||||
|
||||
|
|
|
|||
251
startup.py
251
startup.py
|
|
@ -16,14 +16,14 @@ except:
|
|||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
|
||||
logger
|
||||
from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS,
|
||||
from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER,
|
||||
FSCHAT_OPENAI_API, )
|
||||
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
||||
fschat_openai_api_address, set_httpx_timeout,
|
||||
llm_device, embedding_device, get_model_worker_config)
|
||||
from server.utils import MakeFastAPIOffline, FastAPI
|
||||
get_model_worker_config, get_all_model_worker_configs,
|
||||
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
|
||||
import argparse
|
||||
from typing import Tuple, List
|
||||
from typing import Tuple, List, Dict
|
||||
from configs import VERSION
|
||||
|
||||
|
||||
|
|
@ -41,6 +41,7 @@ def create_controller_app(
|
|||
|
||||
MakeFastAPIOffline(app)
|
||||
app.title = "FastChat Controller"
|
||||
app._controller = controller
|
||||
return app
|
||||
|
||||
|
||||
|
|
@ -97,43 +98,62 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse
|
|||
)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
||||
|
||||
gptq_config = GptqConfig(
|
||||
ckpt=args.gptq_ckpt or args.model_path,
|
||||
wbits=args.gptq_wbits,
|
||||
groupsize=args.gptq_groupsize,
|
||||
act_order=args.gptq_act_order,
|
||||
)
|
||||
awq_config = AWQConfig(
|
||||
ckpt=args.awq_ckpt or args.model_path,
|
||||
wbits=args.awq_wbits,
|
||||
groupsize=args.awq_groupsize,
|
||||
)
|
||||
# 在线模型API
|
||||
if worker_class := kwargs.get("worker_class"):
|
||||
worker = worker_class(model_names=args.model_names,
|
||||
controller_addr=args.controller_address,
|
||||
worker_addr=args.worker_address)
|
||||
# 本地模型
|
||||
else:
|
||||
# workaround to make program exit with Ctrl+c
|
||||
# it should be deleted after pr is merged by fastchat
|
||||
def _new_init_heart_beat(self):
|
||||
self.register_to_controller()
|
||||
self.heart_beat_thread = threading.Thread(
|
||||
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
|
||||
)
|
||||
self.heart_beat_thread.start()
|
||||
|
||||
worker = ModelWorker(
|
||||
controller_addr=args.controller_address,
|
||||
worker_addr=args.worker_address,
|
||||
worker_id=worker_id,
|
||||
model_path=args.model_path,
|
||||
model_names=args.model_names,
|
||||
limit_worker_concurrency=args.limit_worker_concurrency,
|
||||
no_register=args.no_register,
|
||||
device=args.device,
|
||||
num_gpus=args.num_gpus,
|
||||
max_gpu_memory=args.max_gpu_memory,
|
||||
load_8bit=args.load_8bit,
|
||||
cpu_offloading=args.cpu_offloading,
|
||||
gptq_config=gptq_config,
|
||||
awq_config=awq_config,
|
||||
stream_interval=args.stream_interval,
|
||||
conv_template=args.conv_template,
|
||||
)
|
||||
ModelWorker.init_heart_beat = _new_init_heart_beat
|
||||
|
||||
gptq_config = GptqConfig(
|
||||
ckpt=args.gptq_ckpt or args.model_path,
|
||||
wbits=args.gptq_wbits,
|
||||
groupsize=args.gptq_groupsize,
|
||||
act_order=args.gptq_act_order,
|
||||
)
|
||||
awq_config = AWQConfig(
|
||||
ckpt=args.awq_ckpt or args.model_path,
|
||||
wbits=args.awq_wbits,
|
||||
groupsize=args.awq_groupsize,
|
||||
)
|
||||
|
||||
worker = ModelWorker(
|
||||
controller_addr=args.controller_address,
|
||||
worker_addr=args.worker_address,
|
||||
worker_id=worker_id,
|
||||
model_path=args.model_path,
|
||||
model_names=args.model_names,
|
||||
limit_worker_concurrency=args.limit_worker_concurrency,
|
||||
no_register=args.no_register,
|
||||
device=args.device,
|
||||
num_gpus=args.num_gpus,
|
||||
max_gpu_memory=args.max_gpu_memory,
|
||||
load_8bit=args.load_8bit,
|
||||
cpu_offloading=args.cpu_offloading,
|
||||
gptq_config=gptq_config,
|
||||
awq_config=awq_config,
|
||||
stream_interval=args.stream_interval,
|
||||
conv_template=args.conv_template,
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].args = args
|
||||
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
||||
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
sys.modules["fastchat.serve.model_worker"].args = args
|
||||
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
||||
|
||||
MakeFastAPIOffline(app)
|
||||
app.title = f"FastChat LLM Server ({LLM_MODEL})"
|
||||
app.title = f"FastChat LLM Server ({args.model_names[0]})"
|
||||
app._worker = worker
|
||||
return app
|
||||
|
||||
|
||||
|
|
@ -190,6 +210,9 @@ def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
|
|||
|
||||
def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"):
|
||||
import uvicorn
|
||||
import httpx
|
||||
from fastapi import Body
|
||||
import time
|
||||
import sys
|
||||
|
||||
app = create_controller_app(
|
||||
|
|
@ -198,11 +221,71 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"):
|
|||
)
|
||||
_set_app_seq(app, q, run_seq)
|
||||
|
||||
# add interface to release and load model worker
|
||||
@app.post("/release_worker")
|
||||
def release_worker(
|
||||
model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]),
|
||||
# worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[fschat_controller_address()]),
|
||||
new_model_name: str = Body(None, description="释放后加载该模型"),
|
||||
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
|
||||
) -> Dict:
|
||||
available_models = app._controller.list_models()
|
||||
if new_model_name in available_models:
|
||||
msg = f"要切换的LLM模型 {new_model_name} 已经存在"
|
||||
logger.info(msg)
|
||||
return {"code": 500, "msg": msg}
|
||||
|
||||
if new_model_name:
|
||||
logger.info(f"开始切换LLM模型:从 {model_name} 到 {new_model_name}")
|
||||
else:
|
||||
logger.info(f"即将停止LLM模型: {model_name}")
|
||||
|
||||
if model_name not in available_models:
|
||||
msg = f"the model {model_name} is not available"
|
||||
logger.error(msg)
|
||||
return {"code": 500, "msg": msg}
|
||||
|
||||
worker_address = app._controller.get_worker_address(model_name)
|
||||
if not worker_address:
|
||||
msg = f"can not find model_worker address for {model_name}"
|
||||
logger.error(msg)
|
||||
return {"code": 500, "msg": msg}
|
||||
|
||||
r = httpx.post(worker_address + "/release",
|
||||
json={"new_model_name": new_model_name, "keep_origin": keep_origin})
|
||||
if r.status_code != 200:
|
||||
msg = f"failed to release model: {model_name}"
|
||||
logger.error(msg)
|
||||
return {"code": 500, "msg": msg}
|
||||
|
||||
if new_model_name:
|
||||
timer = 300 # wait 5 minutes for new model_worker register
|
||||
while timer > 0:
|
||||
models = app._controller.list_models()
|
||||
if new_model_name in models:
|
||||
break
|
||||
time.sleep(1)
|
||||
timer -= 1
|
||||
if timer > 0:
|
||||
msg = f"sucess change model from {model_name} to {new_model_name}"
|
||||
logger.info(msg)
|
||||
return {"code": 200, "msg": msg}
|
||||
else:
|
||||
msg = f"failed change model from {model_name} to {new_model_name}"
|
||||
logger.error(msg)
|
||||
return {"code": 500, "msg": msg}
|
||||
else:
|
||||
msg = f"sucess to release model: {model_name}"
|
||||
logger.info(msg)
|
||||
return {"code": 200, "msg": msg}
|
||||
|
||||
host = FSCHAT_CONTROLLER["host"]
|
||||
port = FSCHAT_CONTROLLER["port"]
|
||||
|
||||
if log_level == "ERROR":
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
|
||||
|
||||
|
||||
|
|
@ -214,16 +297,17 @@ def run_model_worker(
|
|||
log_level: str ="INFO",
|
||||
):
|
||||
import uvicorn
|
||||
from fastapi import Body
|
||||
import sys
|
||||
|
||||
kwargs = get_model_worker_config(model_name)
|
||||
host = kwargs.pop("host")
|
||||
port = kwargs.pop("port")
|
||||
model_path = llm_model_dict[model_name].get("local_model_path", "")
|
||||
kwargs["model_path"] = model_path
|
||||
kwargs["model_names"] = [model_name]
|
||||
kwargs["controller_address"] = controller_address or fschat_controller_address()
|
||||
kwargs["worker_address"] = fschat_model_worker_address()
|
||||
kwargs["worker_address"] = fschat_model_worker_address(model_name)
|
||||
model_path = kwargs.get("local_model_path", "")
|
||||
kwargs["model_path"] = model_path
|
||||
|
||||
app = create_model_worker_app(log_level=log_level, **kwargs)
|
||||
_set_app_seq(app, q, run_seq)
|
||||
|
|
@ -231,6 +315,22 @@ def run_model_worker(
|
|||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
# add interface to release and load model
|
||||
@app.post("/release")
|
||||
def release_model(
|
||||
new_model_name: str = Body(None, description="释放后加载该模型"),
|
||||
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
|
||||
) -> Dict:
|
||||
if keep_origin:
|
||||
if new_model_name:
|
||||
q.put(["start", new_model_name])
|
||||
else:
|
||||
if new_model_name:
|
||||
q.put(["replace", new_model_name])
|
||||
else:
|
||||
q.put(["stop"])
|
||||
return {"code": 200, "msg": "done"}
|
||||
|
||||
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
|
||||
|
||||
|
||||
|
|
@ -337,6 +437,13 @@ def parse_args() -> argparse.ArgumentParser:
|
|||
help="run api.py server",
|
||||
dest="api",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--api-worker",
|
||||
action="store_true",
|
||||
help="run online model api such as zhipuai",
|
||||
dest="api_worker",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-w",
|
||||
"--webui",
|
||||
|
|
@ -368,9 +475,14 @@ def dump_server_info(after_start=False, args=None):
|
|||
print(f"项目版本:{VERSION}")
|
||||
print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
|
||||
print("\n")
|
||||
print(f"当前LLM模型:{LLM_MODEL} @ {llm_device()}")
|
||||
pprint(llm_model_dict[LLM_MODEL])
|
||||
|
||||
model = LLM_MODEL
|
||||
if args and args.model_name:
|
||||
model = args.model_name
|
||||
print(f"当前LLM模型:{model} @ {llm_device()}")
|
||||
pprint(llm_model_dict[model])
|
||||
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
|
||||
|
||||
if after_start:
|
||||
print("\n")
|
||||
print(f"服务端运行信息:")
|
||||
|
|
@ -396,17 +508,20 @@ if __name__ == "__main__":
|
|||
args.openai_api = True
|
||||
args.model_worker = True
|
||||
args.api = True
|
||||
args.api_worker = True
|
||||
args.webui = True
|
||||
|
||||
elif args.all_api:
|
||||
args.openai_api = True
|
||||
args.model_worker = True
|
||||
args.api = True
|
||||
args.api_worker = True
|
||||
args.webui = False
|
||||
|
||||
elif args.llm_api:
|
||||
args.openai_api = True
|
||||
args.model_worker = True
|
||||
args.api_worker = True
|
||||
args.api = False
|
||||
args.webui = False
|
||||
|
||||
|
|
@ -416,7 +531,11 @@ if __name__ == "__main__":
|
|||
logger.info(f"正在启动服务:")
|
||||
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
|
||||
|
||||
processes = {}
|
||||
processes = {"online-api": []}
|
||||
|
||||
def process_count():
|
||||
return len(processes) + len(processes["online-api"]) -1
|
||||
|
||||
if args.quiet:
|
||||
log_level = "ERROR"
|
||||
else:
|
||||
|
|
@ -426,7 +545,7 @@ if __name__ == "__main__":
|
|||
process = Process(
|
||||
target=run_controller,
|
||||
name=f"controller({os.getpid()})",
|
||||
args=(queue, len(processes) + 1, log_level),
|
||||
args=(queue, process_count() + 1, log_level),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
|
|
@ -435,29 +554,42 @@ if __name__ == "__main__":
|
|||
process = Process(
|
||||
target=run_openai_api,
|
||||
name=f"openai_api({os.getpid()})",
|
||||
args=(queue, len(processes) + 1),
|
||||
args=(queue, process_count() + 1),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
processes["openai_api"] = process
|
||||
|
||||
if args.model_worker:
|
||||
model_path = llm_model_dict[args.model_name].get("local_model_path", "")
|
||||
if os.path.isdir(model_path):
|
||||
config = get_model_worker_config(args.model_name)
|
||||
if not config.get("online_api"):
|
||||
process = Process(
|
||||
target=run_model_worker,
|
||||
name=f"model_worker({os.getpid()})",
|
||||
args=(args.model_name, args.controller_address, queue, len(processes) + 1, log_level),
|
||||
name=f"model_worker - {args.model_name} ({os.getpid()})",
|
||||
args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
processes["model_worker"] = process
|
||||
|
||||
if args.api_worker:
|
||||
configs = get_all_model_worker_configs()
|
||||
for model_name, config in configs.items():
|
||||
if config.get("online_api") and config.get("worker_class"):
|
||||
process = Process(
|
||||
target=run_model_worker,
|
||||
name=f"model_worker - {model_name} ({os.getpid()})",
|
||||
args=(model_name, args.controller_address, queue, process_count() + 1, log_level),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
processes["online-api"].append(process)
|
||||
|
||||
if args.api:
|
||||
process = Process(
|
||||
target=run_api_server,
|
||||
name=f"API Server{os.getpid()})",
|
||||
args=(queue, len(processes) + 1),
|
||||
args=(queue, process_count() + 1),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
|
|
@ -467,37 +599,38 @@ if __name__ == "__main__":
|
|||
process = Process(
|
||||
target=run_webui,
|
||||
name=f"WEBUI Server{os.getpid()})",
|
||||
args=(queue, len(processes) + 1),
|
||||
args=(queue, process_count() + 1),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
processes["webui"] = process
|
||||
|
||||
if len(processes) == 0:
|
||||
if process_count() == 0:
|
||||
parser.print_help()
|
||||
else:
|
||||
try:
|
||||
# log infors
|
||||
while True:
|
||||
no = queue.get()
|
||||
if no == len(processes):
|
||||
if no == process_count():
|
||||
time.sleep(0.5)
|
||||
dump_server_info(after_start=True, args=args)
|
||||
break
|
||||
else:
|
||||
queue.put(no)
|
||||
|
||||
if model_worker_process := processes.get("model_worker"):
|
||||
if model_worker_process := processes.pop("model_worker", None):
|
||||
model_worker_process.join()
|
||||
for process in processes.pop("online-api", []):
|
||||
process.join()
|
||||
for name, process in processes.items():
|
||||
if name != "model_worker":
|
||||
process.join()
|
||||
process.join()
|
||||
except:
|
||||
if model_worker_process := processes.get("model_worker"):
|
||||
if model_worker_process := processes.pop("model_worker", None):
|
||||
model_worker_process.terminate()
|
||||
for process in processes.pop("online-api", []):
|
||||
process.terminate()
|
||||
for name, process in processes.items():
|
||||
if name != "model_worker":
|
||||
process.terminate()
|
||||
process.terminate()
|
||||
|
||||
|
||||
# 服务启动后接口调用示例:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from doctest import testfile
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
|
|
|
|||
|
|
@ -0,0 +1,74 @@
|
|||
import requests
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from configs.server_config import api_address, FSCHAT_MODEL_WORKERS
|
||||
from configs.model_config import LLM_MODEL, llm_model_dict
|
||||
|
||||
from pprint import pprint
|
||||
import random
|
||||
|
||||
|
||||
def get_configured_models():
|
||||
model_workers = list(FSCHAT_MODEL_WORKERS)
|
||||
if "default" in model_workers:
|
||||
model_workers.remove("default")
|
||||
|
||||
llm_dict = list(llm_model_dict)
|
||||
|
||||
return model_workers, llm_dict
|
||||
|
||||
|
||||
api_base_url = api_address()
|
||||
|
||||
|
||||
def get_running_models(api="/llm_model/list_models"):
|
||||
url = api_base_url + api
|
||||
r = requests.post(url)
|
||||
if r.status_code == 200:
|
||||
return r.json()["data"]
|
||||
return []
|
||||
|
||||
|
||||
def test_running_models(api="/llm_model/list_models"):
|
||||
url = api_base_url + api
|
||||
r = requests.post(url)
|
||||
assert r.status_code == 200
|
||||
print("\n获取当前正在运行的模型列表:")
|
||||
pprint(r.json())
|
||||
assert isinstance(r.json()["data"], list)
|
||||
assert len(r.json()["data"]) > 0
|
||||
|
||||
|
||||
# 不建议使用stop_model功能。按现在的实现,停止了就只能手动再启动
|
||||
# def test_stop_model(api="/llm_model/stop"):
|
||||
# url = api_base_url + api
|
||||
# r = requests.post(url, json={""})
|
||||
|
||||
|
||||
def test_change_model(api="/llm_model/change"):
|
||||
url = api_base_url + api
|
||||
|
||||
running_models = get_running_models()
|
||||
assert len(running_models) > 0
|
||||
|
||||
model_workers, llm_dict = get_configured_models()
|
||||
|
||||
availabel_new_models = set(model_workers) - set(running_models)
|
||||
if len(availabel_new_models) == 0:
|
||||
availabel_new_models = set(llm_dict) - set(running_models)
|
||||
availabel_new_models = list(availabel_new_models)
|
||||
assert len(availabel_new_models) > 0
|
||||
print(availabel_new_models)
|
||||
|
||||
model_name = random.choice(running_models)
|
||||
new_model_name = random.choice(availabel_new_models)
|
||||
print(f"\n尝试将模型从 {model_name} 切换到 {new_model_name}")
|
||||
r = requests.post(url, json={"model_name": model_name, "new_model_name": new_model_name})
|
||||
assert r.status_code == 200
|
||||
|
||||
running_models = get_running_models()
|
||||
assert new_model_name in running_models
|
||||
|
|
@ -1,10 +1,14 @@
|
|||
import streamlit as st
|
||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||
from webui_pages.utils import *
|
||||
from streamlit_chatbox import *
|
||||
from datetime import datetime
|
||||
from server.chat.search_engine_chat import SEARCH_ENGINES
|
||||
from typing import List, Dict
|
||||
import os
|
||||
from configs.model_config import llm_model_dict, LLM_MODEL
|
||||
from server.utils import get_model_worker_config
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
chat_box = ChatBox(
|
||||
assistant_avatar=os.path.join(
|
||||
|
|
@ -59,6 +63,38 @@ def dialogue_page(api: ApiRequest):
|
|||
on_change=on_mode_change,
|
||||
key="dialogue_mode",
|
||||
)
|
||||
|
||||
def on_llm_change():
|
||||
st.session_state["prev_llm_model"] = llm_model
|
||||
|
||||
def llm_model_format_func(x):
|
||||
if x in running_models:
|
||||
return f"{x} (Running)"
|
||||
return x
|
||||
|
||||
running_models = api.list_running_models()
|
||||
config_models = api.list_config_models()
|
||||
for x in running_models:
|
||||
if x in config_models:
|
||||
config_models.remove(x)
|
||||
llm_models = running_models + config_models
|
||||
if "prev_llm_model" not in st.session_state:
|
||||
index = llm_models.index(LLM_MODEL)
|
||||
else:
|
||||
index = 0
|
||||
llm_model = st.selectbox("选择LLM模型:",
|
||||
llm_models,
|
||||
index,
|
||||
format_func=llm_model_format_func,
|
||||
on_change=on_llm_change,
|
||||
# key="llm_model",
|
||||
)
|
||||
if (st.session_state.get("prev_llm_model") != llm_model
|
||||
and not get_model_worker_config(llm_model).get("online_api")):
|
||||
with st.spinner(f"正在加载模型: {llm_model}"):
|
||||
r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
|
||||
st.session_state["prev_llm_model"] = llm_model
|
||||
|
||||
history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)
|
||||
|
||||
def on_kb_change():
|
||||
|
|
@ -99,7 +135,7 @@ def dialogue_page(api: ApiRequest):
|
|||
if dialogue_mode == "LLM 对话":
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
r = api.chat_chat(prompt, history)
|
||||
r = api.chat_chat(prompt, history=history, model=llm_model)
|
||||
for t in r:
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
|
|
@ -114,7 +150,7 @@ def dialogue_page(api: ApiRequest):
|
|||
Markdown("...", in_expander=True, title="知识库匹配结果"),
|
||||
])
|
||||
text = ""
|
||||
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history):
|
||||
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history, model=llm_model):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
text += d["answer"]
|
||||
|
|
@ -127,8 +163,8 @@ def dialogue_page(api: ApiRequest):
|
|||
Markdown("...", in_expander=True, title="网络搜索结果"),
|
||||
])
|
||||
text = ""
|
||||
for d in api.search_engine_chat(prompt, search_engine, se_top_k):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
for d in api.search_engine_chat(prompt, search_engine, se_top_k, model=llm_model):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
else:
|
||||
text += d["answer"]
|
||||
|
|
|
|||
|
|
@ -6,12 +6,14 @@ from configs.model_config import (
|
|||
DEFAULT_VS_TYPE,
|
||||
KB_ROOT_PATH,
|
||||
LLM_MODEL,
|
||||
llm_model_dict,
|
||||
HISTORY_LEN,
|
||||
SCORE_THRESHOLD,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
SEARCH_ENGINE_TOP_K,
|
||||
logger,
|
||||
)
|
||||
from configs.server_config import HTTPX_DEFAULT_TIMEOUT
|
||||
import httpx
|
||||
import asyncio
|
||||
from server.chat.openai_chat import OpenAiChatMsgIn
|
||||
|
|
@ -42,7 +44,7 @@ class ApiRequest:
|
|||
def __init__(
|
||||
self,
|
||||
base_url: str = "http://127.0.0.1:7861",
|
||||
timeout: float = 60.0,
|
||||
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
||||
no_remote_api: bool = False, # call api view function directly
|
||||
):
|
||||
self.base_url = base_url
|
||||
|
|
@ -289,6 +291,7 @@ class ApiRequest:
|
|||
query: str,
|
||||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
model: str = LLM_MODEL,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
|
|
@ -301,6 +304,7 @@ class ApiRequest:
|
|||
"query": query,
|
||||
"history": history,
|
||||
"stream": stream,
|
||||
"model_name": model,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
|
|
@ -322,6 +326,7 @@ class ApiRequest:
|
|||
score_threshold: float = SCORE_THRESHOLD,
|
||||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
model: str = LLM_MODEL,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
|
|
@ -337,6 +342,7 @@ class ApiRequest:
|
|||
"score_threshold": score_threshold,
|
||||
"history": history,
|
||||
"stream": stream,
|
||||
"model_name": model,
|
||||
"local_doc_url": no_remote_api,
|
||||
}
|
||||
|
||||
|
|
@ -361,6 +367,7 @@ class ApiRequest:
|
|||
search_engine_name: str,
|
||||
top_k: int = SEARCH_ENGINE_TOP_K,
|
||||
stream: bool = True,
|
||||
model: str = LLM_MODEL,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
|
|
@ -374,6 +381,7 @@ class ApiRequest:
|
|||
"search_engine_name": search_engine_name,
|
||||
"top_k": top_k,
|
||||
"stream": stream,
|
||||
"model_name": model,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
|
|
@ -645,6 +653,84 @@ class ApiRequest:
|
|||
)
|
||||
return self._httpx_stream2generator(response, as_json=True)
|
||||
|
||||
def list_running_models(self, controller_address: str = None):
|
||||
'''
|
||||
获取Fastchat中正运行的模型列表
|
||||
'''
|
||||
r = self.post(
|
||||
"/llm_model/list_models",
|
||||
)
|
||||
return r.json().get("data", [])
|
||||
|
||||
def list_config_models(self):
|
||||
'''
|
||||
获取configs中配置的模型列表
|
||||
'''
|
||||
return list(llm_model_dict.keys())
|
||||
|
||||
def stop_llm_model(
|
||||
self,
|
||||
model_name: str,
|
||||
controller_address: str = None,
|
||||
):
|
||||
'''
|
||||
停止某个LLM模型。
|
||||
注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。
|
||||
'''
|
||||
data = {
|
||||
"model_name": model_name,
|
||||
"controller_address": controller_address,
|
||||
}
|
||||
r = self.post(
|
||||
"/llm_model/stop",
|
||||
json=data,
|
||||
)
|
||||
return r.json()
|
||||
|
||||
def change_llm_model(
|
||||
self,
|
||||
model_name: str,
|
||||
new_model_name: str,
|
||||
controller_address: str = None,
|
||||
):
|
||||
'''
|
||||
向fastchat controller请求切换LLM模型。
|
||||
'''
|
||||
if not model_name or not new_model_name:
|
||||
return
|
||||
|
||||
if new_model_name == model_name:
|
||||
return {
|
||||
"code": 200,
|
||||
"msg": "什么都不用做"
|
||||
}
|
||||
|
||||
running_models = self.list_running_models()
|
||||
if model_name not in running_models:
|
||||
return {
|
||||
"code": 500,
|
||||
"msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}"
|
||||
}
|
||||
|
||||
config_models = self.list_config_models()
|
||||
if new_model_name not in config_models:
|
||||
return {
|
||||
"code": 500,
|
||||
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
|
||||
}
|
||||
|
||||
data = {
|
||||
"model_name": model_name,
|
||||
"new_model_name": new_model_name,
|
||||
"controller_address": controller_address,
|
||||
}
|
||||
r = self.post(
|
||||
"/llm_model/change",
|
||||
json=data,
|
||||
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
|
||||
)
|
||||
return r.json()
|
||||
|
||||
|
||||
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
|
||||
'''
|
||||
|
|
|
|||
Loading…
Reference in New Issue