优化configs (#1474)

* remove llm_model_dict

* optimize configs

* fix get_model_path

* 更改一些默认参数,添加千帆的默认配置

* Update server_config.py.example
This commit is contained in:
liunux4odoo 2023-09-15 17:52:22 +08:00 committed by GitHub
parent 456229c13f
commit f7c73b842a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 371 additions and 320 deletions

View File

@ -1,19 +1,12 @@
from langchain.chat_models import ChatOpenAI from server.chat.utils import get_ChatOpenAI
from configs.model_config import llm_model_dict, LLM_MODEL from configs.model_config import LLM_MODEL, TEMPERATURE
from langchain import LLMChain from langchain import LLMChain
from langchain.prompts.chat import ( from langchain.prompts.chat import (
ChatPromptTemplate, ChatPromptTemplate,
HumanMessagePromptTemplate, HumanMessagePromptTemplate,
) )
model = ChatOpenAI( model = get_ChatOpenAI(model_name=LLM_MODEL, temperature=TEMPERATURE)
streaming=True,
verbose=True,
# callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL
)
human_prompt = "{input}" human_prompt = "{input}"

View File

@ -1,4 +1,7 @@
from .basic_config import *
from .model_config import * from .model_config import *
from .kb_config import *
from .server_config import * from .server_config import *
VERSION = "v0.2.4"
VERSION = "v0.2.5-preview"

View File

@ -0,0 +1,21 @@
import logging
import os
# 是否显示详细日志
log_verbose = False
# 通常情况下不需要更改以下内容
# 日志格式
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.basicConfig(format=LOG_FORMAT)
# 日志存储路径
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
if not os.path.exists(LOG_PATH):
os.mkdir(LOG_PATH)

View File

@ -0,0 +1,107 @@
import os
# 默认向量库类型。可选faiss, milvus, pg.
DEFAULT_VS_TYPE = "faiss"
# 缓存向量库数量针对FAISS
CACHED_VS_NUM = 1
# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
CHUNK_SIZE = 500
# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)
OVERLAP_SIZE = 50
# 知识库匹配向量数量
VECTOR_SEARCH_TOP_K = 3
# 知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右
SCORE_THRESHOLD = 1
# 搜索引擎匹配结题数量
SEARCH_ENGINE_TOP_K = 3
# 基于本地知识问答的提示词模版使用Jinja2语法简单点就是用双大括号代替f-string的单大括号
PROMPT_TEMPLATE = """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
<已知信息>{{ context }}</已知信息>
<问题>{{ question }}</问题>"""
# Bing 搜索必备变量
# 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search
# 具体申请方式请见
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/create-bing-search-service-resource
# 使用python创建bing api 搜索实例详见:
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/quickstarts/rest/python
BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
# 注意不是bing Webmaster Tools的api key
# 此外如果是在服务器上报Failed to establish a new connection: [Errno 110] Connection timed out
# 是因为服务器加了防火墙需要联系管理员加白名单如果公司的服务器的话就别想了GG
BING_SUBSCRIPTION_KEY = ""
# 是否开启中文标题加强,以及标题增强的相关配置
# 通过增加标题判断判断哪些文本为标题并在metadata中进行标记
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
ZH_TITLE_ENHANCE = False
# 通常情况下不需要更改以下内容
# 知识库默认存储路径
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
if not os.path.exists(KB_ROOT_PATH):
os.mkdir(KB_ROOT_PATH)
# 数据库默认存储路径。
# 如果使用sqlite可以直接修改DB_ROOT_PATH如果使用其它数据库请直接修改SQLALCHEMY_DATABASE_URI。
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}"
# 可选向量库类型及对应配置
kbs_config = {
"faiss": {
},
"milvus": {
"host": "127.0.0.1",
"port": "19530",
"user": "",
"password": "",
"secure": False,
},
"pg": {
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
}
}
# TextSplitter配置项如果你不明白其中的含义就不要修改。
text_splitter_dict = {
"ChineseRecursiveTextSplitter": {
"source": "huggingface", ## 选择tiktoken则使用openai的方法
"tokenizer_name_or_path": "gpt2",
},
"SpacyTextSplitter": {
"source": "huggingface",
"tokenizer_name_or_path": "",
},
"RecursiveCharacterTextSplitter": {
"source": "tiktoken",
"tokenizer_name_or_path": "cl100k_base",
},
"MarkdownHeaderTextSplitter": {
"headers_to_split_on":
[
("#", "head1"),
("##", "head2"),
("###", "head3"),
("####", "head4"),
]
},
}
# TEXT_SPLITTER 名称
TEXT_SPLITTER_NAME = "SpacyTextSplitter"

View File

@ -1,36 +1,46 @@
import os import os
import logging
# 日志格式
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.basicConfig(format=LOG_FORMAT)
# 是否显示详细日志
log_verbose = False
# 在以下字典中修改属性值以指定本地embedding模型存储位置 # 可以指定一个绝对路径统一存放所有的Embedding和LLM模型。
# 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese" # 每个模型可以是一个单独的目录,也可以是某个目录下的二级子目录
# 此处请写绝对路径 MODEL_ROOT_PATH = ""
embedding_model_dict = {
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh", # 在以下字典中修改属性值以指定本地embedding模型存储位置。支持3种设置方法
"ernie-base": "nghuyong/ernie-3.0-base-zh", # 1、将对应的值修改为模型绝对路径
"text2vec-base": "shibing624/text2vec-base-chinese", # 2、不修改此处的值以 text2vec 为例):
"text2vec": "GanymedeNil/text2vec-large-chinese", # 2.1 如果{MODEL_ROOT_PATH}下存在如下任一子目录:
"text2vec-paraphrase": "shibing624/text2vec-base-chinese-paraphrase", # - text2vec
"text2vec-sentence": "shibing624/text2vec-base-chinese-sentence", # - GanymedeNil/text2vec-large-chinese
"text2vec-multilingual": "shibing624/text2vec-base-multilingual", # - text2vec-large-chinese
"text2vec-bge-large-chinese": "shibing624/text2vec-bge-large-chinese", # 2.2 如果以上本地路径不存在则使用huggingface模型
"m3e-small": "moka-ai/m3e-small", MODEL_PATH = {
"m3e-base": "moka-ai/m3e-base", "embed_model": {
"m3e-large": "moka-ai/m3e-large", "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
"bge-small-zh": "BAAI/bge-small-zh", "ernie-base": "nghuyong/ernie-3.0-base-zh",
"bge-base-zh": "BAAI/bge-base-zh", "text2vec-base": "shibing624/text2vec-base-chinese",
"bge-large-zh": "BAAI/bge-large-zh", "text2vec": "GanymedeNil/text2vec-large-chinese",
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct", "text2vec-paraphrase": "shibing624/text2vec-base-chinese-paraphrase",
"piccolo-base-zh": "sensenova/piccolo-base-zh", "text2vec-sentence": "shibing624/text2vec-base-chinese-sentence",
"piccolo-large-zh": "sensenova/piccolo-large-zh", "text2vec-multilingual": "shibing624/text2vec-base-multilingual",
"text-embedding-ada-002": os.environ.get("OPENAI_API_KEY") "text2vec-bge-large-chinese": "shibing624/text2vec-bge-large-chinese",
"m3e-small": "moka-ai/m3e-small",
"m3e-base": "moka-ai/m3e-base",
"m3e-large": "moka-ai/m3e-large",
"bge-small-zh": "BAAI/bge-small-zh",
"bge-base-zh": "BAAI/bge-base-zh",
"bge-large-zh": "BAAI/bge-large-zh",
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
"piccolo-base-zh": "sensenova/piccolo-base-zh",
"piccolo-large-zh": "sensenova/piccolo-large-zh",
"text-embedding-ada-002": "your OPENAI_API_KEY",
},
# TODO: add all supported llm models
"llm_model": {
"chatglm-6b": "THUDM/chatglm-6b",
"chatglm2-6b": "THUDM/chatglm2-6b",
"chatglm2-6b-int4": "THUDM/chatglm2-6b-int4",
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
},
} }
# 选用的 Embedding 名称 # 选用的 Embedding 名称
@ -39,25 +49,21 @@ EMBEDDING_MODEL = "m3e-base"
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 # Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
EMBEDDING_DEVICE = "auto" EMBEDDING_DEVICE = "auto"
llm_model_dict = { # LLM 名称
"chatglm-6b": { LLM_MODEL = "chatglm2-6b"
"local_model_path": "THUDM/chatglm-6b",
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_key": "EMPTY"
},
"chatglm2-6b": { # LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
"local_model_path": "THUDM/chatglm2-6b", LLM_DEVICE = "auto"
"api_base_url": "http://localhost:8888/v1", # URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
"api_key": "EMPTY"
},
"chatglm2-6b-32k": { # 历史对话轮数
"local_model_path": "THUDM/chatglm2-6b-32k", # "THUDM/chatglm2-6b-32k", HISTORY_LEN = 3
"api_base_url": "http://localhost:8888/v1", # "URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
"api_key": "EMPTY"
},
# LLM通用对话参数
TEMPERATURE = 0.7
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
ONLINE_LLM_MODEL = {
# 调用chatgpt时如果报出 urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443): # 调用chatgpt时如果报出 urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
# Max retries exceeded with url: /v1/chat/completions # Max retries exceeded with url: /v1/chat/completions
# 则需要将urllib3版本修改为1.25.11 # 则需要将urllib3版本修改为1.25.11
@ -74,29 +80,25 @@ llm_model_dict = {
# 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置 # 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置
# 比如: "openai_proxy": 'http://127.0.0.1:4780' # 比如: "openai_proxy": 'http://127.0.0.1:4780'
"gpt-3.5-turbo": { "gpt-3.5-turbo": {
"api_base_url": "https://api.openai.com/v1", "api_key": "your OPENAI_API_KEY",
"api_key": "", "openai_proxy": "your OPENAI_PROXY",
"openai_proxy": ""
}, },
# 线上模型。当前支持智谱AI。 # 线上模型。请在server_config中为每个在线API设置不同的端口
# 如果没有设置有效的local_model_path则认为是在线模型API。
# 请在server_config中为每个在线API设置不同的端口
# 具体注册及api key获取请前往 http://open.bigmodel.cn # 具体注册及api key获取请前往 http://open.bigmodel.cn
"zhipu-api": { "zhipu-api": {
"api_base_url": "http://127.0.0.1:8888/v1",
"api_key": "", "api_key": "",
"provider": "ChatGLMWorker",
"version": "chatglm_pro", # 可选包括 "chatglm_lite", "chatglm_std", "chatglm_pro" "version": "chatglm_pro", # 可选包括 "chatglm_lite", "chatglm_std", "chatglm_pro"
"provider": "ChatGLMWorker",
}, },
# 具体注册及api key获取请前往 https://api.minimax.chat/
"minimax-api": { "minimax-api": {
"api_base_url": "http://127.0.0.1:8888/v1",
"group_id": "", "group_id": "",
"api_key": "", "api_key": "",
"is_pro": False, "is_pro": False,
"provider": "MiniMaxWorker", "provider": "MiniMaxWorker",
}, },
# 具体注册及api key获取请前往 https://xinghuo.xfyun.cn/
"xinghuo-api": { "xinghuo-api": {
"api_base_url": "http://127.0.0.1:8888/v1",
"APPID": "", "APPID": "",
"APISecret": "", "APISecret": "",
"api_key": "", "api_key": "",
@ -105,140 +107,16 @@ llm_model_dict = {
}, },
# 百度千帆 API申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf # 百度千帆 API申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf
"qianfan-api": { "qianfan-api": {
"version": "ernie-bot", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo" 更多的见文档模型支持列表中千帆部分。 "version": "ernie-bot-turbo", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo" 更多的见官方文档。
"version_url": "", # 可以不填写version直接填写在千帆申请模型发布的API地址 "version_url": "", # 也可以不填写version直接填写在千帆申请模型发布的API地址
"api_base_url": "http://127.0.0.1:8888/v1",
"api_key": "", "api_key": "",
"secret_key": "", "secret_key": "",
"provider": "ErnieWorker", "provider": "QianFanWorker",
}
}
# LLM 名称
LLM_MODEL = "chatglm2-6b"
# 历史对话轮数
HISTORY_LEN = 3
# LLM通用对话参数
TEMPERATURE = 0.7
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
LLM_DEVICE = "auto"
# TextSplitter
text_splitter_dict = {
"ChineseRecursiveTextSplitter": {
"source": "",
"tokenizer_name_or_path": "",
},
"SpacyTextSplitter": {
"source": "huggingface",
"tokenizer_name_or_path": "gpt2",
},
"RecursiveCharacterTextSplitter": {
"source": "tiktoken",
"tokenizer_name_or_path": "cl100k_base",
},
"MarkdownHeaderTextSplitter": {
"headers_to_split_on":
[
("#", "head1"),
("##", "head2"),
("###", "head3"),
("####", "head4"),
]
}, },
} }
# TEXT_SPLITTER 名称
TEXT_SPLITTER = "ChineseRecursiveTextSplitter"
# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter) # 通常情况下不需要更改以下内容
CHUNK_SIZE = 250
# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)
OVERLAP_SIZE = 0
# 日志存储路径
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
if not os.path.exists(LOG_PATH):
os.mkdir(LOG_PATH)
# 知识库默认存储路径
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
if not os.path.exists(KB_ROOT_PATH):
os.mkdir(KB_ROOT_PATH)
# 数据库默认存储路径。
# 如果使用sqlite可以直接修改DB_ROOT_PATH如果使用其它数据库请直接修改SQLALCHEMY_DATABASE_URI。
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}"
# 可选向量库类型及对应配置
kbs_config = {
"faiss": {
},
"milvus": {
"host": "127.0.0.1",
"port": "19530",
"user": "",
"password": "",
"secure": False,
},
"pg": {
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
}
}
# 默认向量库类型。可选faiss, milvus, pg.
DEFAULT_VS_TYPE = "faiss"
# 缓存向量库数量
CACHED_VS_NUM = 1
# 知识库匹配向量数量
VECTOR_SEARCH_TOP_K = 3
# 知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右
SCORE_THRESHOLD = 1
# 搜索引擎匹配结题数量
SEARCH_ENGINE_TOP_K = 3
# nltk 模型存储路径 # nltk 模型存储路径
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data") NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
# 基于本地知识问答的提示词模版使用Jinja2语法简单点就是用双大括号代替f-string的单大括号
PROMPT_TEMPLATE = """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
<已知信息>{{ context }}</已知信息>
<问题>{{ question }}</问题>"""
# API 是否开启跨域默认为False如果需要开启请设置为True
# is open cross domain
OPEN_CROSS_DOMAIN = False
# Bing 搜索必备变量
# 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search
# 具体申请方式请见
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/create-bing-search-service-resource
# 使用python创建bing api 搜索实例详见:
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/quickstarts/rest/python
BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
# 注意不是bing Webmaster Tools的api key
# 此外如果是在服务器上报Failed to establish a new connection: [Errno 110] Connection timed out
# 是因为服务器加了防火墙需要联系管理员加白名单如果公司的服务器的话就别想了GG
BING_SUBSCRIPTION_KEY = ""
# 是否开启中文标题加强,以及标题增强的相关配置
# 通过增加标题判断判断哪些文本为标题并在metadata中进行标记
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
ZH_TITLE_ENHANCE = False

View File

@ -1,4 +1,4 @@
from .model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE from configs.model_config import LLM_DEVICE
import httpx import httpx
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。 # httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
@ -8,7 +8,7 @@ HTTPX_DEFAULT_TIMEOUT = 300.0
# is open cross domain # is open cross domain
OPEN_CROSS_DOMAIN = False OPEN_CROSS_DOMAIN = False
# 各服务器默认绑定host # 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host
DEFAULT_BIND_HOST = "127.0.0.1" DEFAULT_BIND_HOST = "127.0.0.1"
# webui.py server # webui.py server
@ -26,14 +26,14 @@ API_SERVER = {
# fastchat openai_api server # fastchat openai_api server
FSCHAT_OPENAI_API = { FSCHAT_OPENAI_API = {
"host": DEFAULT_BIND_HOST, "host": DEFAULT_BIND_HOST,
"port": 8888, # model_config.llm_model_dict中模型配置的api_base_url需要与这里一致。 "port": 20000,
} }
# fastchat model_worker server # fastchat model_worker server
# 这些模型必须是在model_config.llm_model_dict中正确配置的。 # 这些模型必须是在model_config.MODEL_PATH或ONLINE_MODEL中正确配置的。
# 在启动startup.py时可用通过`--model-worker --model-name xxxx`指定模型不指定则为LLM_MODEL # 在启动startup.py时可用通过`--model-worker --model-name xxxx`指定模型不指定则为LLM_MODEL
FSCHAT_MODEL_WORKERS = { FSCHAT_MODEL_WORKERS = {
# 所有模型共用的默认配置,可在模型专项配置或llm_model_dict中进行覆盖。 # 所有模型共用的默认配置,可在模型专项配置中进行覆盖。
"default": { "default": {
"host": DEFAULT_BIND_HOST, "host": DEFAULT_BIND_HOST,
"port": 20002, "port": 20002,
@ -64,17 +64,17 @@ FSCHAT_MODEL_WORKERS = {
"baichuan-7b": { # 使用default中的IP和端口 "baichuan-7b": { # 使用default中的IP和端口
"device": "cpu", "device": "cpu",
}, },
"zhipu-api": { # 请为每个在线API设置不同的端口 "zhipu-api": { # 请为每个要运行的在线API设置不同的端口
"port": 20003, "port": 21001,
}, },
"minimax-api": { # 请为每个在线API设置不同的端口 "minimax-api": {
"port": 20004, "port": 21002,
}, },
"xinghuo-api": { # 请为每个在线API设置不同的端口 "xinghuo-api": {
"port": 20005, "port": 21003,
}, },
"qianfan-api": { "qianfan-api": {
"port": 20006, "port": 21004,
}, },
} }

View File

@ -1,8 +1,7 @@
from fastapi import Body from fastapi import Body
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from configs.model_config import llm_model_dict, LLM_MODEL, TEMPERATURE from configs.model_config import LLM_MODEL, TEMPERATURE
from server.chat.utils import wrap_done from server.chat.utils import wrap_done, get_ChatOpenAI
from langchain.chat_models import ChatOpenAI
from langchain import LLMChain from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable from typing import AsyncIterable
@ -31,18 +30,11 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
model_name: str = LLM_MODEL, model_name: str = LLM_MODEL,
) -> AsyncIterable[str]: ) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler() callback = AsyncIteratorCallbackHandler()
model = get_ChatOpenAI(
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[model_name]["api_base_url"],
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
openai_proxy=llm_model_dict[model_name].get("openai_proxy") callbacks=[callback],
) )
input_msg = History(role="user", content="{{ input }}").to_msg_template(False) input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages( chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg]) [i.to_msg_template() for i in history] + [input_msg])

View File

@ -1,11 +1,10 @@
from fastapi import Body, Request from fastapi import Body, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE, from configs import (LLM_MODEL, PROMPT_TEMPLATE,
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
TEMPERATURE) TEMPERATURE)
from server.chat.utils import wrap_done from server.chat.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse from server.utils import BaseResponse
from langchain.chat_models import ChatOpenAI
from langchain import LLMChain from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable, List, Optional from typing import AsyncIterable, List, Optional
@ -50,16 +49,10 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
model_name: str = LLM_MODEL, model_name: str = LLM_MODEL,
) -> AsyncIterable[str]: ) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler() callback = AsyncIteratorCallbackHandler()
model = get_ChatOpenAI(
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[model_name]["api_base_url"],
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
openai_proxy=llm_model_dict[model_name].get("openai_proxy") callbacks=[callback],
) )
docs = search_docs(query, knowledge_base_name, top_k, score_threshold) docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs]) context = "\n".join([doc.page_content for doc in docs])

View File

@ -1,7 +1,8 @@
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from typing import List from typing import List
import openai import openai
from configs.model_config import llm_model_dict, LLM_MODEL, logger, log_verbose from configs import LLM_MODEL, logger, log_verbose
from server.utils import get_model_worker_config, fschat_openai_api_address
from pydantic import BaseModel from pydantic import BaseModel
@ -23,9 +24,10 @@ class OpenAiChatMsgIn(BaseModel):
async def openai_chat(msg: OpenAiChatMsgIn): async def openai_chat(msg: OpenAiChatMsgIn):
openai.api_key = llm_model_dict[LLM_MODEL]["api_key"] config = get_model_worker_config(msg.model)
openai.api_key = config.get("api_key", "EMPTY")
print(f"{openai.api_key=}") print(f"{openai.api_key=}")
openai.api_base = llm_model_dict[LLM_MODEL]["api_base_url"] openai.api_base = fschat_openai_api_address()
print(f"{openai.api_base=}") print(f"{openai.api_base=}")
print(msg) print(msg)

View File

@ -1,13 +1,12 @@
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY,
LLM_MODEL, SEARCH_ENGINE_TOP_K,
PROMPT_TEMPLATE, TEMPERATURE)
from fastapi import Body from fastapi import Body
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool from fastapi.concurrency import run_in_threadpool
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, from server.chat.utils import wrap_done, get_ChatOpenAI
PROMPT_TEMPLATE, TEMPERATURE)
from server.chat.utils import wrap_done
from server.utils import BaseResponse from server.utils import BaseResponse
from langchain.chat_models import ChatOpenAI
from langchain import LLMChain from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable from typing import AsyncIterable
@ -90,15 +89,10 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
model_name: str = LLM_MODEL, model_name: str = LLM_MODEL,
) -> AsyncIterable[str]: ) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler() callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI( model = get_ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[model_name]["api_base_url"],
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
openai_proxy=llm_model_dict[model_name].get("openai_proxy") callbacks=[callback],
) )
docs = await lookup_search_engine(query, search_engine_name, top_k) docs = await lookup_search_engine(query, search_engine_name, top_k)

View File

@ -1,8 +1,29 @@
import asyncio import asyncio
from typing import Awaitable, List, Tuple, Dict, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from langchain.prompts.chat import ChatMessagePromptTemplate from langchain.prompts.chat import ChatMessagePromptTemplate
from configs import logger, log_verbose from configs import logger, log_verbose
from server.utils import get_model_worker_config, fschat_openai_api_address
from langchain.chat_models import ChatOpenAI
from typing import Awaitable, List, Tuple, Dict, Union, Callable
def get_ChatOpenAI(
model_name: str,
temperature: float,
callbacks: List[Callable] = [],
) -> ChatOpenAI:
config = get_model_worker_config(model_name)
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=callbacks,
openai_api_key=config.get("api_key", "EMPTY"),
openai_api_base=fschat_openai_api_address(),
model_name=model_name,
temperature=temperature,
openai_proxy=config.get("openai_proxy")
)
return model
async def wrap_done(fn: Awaitable, event: asyncio.Event): async def wrap_done(fn: Awaitable, event: asyncio.Event):

View File

@ -2,7 +2,7 @@ from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from configs.model_config import SQLALCHEMY_DATABASE_URI from configs import SQLALCHEMY_DATABASE_URI
import json import json

View File

@ -3,7 +3,7 @@ from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import validate_kb_name from server.knowledge_base.utils import validate_kb_name
from server.knowledge_base.kb_service.base import KBServiceFactory from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_base_repository import list_kbs_from_db from server.db.repository.knowledge_base_repository import list_kbs_from_db
from configs.model_config import EMBEDDING_MODEL, logger, log_verbose from configs import EMBEDDING_MODEL, logger, log_verbose
from fastapi import Body from fastapi import Body

View File

@ -4,9 +4,9 @@ from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.schema import Document from langchain.schema import Document
import threading import threading
from configs.model_config import (CACHED_VS_NUM, EMBEDDING_MODEL, CHUNK_SIZE, from configs import (EMBEDDING_MODEL, CHUNK_SIZE, CACHED_VS_NUM,
embedding_model_dict, logger, log_verbose) logger, log_verbose)
from server.utils import embedding_device from server.utils import embedding_device, get_model_path
from contextlib import contextmanager from contextlib import contextmanager
from collections import OrderedDict from collections import OrderedDict
from typing import List, Any, Union, Tuple from typing import List, Any, Union, Tuple
@ -118,15 +118,15 @@ class EmbeddingsPool(CachePool):
with item.acquire(msg="初始化"): with item.acquire(msg="初始化"):
self.atomic.release() self.atomic.release()
if model == "text-embedding-ada-002": # openai text-embedding-ada-002 if model == "text-embedding-ada-002": # openai text-embedding-ada-002
embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE) embeddings = OpenAIEmbeddings(openai_api_key=get_model_path(model), chunk_size=CHUNK_SIZE)
elif 'bge-' in model: elif 'bge-' in model:
embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model], embeddings = HuggingFaceBgeEmbeddings(model_name=get_model_path(model),
model_kwargs={'device': device}, model_kwargs={'device': device},
query_instruction="为这个句子生成表示以用于检索相关文章:") query_instruction="为这个句子生成表示以用于检索相关文章:")
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
embeddings.query_instruction = "" embeddings.query_instruction = ""
else: else:
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device}) embeddings = HuggingFaceEmbeddings(model_name=get_model_path(model), model_kwargs={'device': device})
item.obj = embeddings item.obj = embeddings
item.finish_loading() item.finish_loading()
else: else:

View File

@ -1,10 +1,10 @@
import os import os
import urllib import urllib
from fastapi import File, Form, Body, Query, UploadFile from fastapi import File, Form, Body, Query, UploadFile
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
logger, log_verbose,) logger, log_verbose,)
from server.utils import BaseResponse, ListResponse, run_in_thread_pool from server.utils import BaseResponse, ListResponse, run_in_thread_pool
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder,get_file_path, from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder,get_file_path,
files2docs_in_thread, KnowledgeFile) files2docs_in_thread, KnowledgeFile)

View File

@ -18,8 +18,8 @@ from server.db.repository.knowledge_file_repository import (
list_docs_from_db, list_docs_from_db,
) )
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
EMBEDDING_MODEL) EMBEDDING_MODEL)
from server.knowledge_base.utils import ( from server.knowledge_base.utils import (
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile, get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
list_kbs_from_folder, list_files_from_folder, list_kbs_from_folder, list_files_from_folder,

View File

@ -1,7 +1,7 @@
import os import os
import shutil import shutil
from configs.model_config import ( from configs import (
KB_ROOT_PATH, KB_ROOT_PATH,
SCORE_THRESHOLD, SCORE_THRESHOLD,
logger, log_verbose, logger, log_verbose,

View File

@ -7,7 +7,7 @@ from langchain.schema import Document
from langchain.vectorstores import Milvus from langchain.vectorstores import Milvus
from sklearn.preprocessing import normalize from sklearn.preprocessing import normalize
from configs.model_config import SCORE_THRESHOLD, kbs_config from configs import SCORE_THRESHOLD, kbs_config
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \ from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
score_threshold_process score_threshold_process

View File

@ -7,7 +7,7 @@ from langchain.vectorstores import PGVector
from langchain.vectorstores.pgvector import DistanceStrategy from langchain.vectorstores.pgvector import DistanceStrategy
from sqlalchemy import text from sqlalchemy import text
from configs.model_config import EMBEDDING_DEVICE, kbs_config from configs import kbs_config
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \ from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
score_threshold_process score_threshold_process

View File

@ -1,5 +1,5 @@
from configs.model_config import (EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE, from configs import (EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
logger, log_verbose) logger, log_verbose)
from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder, from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
list_files_from_folder,files2docs_in_thread, list_files_from_folder,files2docs_in_thread,
KnowledgeFile,) KnowledgeFile,)

View File

@ -2,7 +2,7 @@ import os
from transformers import AutoTokenizer from transformers import AutoTokenizer
from configs.model_config import ( from configs import (
EMBEDDING_MODEL, EMBEDDING_MODEL,
KB_ROOT_PATH, KB_ROOT_PATH,
CHUNK_SIZE, CHUNK_SIZE,
@ -23,7 +23,7 @@ from langchain.text_splitter import TextSplitter
from pathlib import Path from pathlib import Path
import json import json
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from server.utils import run_in_thread_pool, embedding_device from server.utils import run_in_thread_pool, embedding_device, get_model_worker_config
import io import io
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
@ -185,6 +185,7 @@ def make_text_splitter(
splitter_name: str = TEXT_SPLITTER, splitter_name: str = TEXT_SPLITTER,
chunk_size: int = CHUNK_SIZE, chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE, chunk_overlap: int = OVERLAP_SIZE,
llm_model: str = LLM_MODEL,
): ):
""" """
根据参数获取特定的分词器 根据参数获取特定的分词器
@ -220,8 +221,9 @@ def make_text_splitter(
) )
elif text_splitter_dict[splitter_name]["source"] == "huggingface": ## 从huggingface加载 elif text_splitter_dict[splitter_name]["source"] == "huggingface": ## 从huggingface加载
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "": if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "":
config = get_model_worker_config(llm_model)
text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \ text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \
llm_model_dict[LLM_MODEL]["local_model_path"] config.get("model_path")
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2": if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2":
from transformers import GPT2TokenizerFast from transformers import GPT2TokenizerFast

View File

@ -1,4 +1,4 @@
from configs.model_config import LOG_PATH from configs.basic_config import LOG_PATH
import fastchat.constants import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import BaseModelWorker from fastchat.serve.model_worker import BaseModelWorker

View File

@ -4,8 +4,10 @@ from typing import List
from fastapi import FastAPI from fastapi import FastAPI
from pathlib import Path from pathlib import Path
import asyncio import asyncio
from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDING_DEVICE, logger, log_verbose from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
from configs.server_config import FSCHAT_MODEL_WORKERS MODEL_PATH, MODEL_ROOT_PATH,
logger, log_verbose,
FSCHAT_MODEL_WORKERS)
import os import os
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Literal, Optional, Callable, Generator, Dict, Any from typing import Literal, Optional, Callable, Generator, Dict, Any
@ -197,22 +199,56 @@ def MakeFastAPIOffline(
) )
# 从model_config中获取模型信息
def list_embed_models() -> List[str]:
return list(MODEL_PATH["embed_model"])
def list_llm_models() -> List[str]:
return list(MODEL_PATH["llm_model"])
def get_model_path(model_name: str, type: str = None) -> Optional[str]:
if type in MODEL_PATH:
paths = MODEL_PATH[type]
else:
paths = {}
for v in MODEL_PATH.values():
paths.update(v)
if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径
path = Path(path_str)
if path.is_dir(): # 任意绝对路径
return str(path)
root_path = Path(MODEL_ROOT_PATH)
if root_path.is_dir():
path = root_path / model_name
if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b
return str(path)
path = root_path / path_str
if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new
return str(path)
path = root_path / path_str.split("/")[-1]
if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new
return str(path)
return path_str # THUDM/chatglm06b
# 从server_config中获取服务信息 # 从server_config中获取服务信息
def get_model_worker_config(model_name: str = LLM_MODEL) -> dict: def get_model_worker_config(model_name: str = None) -> dict:
''' '''
加载model worker的配置项 加载model worker的配置项
优先级:FSCHAT_MODEL_WORKERS[model_name] > llm_model_dict[model_name] > FSCHAT_MODEL_WORKERS["default"] 优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"]
''' '''
from configs.model_config import ONLINE_LLM_MODEL
from configs.server_config import FSCHAT_MODEL_WORKERS from configs.server_config import FSCHAT_MODEL_WORKERS
from server import model_workers from server import model_workers
from configs.model_config import llm_model_dict
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy() config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
config.update(llm_model_dict.get(model_name, {})) config.update(ONLINE_LLM_MODEL.get(model_name, {}))
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {})) config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
# 如果没有设置有效的local_model_path则认为是在线模型API # 在线模型API
if not os.path.isdir(config.get("local_model_path", "")): if model_name in ONLINE_LLM_MODEL:
config["online_api"] = True config["online_api"] = True
if provider := config.get("provider"): if provider := config.get("provider"):
try: try:
@ -222,13 +258,14 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
config["device"] = llm_device(config.get("device") or LLM_DEVICE) config["model_path"] = get_model_path(model_name)
config["device"] = llm_device(config.get("device"))
return config return config
def get_all_model_worker_configs() -> dict: def get_all_model_worker_configs() -> dict:
result = {} result = {}
model_names = set(llm_model_dict.keys()) | set(FSCHAT_MODEL_WORKERS.keys()) model_names = set(FSCHAT_MODEL_WORKERS.keys())
for name in model_names: for name in model_names:
if name != "default": if name != "default":
result[name] = get_model_worker_config(name) result[name] = get_model_worker_config(name)
@ -256,7 +293,7 @@ def fschat_openai_api_address() -> str:
host = FSCHAT_OPENAI_API["host"] host = FSCHAT_OPENAI_API["host"]
port = FSCHAT_OPENAI_API["port"] port = FSCHAT_OPENAI_API["port"]
return f"http://{host}:{port}" return f"http://{host}:{port}/v1"
def api_address() -> str: def api_address() -> str:
@ -302,13 +339,15 @@ def detect_device() -> Literal["cuda", "mps", "cpu"]:
return "cpu" return "cpu"
def llm_device(device: str = LLM_DEVICE) -> Literal["cuda", "mps", "cpu"]: def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
device = device or LLM_DEVICE
if device not in ["cuda", "mps", "cpu"]: if device not in ["cuda", "mps", "cpu"]:
device = detect_device() device = detect_device()
return device return device
def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "cpu"]: def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
device = device or EMBEDDING_DEVICE
if device not in ["cuda", "mps", "cpu"]: if device not in ["cuda", "mps", "cpu"]:
device = detect_device() device = detect_device()
return device return device

View File

@ -17,10 +17,19 @@ except:
pass pass
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \ from configs import (
logger, log_verbose, TEXT_SPLITTER LOG_PATH,
from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER, log_verbose,
FSCHAT_OPENAI_API, HTTPX_DEFAULT_TIMEOUT) logger,
LLM_MODEL,
EMBEDDING_MODEL,
TEXT_SPLITTER_NAME,
FSCHAT_CONTROLLER,
FSCHAT_OPENAI_API,
API_SERVER,
WEBUI_SERVER,
HTTPX_DEFAULT_TIMEOUT,
)
from server.utils import (fschat_controller_address, fschat_model_worker_address, from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_timeout, fschat_openai_api_address, set_httpx_timeout,
get_model_worker_config, get_all_model_worker_configs, get_model_worker_config, get_all_model_worker_configs,
@ -216,7 +225,7 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
@app.post("/release_worker") @app.post("/release_worker")
def release_worker( def release_worker(
model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]), model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]),
# worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[fschat_controller_address()]), # worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[FSCHAT_CONTROLLER_address()]),
new_model_name: str = Body(None, description="释放后加载该模型"), new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型") keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict: ) -> Dict:
@ -250,7 +259,7 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
return {"code": 500, "msg": msg} return {"code": 500, "msg": msg}
if new_model_name: if new_model_name:
timer = HTTPX_DEFAULT_TIMEOUT * 2 # wait for new model_worker register timer = HTTPX_DEFAULT_TIMEOUT # wait for new model_worker register
while timer > 0: while timer > 0:
models = app._controller.list_models() models = app._controller.list_models()
if new_model_name in models: if new_model_name in models:
@ -297,7 +306,7 @@ def run_model_worker(
kwargs["model_names"] = [model_name] kwargs["model_names"] = [model_name]
kwargs["controller_address"] = controller_address or fschat_controller_address() kwargs["controller_address"] = controller_address or fschat_controller_address()
kwargs["worker_address"] = fschat_model_worker_address(model_name) kwargs["worker_address"] = fschat_model_worker_address(model_name)
model_path = kwargs.get("local_model_path", "") model_path = kwargs.get("model_path", "")
kwargs["model_path"] = model_path kwargs["model_path"] = model_path
app = create_model_worker_app(log_level=log_level, **kwargs) app = create_model_worker_app(log_level=log_level, **kwargs)
@ -418,7 +427,7 @@ def parse_args() -> argparse.ArgumentParser:
"-c", "-c",
"--controller", "--controller",
type=str, type=str,
help="specify controller address the worker is registered to. default is server_config.FSCHAT_CONTROLLER", help="specify controller address the worker is registered to. default is FSCHAT_CONTROLLER",
dest="controller_address", dest="controller_address",
) )
parser.add_argument( parser.add_argument(
@ -474,15 +483,14 @@ def dump_server_info(after_start=False, args=None):
print(f"当前启动的LLM模型{models} @ {llm_device()}") print(f"当前启动的LLM模型{models} @ {llm_device()}")
for model in models: for model in models:
pprint(llm_model_dict[model]) pprint(get_model_worker_config(model))
print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {embedding_device()}") print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {embedding_device()}")
if after_start: if after_start:
print("\n") print("\n")
print(f"服务端运行信息:") print(f"服务端运行信息:")
if args.openai_api: if args.openai_api:
print(f" OpenAI API Server: {fschat_openai_api_address()}/v1") print(f" OpenAI API Server: {fschat_openai_api_address()}")
print(" (请确认llm_model_dict中配置的api_base_url与上面地址一致。)")
if args.api: if args.api:
print(f" Chatchat API Server: {api_address()}") print(f" Chatchat API Server: {api_address()}")
if args.webui: if args.webui:

View File

@ -6,7 +6,7 @@ from pathlib import Path
root_path = Path(__file__).parent.parent.parent root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path)) sys.path.append(str(root_path))
from server.utils import api_address from server.utils import api_address
from configs.model_config import VECTOR_SEARCH_TOP_K from configs import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path, get_file_path from server.knowledge_base.utils import get_kb_path, get_file_path
from pprint import pprint from pprint import pprint

View File

@ -6,7 +6,7 @@ from pathlib import Path
root_path = Path(__file__).parent.parent.parent root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path)) sys.path.append(str(root_path))
from server.utils import api_address from server.utils import api_address
from configs.model_config import VECTOR_SEARCH_TOP_K from configs import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path, get_file_path from server.knowledge_base.utils import get_kb_path, get_file_path
from webui_pages.utils import ApiRequest from webui_pages.utils import ApiRequest

View File

@ -6,21 +6,19 @@ from pathlib import Path
root_path = Path(__file__).parent.parent.parent root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path)) sys.path.append(str(root_path))
from configs.server_config import FSCHAT_MODEL_WORKERS from configs.server_config import FSCHAT_MODEL_WORKERS
from configs.model_config import LLM_MODEL, llm_model_dict from configs.model_config import LLM_MODEL
from server.utils import api_address, get_model_worker_config from server.utils import api_address, get_model_worker_config
from pprint import pprint from pprint import pprint
import random import random
from typing import List
def get_configured_models(): def get_configured_models() -> List[str]:
model_workers = list(FSCHAT_MODEL_WORKERS) model_workers = list(FSCHAT_MODEL_WORKERS)
if "default" in model_workers: if "default" in model_workers:
model_workers.remove("default") model_workers.remove("default")
return model_workers
llm_dict = list(llm_model_dict)
return model_workers, llm_dict
api_base_url = api_address() api_base_url = api_address()
@ -56,12 +54,9 @@ def test_change_model(api="/llm_model/change"):
running_models = get_running_models() running_models = get_running_models()
assert len(running_models) > 0 assert len(running_models) > 0
model_workers, llm_dict = get_configured_models() model_workers = get_configured_models()
availabel_new_models = set(model_workers) - set(running_models) availabel_new_models = list(set(model_workers) - set(running_models))
if len(availabel_new_models) == 0:
availabel_new_models = set(llm_dict) - set(running_models)
availabel_new_models = list(availabel_new_models)
assert len(availabel_new_models) > 0 assert len(availabel_new_models) > 0
print(availabel_new_models) print(availabel_new_models)

View File

@ -4,7 +4,7 @@ import sys
from pathlib import Path from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent.parent)) sys.path.append(str(Path(__file__).parent.parent.parent))
from configs.model_config import BING_SUBSCRIPTION_KEY from configs import BING_SUBSCRIPTION_KEY
from server.utils import api_address from server.utils import api_address
from pprint import pprint from pprint import pprint

View File

@ -4,7 +4,7 @@ from transformers import AutoTokenizer
import sys import sys
sys.path.append("../..") sys.path.append("../..")
from configs.model_config import ( from configs import (
CHUNK_SIZE, CHUNK_SIZE,
OVERLAP_SIZE OVERLAP_SIZE
) )

View File

@ -1,11 +1,10 @@
import streamlit as st import streamlit as st
from configs.server_config import FSCHAT_MODEL_WORKERS
from webui_pages.utils import * from webui_pages.utils import *
from streamlit_chatbox import * from streamlit_chatbox import *
from datetime import datetime from datetime import datetime
from server.chat.search_engine_chat import SEARCH_ENGINES from server.chat.search_engine_chat import SEARCH_ENGINES
import os import os
from configs.model_config import LLM_MODEL, TEMPERATURE from configs import LLM_MODEL, TEMPERATURE
from server.utils import get_model_worker_config from server.utils import get_model_worker_config
from typing import List, Dict from typing import List, Dict

View File

@ -6,9 +6,10 @@ import pandas as pd
from server.knowledge_base.utils import get_file_path, LOADER_DICT from server.knowledge_base.utils import get_file_path, LOADER_DICT
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
from typing import Literal, Dict, Tuple from typing import Literal, Dict, Tuple
from configs.model_config import (embedding_model_dict, kbs_config, from configs import (kbs_config,
EMBEDDING_MODEL, DEFAULT_VS_TYPE, EMBEDDING_MODEL, DEFAULT_VS_TYPE,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
from server.utils import list_embed_models
import os import os
import time import time
@ -94,7 +95,7 @@ def knowledge_base_page(api: ApiRequest):
key="vs_type", key="vs_type",
) )
embed_models = list(embedding_model_dict.keys()) embed_models = list_embed_models()
embed_model = cols[1].selectbox( embed_model = cols[1].selectbox(
"Embedding 模型", "Embedding 模型",

View File

@ -1,12 +1,11 @@
# 该文件包含webui通用工具可以被不同的webui使用 # 该文件包含webui通用工具可以被不同的webui使用
from typing import * from typing import *
from pathlib import Path from pathlib import Path
from configs.model_config import ( from configs import (
EMBEDDING_MODEL, EMBEDDING_MODEL,
DEFAULT_VS_TYPE, DEFAULT_VS_TYPE,
KB_ROOT_PATH, KB_ROOT_PATH,
LLM_MODEL, LLM_MODEL,
llm_model_dict,
HISTORY_LEN, HISTORY_LEN,
TEMPERATURE, TEMPERATURE,
SCORE_THRESHOLD, SCORE_THRESHOLD,
@ -15,9 +14,10 @@ from configs.model_config import (
ZH_TITLE_ENHANCE, ZH_TITLE_ENHANCE,
VECTOR_SEARCH_TOP_K, VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K, SEARCH_ENGINE_TOP_K,
FSCHAT_MODEL_WORKERS,
HTTPX_DEFAULT_TIMEOUT,
logger, log_verbose, logger, log_verbose,
) )
from configs.server_config import HTTPX_DEFAULT_TIMEOUT
import httpx import httpx
import asyncio import asyncio
from server.chat.openai_chat import OpenAiChatMsgIn from server.chat.openai_chat import OpenAiChatMsgIn
@ -779,7 +779,10 @@ class ApiRequest:
''' '''
获取configs中配置的模型列表 获取configs中配置的模型列表
''' '''
return list(llm_model_dict.keys()) models = list(FSCHAT_MODEL_WORKERS.keys())
if "default" in models:
models.remove("default")
return models
def stop_llm_model( def stop_llm_model(
self, self,