From f7c73b842a1af1cc20e1579e897f36a43fb12968 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Fri, 15 Sep 2023 17:52:22 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96configs=20(#1474)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * remove llm_model_dict * optimize configs * fix get_model_path * 更改一些默认参数,添加千帆的默认配置 * Update server_config.py.example --- chains/llmchain_with_history.py | 13 +- configs/__init__.py | 5 +- configs/basic_config.py.example | 21 ++ configs/kb_config.py.exmaple | 107 ++++++++ configs/model_config.py.example | 246 +++++------------- configs/server_config.py.example | 24 +- server/chat/chat.py | 16 +- server/chat/knowledge_base_chat.py | 19 +- server/chat/openai_chat.py | 8 +- server/chat/search_engine_chat.py | 18 +- server/chat/utils.py | 23 +- server/db/base.py | 2 +- server/knowledge_base/kb_api.py | 2 +- server/knowledge_base/kb_cache/base.py | 12 +- server/knowledge_base/kb_doc_api.py | 8 +- server/knowledge_base/kb_service/base.py | 4 +- .../kb_service/faiss_kb_service.py | 2 +- .../kb_service/milvus_kb_service.py | 2 +- .../kb_service/pg_kb_service.py | 2 +- server/knowledge_base/migrate.py | 4 +- server/knowledge_base/utils.py | 8 +- server/model_workers/base.py | 2 +- server/utils.py | 65 ++++- startup.py | 30 ++- tests/api/test_kb_api.py | 2 +- tests/api/test_kb_api_request.py | 2 +- tests/api/test_llm_api.py | 17 +- tests/api/test_stream_chat_api.py | 2 +- .../test_different_splitter.py | 2 +- webui_pages/dialogue/dialogue.py | 3 +- webui_pages/knowledge_base/knowledge_base.py | 9 +- webui_pages/utils.py | 11 +- 32 files changed, 371 insertions(+), 320 deletions(-) create mode 100644 configs/basic_config.py.example create mode 100644 configs/kb_config.py.exmaple diff --git a/chains/llmchain_with_history.py b/chains/llmchain_with_history.py index 3d36042..044f470 100644 --- a/chains/llmchain_with_history.py +++ b/chains/llmchain_with_history.py @@ -1,19 +1,12 @@ -from langchain.chat_models import ChatOpenAI -from configs.model_config import llm_model_dict, LLM_MODEL +from server.chat.utils import get_ChatOpenAI +from configs.model_config import LLM_MODEL, TEMPERATURE from langchain import LLMChain from langchain.prompts.chat import ( ChatPromptTemplate, HumanMessagePromptTemplate, ) -model = ChatOpenAI( - streaming=True, - verbose=True, - # callbacks=[callback], - openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], - openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], - model_name=LLM_MODEL -) +model = get_ChatOpenAI(model_name=LLM_MODEL, temperature=TEMPERATURE) human_prompt = "{input}" diff --git a/configs/__init__.py b/configs/__init__.py index 41169e8..f4c1866 100644 --- a/configs/__init__.py +++ b/configs/__init__.py @@ -1,4 +1,7 @@ +from .basic_config import * from .model_config import * +from .kb_config import * from .server_config import * -VERSION = "v0.2.4" + +VERSION = "v0.2.5-preview" diff --git a/configs/basic_config.py.example b/configs/basic_config.py.example new file mode 100644 index 0000000..03ff8f6 --- /dev/null +++ b/configs/basic_config.py.example @@ -0,0 +1,21 @@ +import logging +import os + + +# 是否显示详细日志 +log_verbose = False + + +# 通常情况下不需要更改以下内容 + +# 日志格式 +LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" +logger = logging.getLogger() +logger.setLevel(logging.INFO) +logging.basicConfig(format=LOG_FORMAT) + + +# 日志存储路径 +LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs") +if not os.path.exists(LOG_PATH): + os.mkdir(LOG_PATH) diff --git a/configs/kb_config.py.exmaple b/configs/kb_config.py.exmaple new file mode 100644 index 0000000..7df3f38 --- /dev/null +++ b/configs/kb_config.py.exmaple @@ -0,0 +1,107 @@ +import os + + +# 默认向量库类型。可选:faiss, milvus, pg. +DEFAULT_VS_TYPE = "faiss" + +# 缓存向量库数量(针对FAISS) +CACHED_VS_NUM = 1 + +# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter) +CHUNK_SIZE = 500 + +# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter) +OVERLAP_SIZE = 50 + +# 知识库匹配向量数量 +VECTOR_SEARCH_TOP_K = 3 + +# 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右 +SCORE_THRESHOLD = 1 + +# 搜索引擎匹配结题数量 +SEARCH_ENGINE_TOP_K = 3 + + +# 基于本地知识问答的提示词模版(使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号 +PROMPT_TEMPLATE = """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 + +<已知信息>{{ context }} + +<问题>{{ question }}""" + + +# Bing 搜索必备变量 +# 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search +# 具体申请方式请见 +# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/create-bing-search-service-resource +# 使用python创建bing api 搜索实例详见: +# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/quickstarts/rest/python +BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search" +# 注意不是bing Webmaster Tools的api key, + +# 此外,如果是在服务器上,报Failed to establish a new connection: [Errno 110] Connection timed out +# 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG +BING_SUBSCRIPTION_KEY = "" + +# 是否开启中文标题加强,以及标题增强的相关配置 +# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记; +# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。 +ZH_TITLE_ENHANCE = False + + +# 通常情况下不需要更改以下内容 + +# 知识库默认存储路径 +KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base") +if not os.path.exists(KB_ROOT_PATH): + os.mkdir(KB_ROOT_PATH) + +# 数据库默认存储路径。 +# 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。 +DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db") +SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}" + +# 可选向量库类型及对应配置 +kbs_config = { + "faiss": { + }, + "milvus": { + "host": "127.0.0.1", + "port": "19530", + "user": "", + "password": "", + "secure": False, + }, + "pg": { + "connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat", + } +} + +# TextSplitter配置项,如果你不明白其中的含义,就不要修改。 +text_splitter_dict = { + "ChineseRecursiveTextSplitter": { + "source": "huggingface", ## 选择tiktoken则使用openai的方法 + "tokenizer_name_or_path": "gpt2", + }, + "SpacyTextSplitter": { + "source": "huggingface", + "tokenizer_name_or_path": "", + }, + "RecursiveCharacterTextSplitter": { + "source": "tiktoken", + "tokenizer_name_or_path": "cl100k_base", + }, + "MarkdownHeaderTextSplitter": { + "headers_to_split_on": + [ + ("#", "head1"), + ("##", "head2"), + ("###", "head3"), + ("####", "head4"), + ] + }, +} + +# TEXT_SPLITTER 名称 +TEXT_SPLITTER_NAME = "SpacyTextSplitter" diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 0aa0cf5..728cddd 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -1,36 +1,46 @@ import os -import logging -# 日志格式 -LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" -logger = logging.getLogger() -logger.setLevel(logging.INFO) -logging.basicConfig(format=LOG_FORMAT) -# 是否显示详细日志 -log_verbose = False -# 在以下字典中修改属性值,以指定本地embedding模型存储位置 -# 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese" -# 此处请写绝对路径 -embedding_model_dict = { - "ernie-tiny": "nghuyong/ernie-3.0-nano-zh", - "ernie-base": "nghuyong/ernie-3.0-base-zh", - "text2vec-base": "shibing624/text2vec-base-chinese", - "text2vec": "GanymedeNil/text2vec-large-chinese", - "text2vec-paraphrase": "shibing624/text2vec-base-chinese-paraphrase", - "text2vec-sentence": "shibing624/text2vec-base-chinese-sentence", - "text2vec-multilingual": "shibing624/text2vec-base-multilingual", - "text2vec-bge-large-chinese": "shibing624/text2vec-bge-large-chinese", - "m3e-small": "moka-ai/m3e-small", - "m3e-base": "moka-ai/m3e-base", - "m3e-large": "moka-ai/m3e-large", - "bge-small-zh": "BAAI/bge-small-zh", - "bge-base-zh": "BAAI/bge-base-zh", - "bge-large-zh": "BAAI/bge-large-zh", - "bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct", - "piccolo-base-zh": "sensenova/piccolo-base-zh", - "piccolo-large-zh": "sensenova/piccolo-large-zh", - "text-embedding-ada-002": os.environ.get("OPENAI_API_KEY") +# 可以指定一个绝对路径,统一存放所有的Embedding和LLM模型。 +# 每个模型可以是一个单独的目录,也可以是某个目录下的二级子目录 +MODEL_ROOT_PATH = "" + +# 在以下字典中修改属性值,以指定本地embedding模型存储位置。支持3种设置方法: +# 1、将对应的值修改为模型绝对路径 +# 2、不修改此处的值(以 text2vec 为例): +# 2.1 如果{MODEL_ROOT_PATH}下存在如下任一子目录: +# - text2vec +# - GanymedeNil/text2vec-large-chinese +# - text2vec-large-chinese +# 2.2 如果以上本地路径不存在,则使用huggingface模型 +MODEL_PATH = { + "embed_model": { + "ernie-tiny": "nghuyong/ernie-3.0-nano-zh", + "ernie-base": "nghuyong/ernie-3.0-base-zh", + "text2vec-base": "shibing624/text2vec-base-chinese", + "text2vec": "GanymedeNil/text2vec-large-chinese", + "text2vec-paraphrase": "shibing624/text2vec-base-chinese-paraphrase", + "text2vec-sentence": "shibing624/text2vec-base-chinese-sentence", + "text2vec-multilingual": "shibing624/text2vec-base-multilingual", + "text2vec-bge-large-chinese": "shibing624/text2vec-bge-large-chinese", + "m3e-small": "moka-ai/m3e-small", + "m3e-base": "moka-ai/m3e-base", + "m3e-large": "moka-ai/m3e-large", + "bge-small-zh": "BAAI/bge-small-zh", + "bge-base-zh": "BAAI/bge-base-zh", + "bge-large-zh": "BAAI/bge-large-zh", + "bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct", + "piccolo-base-zh": "sensenova/piccolo-base-zh", + "piccolo-large-zh": "sensenova/piccolo-large-zh", + "text-embedding-ada-002": "your OPENAI_API_KEY", + }, + # TODO: add all supported llm models + "llm_model": { + "chatglm-6b": "THUDM/chatglm-6b", + "chatglm2-6b": "THUDM/chatglm2-6b", + "chatglm2-6b-int4": "THUDM/chatglm2-6b-int4", + "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k", + }, } # 选用的 Embedding 名称 @@ -39,25 +49,21 @@ EMBEDDING_MODEL = "m3e-base" # Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 EMBEDDING_DEVICE = "auto" -llm_model_dict = { - "chatglm-6b": { - "local_model_path": "THUDM/chatglm-6b", - "api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url" - "api_key": "EMPTY" - }, +# LLM 名称 +LLM_MODEL = "chatglm2-6b" - "chatglm2-6b": { - "local_model_path": "THUDM/chatglm2-6b", - "api_base_url": "http://localhost:8888/v1", # URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致 - "api_key": "EMPTY" - }, +# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 +LLM_DEVICE = "auto" - "chatglm2-6b-32k": { - "local_model_path": "THUDM/chatglm2-6b-32k", # "THUDM/chatglm2-6b-32k", - "api_base_url": "http://localhost:8888/v1", # "URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致 - "api_key": "EMPTY" - }, +# 历史对话轮数 +HISTORY_LEN = 3 +# LLM通用对话参数 +TEMPERATURE = 0.7 +# TOP_P = 0.95 # ChatOpenAI暂不支持该参数 + + +ONLINE_LLM_MODEL = { # 调用chatgpt时如果报出: urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443): # Max retries exceeded with url: /v1/chat/completions # 则需要将urllib3版本修改为1.25.11 @@ -74,29 +80,25 @@ llm_model_dict = { # 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置 # 比如: "openai_proxy": 'http://127.0.0.1:4780' "gpt-3.5-turbo": { - "api_base_url": "https://api.openai.com/v1", - "api_key": "", - "openai_proxy": "" + "api_key": "your OPENAI_API_KEY", + "openai_proxy": "your OPENAI_PROXY", }, - # 线上模型。当前支持智谱AI。 - # 如果没有设置有效的local_model_path,则认为是在线模型API。 - # 请在server_config中为每个在线API设置不同的端口 + # 线上模型。请在server_config中为每个在线API设置不同的端口 # 具体注册及api key获取请前往 http://open.bigmodel.cn "zhipu-api": { - "api_base_url": "http://127.0.0.1:8888/v1", "api_key": "", - "provider": "ChatGLMWorker", "version": "chatglm_pro", # 可选包括 "chatglm_lite", "chatglm_std", "chatglm_pro" + "provider": "ChatGLMWorker", }, + # 具体注册及api key获取请前往 https://api.minimax.chat/ "minimax-api": { - "api_base_url": "http://127.0.0.1:8888/v1", "group_id": "", "api_key": "", "is_pro": False, "provider": "MiniMaxWorker", }, + # 具体注册及api key获取请前往 https://xinghuo.xfyun.cn/ "xinghuo-api": { - "api_base_url": "http://127.0.0.1:8888/v1", "APPID": "", "APISecret": "", "api_key": "", @@ -105,140 +107,16 @@ llm_model_dict = { }, # 百度千帆 API,申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf "qianfan-api": { - "version": "ernie-bot", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo", 更多的见文档模型支持列表中千帆部分。 - "version_url": "", # 可以不填写version,直接填写在千帆申请模型发布的API地址 - "api_base_url": "http://127.0.0.1:8888/v1", + "version": "ernie-bot-turbo", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo", 更多的见官方文档。 + "version_url": "", # 也可以不填写version,直接填写在千帆申请模型发布的API地址 "api_key": "", "secret_key": "", - "provider": "ErnieWorker", - } -} - -# LLM 名称 -LLM_MODEL = "chatglm2-6b" - -# 历史对话轮数 -HISTORY_LEN = 3 - -# LLM通用对话参数 -TEMPERATURE = 0.7 -# TOP_P = 0.95 # ChatOpenAI暂不支持该参数 - - -# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 -LLM_DEVICE = "auto" - -# TextSplitter - -text_splitter_dict = { - "ChineseRecursiveTextSplitter": { - "source": "", - "tokenizer_name_or_path": "", - }, - "SpacyTextSplitter": { - "source": "huggingface", - "tokenizer_name_or_path": "gpt2", - }, - "RecursiveCharacterTextSplitter": { - "source": "tiktoken", - "tokenizer_name_or_path": "cl100k_base", - }, - - "MarkdownHeaderTextSplitter": { - "headers_to_split_on": - [ - ("#", "head1"), - ("##", "head2"), - ("###", "head3"), - ("####", "head4"), - ] + "provider": "QianFanWorker", }, } -# TEXT_SPLITTER 名称 -TEXT_SPLITTER = "ChineseRecursiveTextSplitter" -# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter) -CHUNK_SIZE = 250 - -# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter) -OVERLAP_SIZE = 0 - - -# 日志存储路径 -LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs") -if not os.path.exists(LOG_PATH): - os.mkdir(LOG_PATH) - -# 知识库默认存储路径 -KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base") -if not os.path.exists(KB_ROOT_PATH): - os.mkdir(KB_ROOT_PATH) -# 数据库默认存储路径。 -# 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。 -DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db") -SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}" - - -# 可选向量库类型及对应配置 -kbs_config = { - "faiss": { - }, - "milvus": { - "host": "127.0.0.1", - "port": "19530", - "user": "", - "password": "", - "secure": False, - }, - "pg": { - "connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat", - } -} - -# 默认向量库类型。可选:faiss, milvus, pg. -DEFAULT_VS_TYPE = "faiss" - -# 缓存向量库数量 -CACHED_VS_NUM = 1 - -# 知识库匹配向量数量 -VECTOR_SEARCH_TOP_K = 3 - -# 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右 -SCORE_THRESHOLD = 1 - -# 搜索引擎匹配结题数量 -SEARCH_ENGINE_TOP_K = 3 +# 通常情况下不需要更改以下内容 # nltk 模型存储路径 NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data") - -# 基于本地知识问答的提示词模版(使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号 -PROMPT_TEMPLATE = """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 - -<已知信息>{{ context }} - -<问题>{{ question }}""" - -# API 是否开启跨域,默认为False,如果需要开启,请设置为True -# is open cross domain -OPEN_CROSS_DOMAIN = False - -# Bing 搜索必备变量 -# 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search -# 具体申请方式请见 -# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/create-bing-search-service-resource -# 使用python创建bing api 搜索实例详见: -# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/quickstarts/rest/python -BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search" -# 注意不是bing Webmaster Tools的api key, - -# 此外,如果是在服务器上,报Failed to establish a new connection: [Errno 110] Connection timed out -# 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG -BING_SUBSCRIPTION_KEY = "" - -# 是否开启中文标题加强,以及标题增强的相关配置 -# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记; -# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。 -ZH_TITLE_ENHANCE = False diff --git a/configs/server_config.py.example b/configs/server_config.py.example index 51f53dc..b5040bf 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -1,4 +1,4 @@ -from .model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE +from configs.model_config import LLM_DEVICE import httpx # httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。 @@ -8,7 +8,7 @@ HTTPX_DEFAULT_TIMEOUT = 300.0 # is open cross domain OPEN_CROSS_DOMAIN = False -# 各服务器默认绑定host +# 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host DEFAULT_BIND_HOST = "127.0.0.1" # webui.py server @@ -26,14 +26,14 @@ API_SERVER = { # fastchat openai_api server FSCHAT_OPENAI_API = { "host": DEFAULT_BIND_HOST, - "port": 8888, # model_config.llm_model_dict中模型配置的api_base_url需要与这里一致。 + "port": 20000, } # fastchat model_worker server -# 这些模型必须是在model_config.llm_model_dict中正确配置的。 +# 这些模型必须是在model_config.MODEL_PATH或ONLINE_MODEL中正确配置的。 # 在启动startup.py时,可用通过`--model-worker --model-name xxxx`指定模型,不指定则为LLM_MODEL FSCHAT_MODEL_WORKERS = { - # 所有模型共用的默认配置,可在模型专项配置或llm_model_dict中进行覆盖。 + # 所有模型共用的默认配置,可在模型专项配置中进行覆盖。 "default": { "host": DEFAULT_BIND_HOST, "port": 20002, @@ -64,17 +64,17 @@ FSCHAT_MODEL_WORKERS = { "baichuan-7b": { # 使用default中的IP和端口 "device": "cpu", }, - "zhipu-api": { # 请为每个在线API设置不同的端口 - "port": 20003, + "zhipu-api": { # 请为每个要运行的在线API设置不同的端口 + "port": 21001, }, - "minimax-api": { # 请为每个在线API设置不同的端口 - "port": 20004, + "minimax-api": { + "port": 21002, }, - "xinghuo-api": { # 请为每个在线API设置不同的端口 - "port": 20005, + "xinghuo-api": { + "port": 21003, }, "qianfan-api": { - "port": 20006, + "port": 21004, }, } diff --git a/server/chat/chat.py b/server/chat/chat.py index c025c3c..e7564bc 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -1,8 +1,7 @@ from fastapi import Body from fastapi.responses import StreamingResponse -from configs.model_config import llm_model_dict, LLM_MODEL, TEMPERATURE -from server.chat.utils import wrap_done -from langchain.chat_models import ChatOpenAI +from configs.model_config import LLM_MODEL, TEMPERATURE +from server.chat.utils import wrap_done, get_ChatOpenAI from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable @@ -31,18 +30,11 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 model_name: str = LLM_MODEL, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() - - model = ChatOpenAI( - streaming=True, - verbose=True, - callbacks=[callback], - openai_api_key=llm_model_dict[model_name]["api_key"], - openai_api_base=llm_model_dict[model_name]["api_base_url"], + model = get_ChatOpenAI( model_name=model_name, temperature=temperature, - openai_proxy=llm_model_dict[model_name].get("openai_proxy") + callbacks=[callback], ) - input_msg = History(role="user", content="{{ input }}").to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages( [i.to_msg_template() for i in history] + [input_msg]) diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index b26f2cb..62bd31a 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -1,11 +1,10 @@ from fastapi import Body, Request from fastapi.responses import StreamingResponse -from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE, - VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, - TEMPERATURE) -from server.chat.utils import wrap_done +from configs import (LLM_MODEL, PROMPT_TEMPLATE, + VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, + TEMPERATURE) +from server.chat.utils import wrap_done, get_ChatOpenAI from server.utils import BaseResponse -from langchain.chat_models import ChatOpenAI from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable, List, Optional @@ -50,16 +49,10 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", model_name: str = LLM_MODEL, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() - - model = ChatOpenAI( - streaming=True, - verbose=True, - callbacks=[callback], - openai_api_key=llm_model_dict[model_name]["api_key"], - openai_api_base=llm_model_dict[model_name]["api_base_url"], + model = get_ChatOpenAI( model_name=model_name, temperature=temperature, - openai_proxy=llm_model_dict[model_name].get("openai_proxy") + callbacks=[callback], ) docs = search_docs(query, knowledge_base_name, top_k, score_threshold) context = "\n".join([doc.page_content for doc in docs]) diff --git a/server/chat/openai_chat.py b/server/chat/openai_chat.py index 857ac97..a7d14c9 100644 --- a/server/chat/openai_chat.py +++ b/server/chat/openai_chat.py @@ -1,7 +1,8 @@ from fastapi.responses import StreamingResponse from typing import List import openai -from configs.model_config import llm_model_dict, LLM_MODEL, logger, log_verbose +from configs import LLM_MODEL, logger, log_verbose +from server.utils import get_model_worker_config, fschat_openai_api_address from pydantic import BaseModel @@ -23,9 +24,10 @@ class OpenAiChatMsgIn(BaseModel): async def openai_chat(msg: OpenAiChatMsgIn): - openai.api_key = llm_model_dict[LLM_MODEL]["api_key"] + config = get_model_worker_config(msg.model) + openai.api_key = config.get("api_key", "EMPTY") print(f"{openai.api_key=}") - openai.api_base = llm_model_dict[LLM_MODEL]["api_base_url"] + openai.api_base = fschat_openai_api_address() print(f"{openai.api_base=}") print(msg) diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index f8e4ebe..ffeff86 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -1,13 +1,12 @@ from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper -from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY +from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, + LLM_MODEL, SEARCH_ENGINE_TOP_K, + PROMPT_TEMPLATE, TEMPERATURE) from fastapi import Body from fastapi.responses import StreamingResponse from fastapi.concurrency import run_in_threadpool -from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, - PROMPT_TEMPLATE, TEMPERATURE) -from server.chat.utils import wrap_done +from server.chat.utils import wrap_done, get_ChatOpenAI from server.utils import BaseResponse -from langchain.chat_models import ChatOpenAI from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable @@ -90,15 +89,10 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", model_name: str = LLM_MODEL, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() - model = ChatOpenAI( - streaming=True, - verbose=True, - callbacks=[callback], - openai_api_key=llm_model_dict[model_name]["api_key"], - openai_api_base=llm_model_dict[model_name]["api_base_url"], + model = get_ChatOpenAI( model_name=model_name, temperature=temperature, - openai_proxy=llm_model_dict[model_name].get("openai_proxy") + callbacks=[callback], ) docs = await lookup_search_engine(query, search_engine_name, top_k) diff --git a/server/chat/utils.py b/server/chat/utils.py index a80648b..0d2e877 100644 --- a/server/chat/utils.py +++ b/server/chat/utils.py @@ -1,8 +1,29 @@ import asyncio -from typing import Awaitable, List, Tuple, Dict, Union from pydantic import BaseModel, Field from langchain.prompts.chat import ChatMessagePromptTemplate from configs import logger, log_verbose +from server.utils import get_model_worker_config, fschat_openai_api_address +from langchain.chat_models import ChatOpenAI +from typing import Awaitable, List, Tuple, Dict, Union, Callable + + +def get_ChatOpenAI( + model_name: str, + temperature: float, + callbacks: List[Callable] = [], +) -> ChatOpenAI: + config = get_model_worker_config(model_name) + model = ChatOpenAI( + streaming=True, + verbose=True, + callbacks=callbacks, + openai_api_key=config.get("api_key", "EMPTY"), + openai_api_base=fschat_openai_api_address(), + model_name=model_name, + temperature=temperature, + openai_proxy=config.get("openai_proxy") + ) + return model async def wrap_done(fn: Awaitable, event: asyncio.Event): diff --git a/server/db/base.py b/server/db/base.py index 1d911c0..ae42ac0 100644 --- a/server/db/base.py +++ b/server/db/base.py @@ -2,7 +2,7 @@ from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker -from configs.model_config import SQLALCHEMY_DATABASE_URI +from configs import SQLALCHEMY_DATABASE_URI import json diff --git a/server/knowledge_base/kb_api.py b/server/knowledge_base/kb_api.py index c7b703e..f50d8a7 100644 --- a/server/knowledge_base/kb_api.py +++ b/server/knowledge_base/kb_api.py @@ -3,7 +3,7 @@ from server.utils import BaseResponse, ListResponse from server.knowledge_base.utils import validate_kb_name from server.knowledge_base.kb_service.base import KBServiceFactory from server.db.repository.knowledge_base_repository import list_kbs_from_db -from configs.model_config import EMBEDDING_MODEL, logger, log_verbose +from configs import EMBEDDING_MODEL, logger, log_verbose from fastapi import Body diff --git a/server/knowledge_base/kb_cache/base.py b/server/knowledge_base/kb_cache/base.py index f3e6d65..d6b72c7 100644 --- a/server/knowledge_base/kb_cache/base.py +++ b/server/knowledge_base/kb_cache/base.py @@ -4,9 +4,9 @@ from langchain.embeddings import HuggingFaceBgeEmbeddings from langchain.embeddings.base import Embeddings from langchain.schema import Document import threading -from configs.model_config import (CACHED_VS_NUM, EMBEDDING_MODEL, CHUNK_SIZE, - embedding_model_dict, logger, log_verbose) -from server.utils import embedding_device +from configs import (EMBEDDING_MODEL, CHUNK_SIZE, CACHED_VS_NUM, + logger, log_verbose) +from server.utils import embedding_device, get_model_path from contextlib import contextmanager from collections import OrderedDict from typing import List, Any, Union, Tuple @@ -118,15 +118,15 @@ class EmbeddingsPool(CachePool): with item.acquire(msg="初始化"): self.atomic.release() if model == "text-embedding-ada-002": # openai text-embedding-ada-002 - embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE) + embeddings = OpenAIEmbeddings(openai_api_key=get_model_path(model), chunk_size=CHUNK_SIZE) elif 'bge-' in model: - embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model], + embeddings = HuggingFaceBgeEmbeddings(model_name=get_model_path(model), model_kwargs={'device': device}, query_instruction="为这个句子生成表示以用于检索相关文章:") if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding embeddings.query_instruction = "" else: - embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device}) + embeddings = HuggingFaceEmbeddings(model_name=get_model_path(model), model_kwargs={'device': device}) item.obj = embeddings item.finish_loading() else: diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 02ad222..0a349f7 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -1,10 +1,10 @@ import os import urllib from fastapi import File, Form, Body, Query, UploadFile -from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, - VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, - CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, - logger, log_verbose,) +from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, + VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, + CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, + logger, log_verbose,) from server.utils import BaseResponse, ListResponse, run_in_thread_pool from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder,get_file_path, files2docs_in_thread, KnowledgeFile) diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index c97f8cc..a725a78 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -18,8 +18,8 @@ from server.db.repository.knowledge_file_repository import ( list_docs_from_db, ) -from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, - EMBEDDING_MODEL) +from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, + EMBEDDING_MODEL) from server.knowledge_base.utils import ( get_kb_path, get_doc_path, load_embeddings, KnowledgeFile, list_kbs_from_folder, list_files_from_folder, diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 6e20acf..2671f55 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -1,7 +1,7 @@ import os import shutil -from configs.model_config import ( +from configs import ( KB_ROOT_PATH, SCORE_THRESHOLD, logger, log_verbose, diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index 5ca425b..afb8331 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -7,7 +7,7 @@ from langchain.schema import Document from langchain.vectorstores import Milvus from sklearn.preprocessing import normalize -from configs.model_config import SCORE_THRESHOLD, kbs_config +from configs import SCORE_THRESHOLD, kbs_config from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \ score_threshold_process diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index fa832ab..9c17e80 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -7,7 +7,7 @@ from langchain.vectorstores import PGVector from langchain.vectorstores.pgvector import DistanceStrategy from sqlalchemy import text -from configs.model_config import EMBEDDING_DEVICE, kbs_config +from configs import kbs_config from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \ score_threshold_process diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index b2073a1..1b7c7ae 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -1,5 +1,5 @@ -from configs.model_config import (EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE, - logger, log_verbose) +from configs import (EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE, + logger, log_verbose) from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder, list_files_from_folder,files2docs_in_thread, KnowledgeFile,) diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index ab4c2f9..c804a6a 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -2,7 +2,7 @@ import os from transformers import AutoTokenizer -from configs.model_config import ( +from configs import ( EMBEDDING_MODEL, KB_ROOT_PATH, CHUNK_SIZE, @@ -23,7 +23,7 @@ from langchain.text_splitter import TextSplitter from pathlib import Path import json from concurrent.futures import ThreadPoolExecutor -from server.utils import run_in_thread_pool, embedding_device +from server.utils import run_in_thread_pool, embedding_device, get_model_worker_config import io from typing import List, Union, Callable, Dict, Optional, Tuple, Generator @@ -185,6 +185,7 @@ def make_text_splitter( splitter_name: str = TEXT_SPLITTER, chunk_size: int = CHUNK_SIZE, chunk_overlap: int = OVERLAP_SIZE, + llm_model: str = LLM_MODEL, ): """ 根据参数获取特定的分词器 @@ -220,8 +221,9 @@ def make_text_splitter( ) elif text_splitter_dict[splitter_name]["source"] == "huggingface": ## 从huggingface加载 if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "": + config = get_model_worker_config(llm_model) text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \ - llm_model_dict[LLM_MODEL]["local_model_path"] + config.get("model_path") if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2": from transformers import GPT2TokenizerFast diff --git a/server/model_workers/base.py b/server/model_workers/base.py index df5fbfc..653e29b 100644 --- a/server/model_workers/base.py +++ b/server/model_workers/base.py @@ -1,4 +1,4 @@ -from configs.model_config import LOG_PATH +from configs.basic_config import LOG_PATH import fastchat.constants fastchat.constants.LOGDIR = LOG_PATH from fastchat.serve.model_worker import BaseModelWorker diff --git a/server/utils.py b/server/utils.py index 48e7435..ec51ae0 100644 --- a/server/utils.py +++ b/server/utils.py @@ -4,8 +4,10 @@ from typing import List from fastapi import FastAPI from pathlib import Path import asyncio -from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDING_DEVICE, logger, log_verbose -from configs.server_config import FSCHAT_MODEL_WORKERS +from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE, + MODEL_PATH, MODEL_ROOT_PATH, + logger, log_verbose, + FSCHAT_MODEL_WORKERS) import os from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Literal, Optional, Callable, Generator, Dict, Any @@ -197,22 +199,56 @@ def MakeFastAPIOffline( ) +# 从model_config中获取模型信息 +def list_embed_models() -> List[str]: + return list(MODEL_PATH["embed_model"]) + +def list_llm_models() -> List[str]: + return list(MODEL_PATH["llm_model"]) + +def get_model_path(model_name: str, type: str = None) -> Optional[str]: + if type in MODEL_PATH: + paths = MODEL_PATH[type] + else: + paths = {} + for v in MODEL_PATH.values(): + paths.update(v) + + if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径 + path = Path(path_str) + if path.is_dir(): # 任意绝对路径 + return str(path) + + root_path = Path(MODEL_ROOT_PATH) + if root_path.is_dir(): + path = root_path / model_name + if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b + return str(path) + path = root_path / path_str + if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new + return str(path) + path = root_path / path_str.split("/")[-1] + if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new + return str(path) + return path_str # THUDM/chatglm06b + + # 从server_config中获取服务信息 -def get_model_worker_config(model_name: str = LLM_MODEL) -> dict: +def get_model_worker_config(model_name: str = None) -> dict: ''' 加载model worker的配置项。 - 优先级:FSCHAT_MODEL_WORKERS[model_name] > llm_model_dict[model_name] > FSCHAT_MODEL_WORKERS["default"] + 优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"] ''' + from configs.model_config import ONLINE_LLM_MODEL from configs.server_config import FSCHAT_MODEL_WORKERS from server import model_workers - from configs.model_config import llm_model_dict config = FSCHAT_MODEL_WORKERS.get("default", {}).copy() - config.update(llm_model_dict.get(model_name, {})) + config.update(ONLINE_LLM_MODEL.get(model_name, {})) config.update(FSCHAT_MODEL_WORKERS.get(model_name, {})) - # 如果没有设置有效的local_model_path,则认为是在线模型API - if not os.path.isdir(config.get("local_model_path", "")): + # 在线模型API + if model_name in ONLINE_LLM_MODEL: config["online_api"] = True if provider := config.get("provider"): try: @@ -222,13 +258,14 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict: logger.error(f'{e.__class__.__name__}: {msg}', exc_info=e if log_verbose else None) - config["device"] = llm_device(config.get("device") or LLM_DEVICE) + config["model_path"] = get_model_path(model_name) + config["device"] = llm_device(config.get("device")) return config def get_all_model_worker_configs() -> dict: result = {} - model_names = set(llm_model_dict.keys()) | set(FSCHAT_MODEL_WORKERS.keys()) + model_names = set(FSCHAT_MODEL_WORKERS.keys()) for name in model_names: if name != "default": result[name] = get_model_worker_config(name) @@ -256,7 +293,7 @@ def fschat_openai_api_address() -> str: host = FSCHAT_OPENAI_API["host"] port = FSCHAT_OPENAI_API["port"] - return f"http://{host}:{port}" + return f"http://{host}:{port}/v1" def api_address() -> str: @@ -302,13 +339,15 @@ def detect_device() -> Literal["cuda", "mps", "cpu"]: return "cpu" -def llm_device(device: str = LLM_DEVICE) -> Literal["cuda", "mps", "cpu"]: +def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]: + device = device or LLM_DEVICE if device not in ["cuda", "mps", "cpu"]: device = detect_device() return device -def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "cpu"]: +def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]: + device = device or EMBEDDING_DEVICE if device not in ["cuda", "mps", "cpu"]: device = detect_device() return device diff --git a/startup.py b/startup.py index b309444..d5be6e5 100644 --- a/startup.py +++ b/startup.py @@ -17,10 +17,19 @@ except: pass sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \ - logger, log_verbose, TEXT_SPLITTER -from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER, - FSCHAT_OPENAI_API, HTTPX_DEFAULT_TIMEOUT) +from configs import ( + LOG_PATH, + log_verbose, + logger, + LLM_MODEL, + EMBEDDING_MODEL, + TEXT_SPLITTER_NAME, + FSCHAT_CONTROLLER, + FSCHAT_OPENAI_API, + API_SERVER, + WEBUI_SERVER, + HTTPX_DEFAULT_TIMEOUT, +) from server.utils import (fschat_controller_address, fschat_model_worker_address, fschat_openai_api_address, set_httpx_timeout, get_model_worker_config, get_all_model_worker_configs, @@ -216,7 +225,7 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None): @app.post("/release_worker") def release_worker( model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]), - # worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[fschat_controller_address()]), + # worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[FSCHAT_CONTROLLER_address()]), new_model_name: str = Body(None, description="释放后加载该模型"), keep_origin: bool = Body(False, description="不释放原模型,加载新模型") ) -> Dict: @@ -250,7 +259,7 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None): return {"code": 500, "msg": msg} if new_model_name: - timer = HTTPX_DEFAULT_TIMEOUT * 2 # wait for new model_worker register + timer = HTTPX_DEFAULT_TIMEOUT # wait for new model_worker register while timer > 0: models = app._controller.list_models() if new_model_name in models: @@ -297,7 +306,7 @@ def run_model_worker( kwargs["model_names"] = [model_name] kwargs["controller_address"] = controller_address or fschat_controller_address() kwargs["worker_address"] = fschat_model_worker_address(model_name) - model_path = kwargs.get("local_model_path", "") + model_path = kwargs.get("model_path", "") kwargs["model_path"] = model_path app = create_model_worker_app(log_level=log_level, **kwargs) @@ -418,7 +427,7 @@ def parse_args() -> argparse.ArgumentParser: "-c", "--controller", type=str, - help="specify controller address the worker is registered to. default is server_config.FSCHAT_CONTROLLER", + help="specify controller address the worker is registered to. default is FSCHAT_CONTROLLER", dest="controller_address", ) parser.add_argument( @@ -474,15 +483,14 @@ def dump_server_info(after_start=False, args=None): print(f"当前启动的LLM模型:{models} @ {llm_device()}") for model in models: - pprint(llm_model_dict[model]) + pprint(get_model_worker_config(model)) print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}") if after_start: print("\n") print(f"服务端运行信息:") if args.openai_api: - print(f" OpenAI API Server: {fschat_openai_api_address()}/v1") - print(" (请确认llm_model_dict中配置的api_base_url与上面地址一致。)") + print(f" OpenAI API Server: {fschat_openai_api_address()}") if args.api: print(f" Chatchat API Server: {api_address()}") if args.webui: diff --git a/tests/api/test_kb_api.py b/tests/api/test_kb_api.py index ed4e8b2..975f8bc 100644 --- a/tests/api/test_kb_api.py +++ b/tests/api/test_kb_api.py @@ -6,7 +6,7 @@ from pathlib import Path root_path = Path(__file__).parent.parent.parent sys.path.append(str(root_path)) from server.utils import api_address -from configs.model_config import VECTOR_SEARCH_TOP_K +from configs import VECTOR_SEARCH_TOP_K from server.knowledge_base.utils import get_kb_path, get_file_path from pprint import pprint diff --git a/tests/api/test_kb_api_request.py b/tests/api/test_kb_api_request.py index 8645528..3c115f1 100644 --- a/tests/api/test_kb_api_request.py +++ b/tests/api/test_kb_api_request.py @@ -6,7 +6,7 @@ from pathlib import Path root_path = Path(__file__).parent.parent.parent sys.path.append(str(root_path)) from server.utils import api_address -from configs.model_config import VECTOR_SEARCH_TOP_K +from configs import VECTOR_SEARCH_TOP_K from server.knowledge_base.utils import get_kb_path, get_file_path from webui_pages.utils import ApiRequest diff --git a/tests/api/test_llm_api.py b/tests/api/test_llm_api.py index af5ced8..8957981 100644 --- a/tests/api/test_llm_api.py +++ b/tests/api/test_llm_api.py @@ -6,21 +6,19 @@ from pathlib import Path root_path = Path(__file__).parent.parent.parent sys.path.append(str(root_path)) from configs.server_config import FSCHAT_MODEL_WORKERS -from configs.model_config import LLM_MODEL, llm_model_dict +from configs.model_config import LLM_MODEL from server.utils import api_address, get_model_worker_config from pprint import pprint import random +from typing import List -def get_configured_models(): +def get_configured_models() -> List[str]: model_workers = list(FSCHAT_MODEL_WORKERS) if "default" in model_workers: model_workers.remove("default") - - llm_dict = list(llm_model_dict) - - return model_workers, llm_dict + return model_workers api_base_url = api_address() @@ -56,12 +54,9 @@ def test_change_model(api="/llm_model/change"): running_models = get_running_models() assert len(running_models) > 0 - model_workers, llm_dict = get_configured_models() + model_workers = get_configured_models() - availabel_new_models = set(model_workers) - set(running_models) - if len(availabel_new_models) == 0: - availabel_new_models = set(llm_dict) - set(running_models) - availabel_new_models = list(availabel_new_models) + availabel_new_models = list(set(model_workers) - set(running_models)) assert len(availabel_new_models) > 0 print(availabel_new_models) diff --git a/tests/api/test_stream_chat_api.py b/tests/api/test_stream_chat_api.py index 1431485..7989499 100644 --- a/tests/api/test_stream_chat_api.py +++ b/tests/api/test_stream_chat_api.py @@ -4,7 +4,7 @@ import sys from pathlib import Path sys.path.append(str(Path(__file__).parent.parent.parent)) -from configs.model_config import BING_SUBSCRIPTION_KEY +from configs import BING_SUBSCRIPTION_KEY from server.utils import api_address from pprint import pprint diff --git a/tests/custom_splitter/test_different_splitter.py b/tests/custom_splitter/test_different_splitter.py index fea597e..2111bae 100644 --- a/tests/custom_splitter/test_different_splitter.py +++ b/tests/custom_splitter/test_different_splitter.py @@ -4,7 +4,7 @@ from transformers import AutoTokenizer import sys sys.path.append("../..") -from configs.model_config import ( +from configs import ( CHUNK_SIZE, OVERLAP_SIZE ) diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 4b347df..50ad423 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -1,11 +1,10 @@ import streamlit as st -from configs.server_config import FSCHAT_MODEL_WORKERS from webui_pages.utils import * from streamlit_chatbox import * from datetime import datetime from server.chat.search_engine_chat import SEARCH_ENGINES import os -from configs.model_config import LLM_MODEL, TEMPERATURE +from configs import LLM_MODEL, TEMPERATURE from server.utils import get_model_worker_config from typing import List, Dict diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index c71da7e..bf8f089 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -6,9 +6,10 @@ import pandas as pd from server.knowledge_base.utils import get_file_path, LOADER_DICT from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details from typing import Literal, Dict, Tuple -from configs.model_config import (embedding_model_dict, kbs_config, - EMBEDDING_MODEL, DEFAULT_VS_TYPE, - CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) +from configs import (kbs_config, + EMBEDDING_MODEL, DEFAULT_VS_TYPE, + CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) +from server.utils import list_embed_models import os import time @@ -94,7 +95,7 @@ def knowledge_base_page(api: ApiRequest): key="vs_type", ) - embed_models = list(embedding_model_dict.keys()) + embed_models = list_embed_models() embed_model = cols[1].selectbox( "Embedding 模型", diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 26e5320..c23a8ea 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -1,12 +1,11 @@ # 该文件包含webui通用工具,可以被不同的webui使用 from typing import * from pathlib import Path -from configs.model_config import ( +from configs import ( EMBEDDING_MODEL, DEFAULT_VS_TYPE, KB_ROOT_PATH, LLM_MODEL, - llm_model_dict, HISTORY_LEN, TEMPERATURE, SCORE_THRESHOLD, @@ -15,9 +14,10 @@ from configs.model_config import ( ZH_TITLE_ENHANCE, VECTOR_SEARCH_TOP_K, SEARCH_ENGINE_TOP_K, + FSCHAT_MODEL_WORKERS, + HTTPX_DEFAULT_TIMEOUT, logger, log_verbose, ) -from configs.server_config import HTTPX_DEFAULT_TIMEOUT import httpx import asyncio from server.chat.openai_chat import OpenAiChatMsgIn @@ -779,7 +779,10 @@ class ApiRequest: ''' 获取configs中配置的模型列表 ''' - return list(llm_model_dict.keys()) + models = list(FSCHAT_MODEL_WORKERS.keys()) + if "default" in models: + models.remove("default") + return models def stop_llm_model( self,