优化configs (#1474)
* remove llm_model_dict * optimize configs * fix get_model_path * 更改一些默认参数,添加千帆的默认配置 * Update server_config.py.example
This commit is contained in:
parent
456229c13f
commit
f7c73b842a
|
|
@ -1,19 +1,12 @@
|
||||||
from langchain.chat_models import ChatOpenAI
|
from server.chat.utils import get_ChatOpenAI
|
||||||
from configs.model_config import llm_model_dict, LLM_MODEL
|
from configs.model_config import LLM_MODEL, TEMPERATURE
|
||||||
from langchain import LLMChain
|
from langchain import LLMChain
|
||||||
from langchain.prompts.chat import (
|
from langchain.prompts.chat import (
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = ChatOpenAI(
|
model = get_ChatOpenAI(model_name=LLM_MODEL, temperature=TEMPERATURE)
|
||||||
streaming=True,
|
|
||||||
verbose=True,
|
|
||||||
# callbacks=[callback],
|
|
||||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
|
||||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
|
||||||
model_name=LLM_MODEL
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
human_prompt = "{input}"
|
human_prompt = "{input}"
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,7 @@
|
||||||
|
from .basic_config import *
|
||||||
from .model_config import *
|
from .model_config import *
|
||||||
|
from .kb_config import *
|
||||||
from .server_config import *
|
from .server_config import *
|
||||||
|
|
||||||
VERSION = "v0.2.4"
|
|
||||||
|
VERSION = "v0.2.5-preview"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
# 是否显示详细日志
|
||||||
|
log_verbose = False
|
||||||
|
|
||||||
|
|
||||||
|
# 通常情况下不需要更改以下内容
|
||||||
|
|
||||||
|
# 日志格式
|
||||||
|
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
logging.basicConfig(format=LOG_FORMAT)
|
||||||
|
|
||||||
|
|
||||||
|
# 日志存储路径
|
||||||
|
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
|
||||||
|
if not os.path.exists(LOG_PATH):
|
||||||
|
os.mkdir(LOG_PATH)
|
||||||
|
|
@ -0,0 +1,107 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
# 默认向量库类型。可选:faiss, milvus, pg.
|
||||||
|
DEFAULT_VS_TYPE = "faiss"
|
||||||
|
|
||||||
|
# 缓存向量库数量(针对FAISS)
|
||||||
|
CACHED_VS_NUM = 1
|
||||||
|
|
||||||
|
# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
|
||||||
|
CHUNK_SIZE = 500
|
||||||
|
|
||||||
|
# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)
|
||||||
|
OVERLAP_SIZE = 50
|
||||||
|
|
||||||
|
# 知识库匹配向量数量
|
||||||
|
VECTOR_SEARCH_TOP_K = 3
|
||||||
|
|
||||||
|
# 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右
|
||||||
|
SCORE_THRESHOLD = 1
|
||||||
|
|
||||||
|
# 搜索引擎匹配结题数量
|
||||||
|
SEARCH_ENGINE_TOP_K = 3
|
||||||
|
|
||||||
|
|
||||||
|
# 基于本地知识问答的提示词模版(使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号
|
||||||
|
PROMPT_TEMPLATE = """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
|
||||||
|
|
||||||
|
<已知信息>{{ context }}</已知信息>
|
||||||
|
|
||||||
|
<问题>{{ question }}</问题>"""
|
||||||
|
|
||||||
|
|
||||||
|
# Bing 搜索必备变量
|
||||||
|
# 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search
|
||||||
|
# 具体申请方式请见
|
||||||
|
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/create-bing-search-service-resource
|
||||||
|
# 使用python创建bing api 搜索实例详见:
|
||||||
|
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/quickstarts/rest/python
|
||||||
|
BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
|
||||||
|
# 注意不是bing Webmaster Tools的api key,
|
||||||
|
|
||||||
|
# 此外,如果是在服务器上,报Failed to establish a new connection: [Errno 110] Connection timed out
|
||||||
|
# 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG
|
||||||
|
BING_SUBSCRIPTION_KEY = ""
|
||||||
|
|
||||||
|
# 是否开启中文标题加强,以及标题增强的相关配置
|
||||||
|
# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记;
|
||||||
|
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
|
||||||
|
ZH_TITLE_ENHANCE = False
|
||||||
|
|
||||||
|
|
||||||
|
# 通常情况下不需要更改以下内容
|
||||||
|
|
||||||
|
# 知识库默认存储路径
|
||||||
|
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
|
||||||
|
if not os.path.exists(KB_ROOT_PATH):
|
||||||
|
os.mkdir(KB_ROOT_PATH)
|
||||||
|
|
||||||
|
# 数据库默认存储路径。
|
||||||
|
# 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。
|
||||||
|
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
|
||||||
|
SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}"
|
||||||
|
|
||||||
|
# 可选向量库类型及对应配置
|
||||||
|
kbs_config = {
|
||||||
|
"faiss": {
|
||||||
|
},
|
||||||
|
"milvus": {
|
||||||
|
"host": "127.0.0.1",
|
||||||
|
"port": "19530",
|
||||||
|
"user": "",
|
||||||
|
"password": "",
|
||||||
|
"secure": False,
|
||||||
|
},
|
||||||
|
"pg": {
|
||||||
|
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# TextSplitter配置项,如果你不明白其中的含义,就不要修改。
|
||||||
|
text_splitter_dict = {
|
||||||
|
"ChineseRecursiveTextSplitter": {
|
||||||
|
"source": "huggingface", ## 选择tiktoken则使用openai的方法
|
||||||
|
"tokenizer_name_or_path": "gpt2",
|
||||||
|
},
|
||||||
|
"SpacyTextSplitter": {
|
||||||
|
"source": "huggingface",
|
||||||
|
"tokenizer_name_or_path": "",
|
||||||
|
},
|
||||||
|
"RecursiveCharacterTextSplitter": {
|
||||||
|
"source": "tiktoken",
|
||||||
|
"tokenizer_name_or_path": "cl100k_base",
|
||||||
|
},
|
||||||
|
"MarkdownHeaderTextSplitter": {
|
||||||
|
"headers_to_split_on":
|
||||||
|
[
|
||||||
|
("#", "head1"),
|
||||||
|
("##", "head2"),
|
||||||
|
("###", "head3"),
|
||||||
|
("####", "head4"),
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# TEXT_SPLITTER 名称
|
||||||
|
TEXT_SPLITTER_NAME = "SpacyTextSplitter"
|
||||||
|
|
@ -1,36 +1,46 @@
|
||||||
import os
|
import os
|
||||||
import logging
|
|
||||||
# 日志格式
|
|
||||||
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
|
||||||
logger = logging.getLogger()
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
logging.basicConfig(format=LOG_FORMAT)
|
|
||||||
# 是否显示详细日志
|
|
||||||
log_verbose = False
|
|
||||||
|
|
||||||
|
|
||||||
# 在以下字典中修改属性值,以指定本地embedding模型存储位置
|
# 可以指定一个绝对路径,统一存放所有的Embedding和LLM模型。
|
||||||
# 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese"
|
# 每个模型可以是一个单独的目录,也可以是某个目录下的二级子目录
|
||||||
# 此处请写绝对路径
|
MODEL_ROOT_PATH = ""
|
||||||
embedding_model_dict = {
|
|
||||||
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
|
# 在以下字典中修改属性值,以指定本地embedding模型存储位置。支持3种设置方法:
|
||||||
"ernie-base": "nghuyong/ernie-3.0-base-zh",
|
# 1、将对应的值修改为模型绝对路径
|
||||||
"text2vec-base": "shibing624/text2vec-base-chinese",
|
# 2、不修改此处的值(以 text2vec 为例):
|
||||||
"text2vec": "GanymedeNil/text2vec-large-chinese",
|
# 2.1 如果{MODEL_ROOT_PATH}下存在如下任一子目录:
|
||||||
"text2vec-paraphrase": "shibing624/text2vec-base-chinese-paraphrase",
|
# - text2vec
|
||||||
"text2vec-sentence": "shibing624/text2vec-base-chinese-sentence",
|
# - GanymedeNil/text2vec-large-chinese
|
||||||
"text2vec-multilingual": "shibing624/text2vec-base-multilingual",
|
# - text2vec-large-chinese
|
||||||
"text2vec-bge-large-chinese": "shibing624/text2vec-bge-large-chinese",
|
# 2.2 如果以上本地路径不存在,则使用huggingface模型
|
||||||
"m3e-small": "moka-ai/m3e-small",
|
MODEL_PATH = {
|
||||||
"m3e-base": "moka-ai/m3e-base",
|
"embed_model": {
|
||||||
"m3e-large": "moka-ai/m3e-large",
|
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
|
||||||
"bge-small-zh": "BAAI/bge-small-zh",
|
"ernie-base": "nghuyong/ernie-3.0-base-zh",
|
||||||
"bge-base-zh": "BAAI/bge-base-zh",
|
"text2vec-base": "shibing624/text2vec-base-chinese",
|
||||||
"bge-large-zh": "BAAI/bge-large-zh",
|
"text2vec": "GanymedeNil/text2vec-large-chinese",
|
||||||
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
|
"text2vec-paraphrase": "shibing624/text2vec-base-chinese-paraphrase",
|
||||||
"piccolo-base-zh": "sensenova/piccolo-base-zh",
|
"text2vec-sentence": "shibing624/text2vec-base-chinese-sentence",
|
||||||
"piccolo-large-zh": "sensenova/piccolo-large-zh",
|
"text2vec-multilingual": "shibing624/text2vec-base-multilingual",
|
||||||
"text-embedding-ada-002": os.environ.get("OPENAI_API_KEY")
|
"text2vec-bge-large-chinese": "shibing624/text2vec-bge-large-chinese",
|
||||||
|
"m3e-small": "moka-ai/m3e-small",
|
||||||
|
"m3e-base": "moka-ai/m3e-base",
|
||||||
|
"m3e-large": "moka-ai/m3e-large",
|
||||||
|
"bge-small-zh": "BAAI/bge-small-zh",
|
||||||
|
"bge-base-zh": "BAAI/bge-base-zh",
|
||||||
|
"bge-large-zh": "BAAI/bge-large-zh",
|
||||||
|
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
|
||||||
|
"piccolo-base-zh": "sensenova/piccolo-base-zh",
|
||||||
|
"piccolo-large-zh": "sensenova/piccolo-large-zh",
|
||||||
|
"text-embedding-ada-002": "your OPENAI_API_KEY",
|
||||||
|
},
|
||||||
|
# TODO: add all supported llm models
|
||||||
|
"llm_model": {
|
||||||
|
"chatglm-6b": "THUDM/chatglm-6b",
|
||||||
|
"chatglm2-6b": "THUDM/chatglm2-6b",
|
||||||
|
"chatglm2-6b-int4": "THUDM/chatglm2-6b-int4",
|
||||||
|
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# 选用的 Embedding 名称
|
# 选用的 Embedding 名称
|
||||||
|
|
@ -39,25 +49,21 @@ EMBEDDING_MODEL = "m3e-base"
|
||||||
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
|
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
|
||||||
EMBEDDING_DEVICE = "auto"
|
EMBEDDING_DEVICE = "auto"
|
||||||
|
|
||||||
llm_model_dict = {
|
# LLM 名称
|
||||||
"chatglm-6b": {
|
LLM_MODEL = "chatglm2-6b"
|
||||||
"local_model_path": "THUDM/chatglm-6b",
|
|
||||||
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
|
|
||||||
"api_key": "EMPTY"
|
|
||||||
},
|
|
||||||
|
|
||||||
"chatglm2-6b": {
|
# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
|
||||||
"local_model_path": "THUDM/chatglm2-6b",
|
LLM_DEVICE = "auto"
|
||||||
"api_base_url": "http://localhost:8888/v1", # URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
|
|
||||||
"api_key": "EMPTY"
|
|
||||||
},
|
|
||||||
|
|
||||||
"chatglm2-6b-32k": {
|
# 历史对话轮数
|
||||||
"local_model_path": "THUDM/chatglm2-6b-32k", # "THUDM/chatglm2-6b-32k",
|
HISTORY_LEN = 3
|
||||||
"api_base_url": "http://localhost:8888/v1", # "URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
|
|
||||||
"api_key": "EMPTY"
|
|
||||||
},
|
|
||||||
|
|
||||||
|
# LLM通用对话参数
|
||||||
|
TEMPERATURE = 0.7
|
||||||
|
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
|
||||||
|
|
||||||
|
|
||||||
|
ONLINE_LLM_MODEL = {
|
||||||
# 调用chatgpt时如果报出: urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
|
# 调用chatgpt时如果报出: urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
|
||||||
# Max retries exceeded with url: /v1/chat/completions
|
# Max retries exceeded with url: /v1/chat/completions
|
||||||
# 则需要将urllib3版本修改为1.25.11
|
# 则需要将urllib3版本修改为1.25.11
|
||||||
|
|
@ -74,29 +80,25 @@ llm_model_dict = {
|
||||||
# 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置
|
# 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置
|
||||||
# 比如: "openai_proxy": 'http://127.0.0.1:4780'
|
# 比如: "openai_proxy": 'http://127.0.0.1:4780'
|
||||||
"gpt-3.5-turbo": {
|
"gpt-3.5-turbo": {
|
||||||
"api_base_url": "https://api.openai.com/v1",
|
"api_key": "your OPENAI_API_KEY",
|
||||||
"api_key": "",
|
"openai_proxy": "your OPENAI_PROXY",
|
||||||
"openai_proxy": ""
|
|
||||||
},
|
},
|
||||||
# 线上模型。当前支持智谱AI。
|
# 线上模型。请在server_config中为每个在线API设置不同的端口
|
||||||
# 如果没有设置有效的local_model_path,则认为是在线模型API。
|
|
||||||
# 请在server_config中为每个在线API设置不同的端口
|
|
||||||
# 具体注册及api key获取请前往 http://open.bigmodel.cn
|
# 具体注册及api key获取请前往 http://open.bigmodel.cn
|
||||||
"zhipu-api": {
|
"zhipu-api": {
|
||||||
"api_base_url": "http://127.0.0.1:8888/v1",
|
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"provider": "ChatGLMWorker",
|
|
||||||
"version": "chatglm_pro", # 可选包括 "chatglm_lite", "chatglm_std", "chatglm_pro"
|
"version": "chatglm_pro", # 可选包括 "chatglm_lite", "chatglm_std", "chatglm_pro"
|
||||||
|
"provider": "ChatGLMWorker",
|
||||||
},
|
},
|
||||||
|
# 具体注册及api key获取请前往 https://api.minimax.chat/
|
||||||
"minimax-api": {
|
"minimax-api": {
|
||||||
"api_base_url": "http://127.0.0.1:8888/v1",
|
|
||||||
"group_id": "",
|
"group_id": "",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"is_pro": False,
|
"is_pro": False,
|
||||||
"provider": "MiniMaxWorker",
|
"provider": "MiniMaxWorker",
|
||||||
},
|
},
|
||||||
|
# 具体注册及api key获取请前往 https://xinghuo.xfyun.cn/
|
||||||
"xinghuo-api": {
|
"xinghuo-api": {
|
||||||
"api_base_url": "http://127.0.0.1:8888/v1",
|
|
||||||
"APPID": "",
|
"APPID": "",
|
||||||
"APISecret": "",
|
"APISecret": "",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
|
|
@ -105,140 +107,16 @@ llm_model_dict = {
|
||||||
},
|
},
|
||||||
# 百度千帆 API,申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf
|
# 百度千帆 API,申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf
|
||||||
"qianfan-api": {
|
"qianfan-api": {
|
||||||
"version": "ernie-bot", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo", 更多的见文档模型支持列表中千帆部分。
|
"version": "ernie-bot-turbo", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo", 更多的见官方文档。
|
||||||
"version_url": "", # 可以不填写version,直接填写在千帆申请模型发布的API地址
|
"version_url": "", # 也可以不填写version,直接填写在千帆申请模型发布的API地址
|
||||||
"api_base_url": "http://127.0.0.1:8888/v1",
|
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"secret_key": "",
|
"secret_key": "",
|
||||||
"provider": "ErnieWorker",
|
"provider": "QianFanWorker",
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# LLM 名称
|
|
||||||
LLM_MODEL = "chatglm2-6b"
|
|
||||||
|
|
||||||
# 历史对话轮数
|
|
||||||
HISTORY_LEN = 3
|
|
||||||
|
|
||||||
# LLM通用对话参数
|
|
||||||
TEMPERATURE = 0.7
|
|
||||||
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
|
|
||||||
|
|
||||||
|
|
||||||
# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
|
|
||||||
LLM_DEVICE = "auto"
|
|
||||||
|
|
||||||
# TextSplitter
|
|
||||||
|
|
||||||
text_splitter_dict = {
|
|
||||||
"ChineseRecursiveTextSplitter": {
|
|
||||||
"source": "",
|
|
||||||
"tokenizer_name_or_path": "",
|
|
||||||
},
|
|
||||||
"SpacyTextSplitter": {
|
|
||||||
"source": "huggingface",
|
|
||||||
"tokenizer_name_or_path": "gpt2",
|
|
||||||
},
|
|
||||||
"RecursiveCharacterTextSplitter": {
|
|
||||||
"source": "tiktoken",
|
|
||||||
"tokenizer_name_or_path": "cl100k_base",
|
|
||||||
},
|
|
||||||
|
|
||||||
"MarkdownHeaderTextSplitter": {
|
|
||||||
"headers_to_split_on":
|
|
||||||
[
|
|
||||||
("#", "head1"),
|
|
||||||
("##", "head2"),
|
|
||||||
("###", "head3"),
|
|
||||||
("####", "head4"),
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# TEXT_SPLITTER 名称
|
|
||||||
TEXT_SPLITTER = "ChineseRecursiveTextSplitter"
|
|
||||||
|
|
||||||
# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
|
# 通常情况下不需要更改以下内容
|
||||||
CHUNK_SIZE = 250
|
|
||||||
|
|
||||||
# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)
|
|
||||||
OVERLAP_SIZE = 0
|
|
||||||
|
|
||||||
|
|
||||||
# 日志存储路径
|
|
||||||
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
|
|
||||||
if not os.path.exists(LOG_PATH):
|
|
||||||
os.mkdir(LOG_PATH)
|
|
||||||
|
|
||||||
# 知识库默认存储路径
|
|
||||||
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
|
|
||||||
if not os.path.exists(KB_ROOT_PATH):
|
|
||||||
os.mkdir(KB_ROOT_PATH)
|
|
||||||
# 数据库默认存储路径。
|
|
||||||
# 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。
|
|
||||||
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
|
|
||||||
SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}"
|
|
||||||
|
|
||||||
|
|
||||||
# 可选向量库类型及对应配置
|
|
||||||
kbs_config = {
|
|
||||||
"faiss": {
|
|
||||||
},
|
|
||||||
"milvus": {
|
|
||||||
"host": "127.0.0.1",
|
|
||||||
"port": "19530",
|
|
||||||
"user": "",
|
|
||||||
"password": "",
|
|
||||||
"secure": False,
|
|
||||||
},
|
|
||||||
"pg": {
|
|
||||||
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# 默认向量库类型。可选:faiss, milvus, pg.
|
|
||||||
DEFAULT_VS_TYPE = "faiss"
|
|
||||||
|
|
||||||
# 缓存向量库数量
|
|
||||||
CACHED_VS_NUM = 1
|
|
||||||
|
|
||||||
# 知识库匹配向量数量
|
|
||||||
VECTOR_SEARCH_TOP_K = 3
|
|
||||||
|
|
||||||
# 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右
|
|
||||||
SCORE_THRESHOLD = 1
|
|
||||||
|
|
||||||
# 搜索引擎匹配结题数量
|
|
||||||
SEARCH_ENGINE_TOP_K = 3
|
|
||||||
|
|
||||||
# nltk 模型存储路径
|
# nltk 模型存储路径
|
||||||
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
|
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
|
||||||
|
|
||||||
# 基于本地知识问答的提示词模版(使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号
|
|
||||||
PROMPT_TEMPLATE = """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
|
|
||||||
|
|
||||||
<已知信息>{{ context }}</已知信息>
|
|
||||||
|
|
||||||
<问题>{{ question }}</问题>"""
|
|
||||||
|
|
||||||
# API 是否开启跨域,默认为False,如果需要开启,请设置为True
|
|
||||||
# is open cross domain
|
|
||||||
OPEN_CROSS_DOMAIN = False
|
|
||||||
|
|
||||||
# Bing 搜索必备变量
|
|
||||||
# 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search
|
|
||||||
# 具体申请方式请见
|
|
||||||
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/create-bing-search-service-resource
|
|
||||||
# 使用python创建bing api 搜索实例详见:
|
|
||||||
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/quickstarts/rest/python
|
|
||||||
BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
|
|
||||||
# 注意不是bing Webmaster Tools的api key,
|
|
||||||
|
|
||||||
# 此外,如果是在服务器上,报Failed to establish a new connection: [Errno 110] Connection timed out
|
|
||||||
# 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG
|
|
||||||
BING_SUBSCRIPTION_KEY = ""
|
|
||||||
|
|
||||||
# 是否开启中文标题加强,以及标题增强的相关配置
|
|
||||||
# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记;
|
|
||||||
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
|
|
||||||
ZH_TITLE_ENHANCE = False
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from .model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE
|
from configs.model_config import LLM_DEVICE
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
|
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
|
||||||
|
|
@ -8,7 +8,7 @@ HTTPX_DEFAULT_TIMEOUT = 300.0
|
||||||
# is open cross domain
|
# is open cross domain
|
||||||
OPEN_CROSS_DOMAIN = False
|
OPEN_CROSS_DOMAIN = False
|
||||||
|
|
||||||
# 各服务器默认绑定host
|
# 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host
|
||||||
DEFAULT_BIND_HOST = "127.0.0.1"
|
DEFAULT_BIND_HOST = "127.0.0.1"
|
||||||
|
|
||||||
# webui.py server
|
# webui.py server
|
||||||
|
|
@ -26,14 +26,14 @@ API_SERVER = {
|
||||||
# fastchat openai_api server
|
# fastchat openai_api server
|
||||||
FSCHAT_OPENAI_API = {
|
FSCHAT_OPENAI_API = {
|
||||||
"host": DEFAULT_BIND_HOST,
|
"host": DEFAULT_BIND_HOST,
|
||||||
"port": 8888, # model_config.llm_model_dict中模型配置的api_base_url需要与这里一致。
|
"port": 20000,
|
||||||
}
|
}
|
||||||
|
|
||||||
# fastchat model_worker server
|
# fastchat model_worker server
|
||||||
# 这些模型必须是在model_config.llm_model_dict中正确配置的。
|
# 这些模型必须是在model_config.MODEL_PATH或ONLINE_MODEL中正确配置的。
|
||||||
# 在启动startup.py时,可用通过`--model-worker --model-name xxxx`指定模型,不指定则为LLM_MODEL
|
# 在启动startup.py时,可用通过`--model-worker --model-name xxxx`指定模型,不指定则为LLM_MODEL
|
||||||
FSCHAT_MODEL_WORKERS = {
|
FSCHAT_MODEL_WORKERS = {
|
||||||
# 所有模型共用的默认配置,可在模型专项配置或llm_model_dict中进行覆盖。
|
# 所有模型共用的默认配置,可在模型专项配置中进行覆盖。
|
||||||
"default": {
|
"default": {
|
||||||
"host": DEFAULT_BIND_HOST,
|
"host": DEFAULT_BIND_HOST,
|
||||||
"port": 20002,
|
"port": 20002,
|
||||||
|
|
@ -64,17 +64,17 @@ FSCHAT_MODEL_WORKERS = {
|
||||||
"baichuan-7b": { # 使用default中的IP和端口
|
"baichuan-7b": { # 使用default中的IP和端口
|
||||||
"device": "cpu",
|
"device": "cpu",
|
||||||
},
|
},
|
||||||
"zhipu-api": { # 请为每个在线API设置不同的端口
|
"zhipu-api": { # 请为每个要运行的在线API设置不同的端口
|
||||||
"port": 20003,
|
"port": 21001,
|
||||||
},
|
},
|
||||||
"minimax-api": { # 请为每个在线API设置不同的端口
|
"minimax-api": {
|
||||||
"port": 20004,
|
"port": 21002,
|
||||||
},
|
},
|
||||||
"xinghuo-api": { # 请为每个在线API设置不同的端口
|
"xinghuo-api": {
|
||||||
"port": 20005,
|
"port": 21003,
|
||||||
},
|
},
|
||||||
"qianfan-api": {
|
"qianfan-api": {
|
||||||
"port": 20006,
|
"port": 21004,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from configs.model_config import llm_model_dict, LLM_MODEL, TEMPERATURE
|
from configs.model_config import LLM_MODEL, TEMPERATURE
|
||||||
from server.chat.utils import wrap_done
|
from server.chat.utils import wrap_done, get_ChatOpenAI
|
||||||
from langchain.chat_models import ChatOpenAI
|
|
||||||
from langchain import LLMChain
|
from langchain import LLMChain
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
from typing import AsyncIterable
|
from typing import AsyncIterable
|
||||||
|
|
@ -31,18 +30,11 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||||
model_name: str = LLM_MODEL,
|
model_name: str = LLM_MODEL,
|
||||||
) -> AsyncIterable[str]:
|
) -> AsyncIterable[str]:
|
||||||
callback = AsyncIteratorCallbackHandler()
|
callback = AsyncIteratorCallbackHandler()
|
||||||
|
model = get_ChatOpenAI(
|
||||||
model = ChatOpenAI(
|
|
||||||
streaming=True,
|
|
||||||
verbose=True,
|
|
||||||
callbacks=[callback],
|
|
||||||
openai_api_key=llm_model_dict[model_name]["api_key"],
|
|
||||||
openai_api_base=llm_model_dict[model_name]["api_base_url"],
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
|
callbacks=[callback],
|
||||||
)
|
)
|
||||||
|
|
||||||
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
|
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
|
||||||
chat_prompt = ChatPromptTemplate.from_messages(
|
chat_prompt = ChatPromptTemplate.from_messages(
|
||||||
[i.to_msg_template() for i in history] + [input_msg])
|
[i.to_msg_template() for i in history] + [input_msg])
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,10 @@
|
||||||
from fastapi import Body, Request
|
from fastapi import Body, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
from configs import (LLM_MODEL, PROMPT_TEMPLATE,
|
||||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||||
TEMPERATURE)
|
TEMPERATURE)
|
||||||
from server.chat.utils import wrap_done
|
from server.chat.utils import wrap_done, get_ChatOpenAI
|
||||||
from server.utils import BaseResponse
|
from server.utils import BaseResponse
|
||||||
from langchain.chat_models import ChatOpenAI
|
|
||||||
from langchain import LLMChain
|
from langchain import LLMChain
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
from typing import AsyncIterable, List, Optional
|
from typing import AsyncIterable, List, Optional
|
||||||
|
|
@ -50,16 +49,10 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||||
model_name: str = LLM_MODEL,
|
model_name: str = LLM_MODEL,
|
||||||
) -> AsyncIterable[str]:
|
) -> AsyncIterable[str]:
|
||||||
callback = AsyncIteratorCallbackHandler()
|
callback = AsyncIteratorCallbackHandler()
|
||||||
|
model = get_ChatOpenAI(
|
||||||
model = ChatOpenAI(
|
|
||||||
streaming=True,
|
|
||||||
verbose=True,
|
|
||||||
callbacks=[callback],
|
|
||||||
openai_api_key=llm_model_dict[model_name]["api_key"],
|
|
||||||
openai_api_base=llm_model_dict[model_name]["api_base_url"],
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
|
callbacks=[callback],
|
||||||
)
|
)
|
||||||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
context = "\n".join([doc.page_content for doc in docs])
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from typing import List
|
from typing import List
|
||||||
import openai
|
import openai
|
||||||
from configs.model_config import llm_model_dict, LLM_MODEL, logger, log_verbose
|
from configs import LLM_MODEL, logger, log_verbose
|
||||||
|
from server.utils import get_model_worker_config, fschat_openai_api_address
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -23,9 +24,10 @@ class OpenAiChatMsgIn(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
async def openai_chat(msg: OpenAiChatMsgIn):
|
async def openai_chat(msg: OpenAiChatMsgIn):
|
||||||
openai.api_key = llm_model_dict[LLM_MODEL]["api_key"]
|
config = get_model_worker_config(msg.model)
|
||||||
|
openai.api_key = config.get("api_key", "EMPTY")
|
||||||
print(f"{openai.api_key=}")
|
print(f"{openai.api_key=}")
|
||||||
openai.api_base = llm_model_dict[LLM_MODEL]["api_base_url"]
|
openai.api_base = fschat_openai_api_address()
|
||||||
print(f"{openai.api_base=}")
|
print(f"{openai.api_base=}")
|
||||||
print(msg)
|
print(msg)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,12 @@
|
||||||
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
|
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
|
||||||
from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY
|
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY,
|
||||||
|
LLM_MODEL, SEARCH_ENGINE_TOP_K,
|
||||||
|
PROMPT_TEMPLATE, TEMPERATURE)
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from fastapi.concurrency import run_in_threadpool
|
from fastapi.concurrency import run_in_threadpool
|
||||||
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K,
|
from server.chat.utils import wrap_done, get_ChatOpenAI
|
||||||
PROMPT_TEMPLATE, TEMPERATURE)
|
|
||||||
from server.chat.utils import wrap_done
|
|
||||||
from server.utils import BaseResponse
|
from server.utils import BaseResponse
|
||||||
from langchain.chat_models import ChatOpenAI
|
|
||||||
from langchain import LLMChain
|
from langchain import LLMChain
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
from typing import AsyncIterable
|
from typing import AsyncIterable
|
||||||
|
|
@ -90,15 +89,10 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
||||||
model_name: str = LLM_MODEL,
|
model_name: str = LLM_MODEL,
|
||||||
) -> AsyncIterable[str]:
|
) -> AsyncIterable[str]:
|
||||||
callback = AsyncIteratorCallbackHandler()
|
callback = AsyncIteratorCallbackHandler()
|
||||||
model = ChatOpenAI(
|
model = get_ChatOpenAI(
|
||||||
streaming=True,
|
|
||||||
verbose=True,
|
|
||||||
callbacks=[callback],
|
|
||||||
openai_api_key=llm_model_dict[model_name]["api_key"],
|
|
||||||
openai_api_base=llm_model_dict[model_name]["api_base_url"],
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
|
callbacks=[callback],
|
||||||
)
|
)
|
||||||
|
|
||||||
docs = await lookup_search_engine(query, search_engine_name, top_k)
|
docs = await lookup_search_engine(query, search_engine_name, top_k)
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,29 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Awaitable, List, Tuple, Dict, Union
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from langchain.prompts.chat import ChatMessagePromptTemplate
|
from langchain.prompts.chat import ChatMessagePromptTemplate
|
||||||
from configs import logger, log_verbose
|
from configs import logger, log_verbose
|
||||||
|
from server.utils import get_model_worker_config, fschat_openai_api_address
|
||||||
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
from typing import Awaitable, List, Tuple, Dict, Union, Callable
|
||||||
|
|
||||||
|
|
||||||
|
def get_ChatOpenAI(
|
||||||
|
model_name: str,
|
||||||
|
temperature: float,
|
||||||
|
callbacks: List[Callable] = [],
|
||||||
|
) -> ChatOpenAI:
|
||||||
|
config = get_model_worker_config(model_name)
|
||||||
|
model = ChatOpenAI(
|
||||||
|
streaming=True,
|
||||||
|
verbose=True,
|
||||||
|
callbacks=callbacks,
|
||||||
|
openai_api_key=config.get("api_key", "EMPTY"),
|
||||||
|
openai_api_base=fschat_openai_api_address(),
|
||||||
|
model_name=model_name,
|
||||||
|
temperature=temperature,
|
||||||
|
openai_proxy=config.get("openai_proxy")
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ from sqlalchemy import create_engine
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from configs.model_config import SQLALCHEMY_DATABASE_URI
|
from configs import SQLALCHEMY_DATABASE_URI
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from server.utils import BaseResponse, ListResponse
|
||||||
from server.knowledge_base.utils import validate_kb_name
|
from server.knowledge_base.utils import validate_kb_name
|
||||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||||
from server.db.repository.knowledge_base_repository import list_kbs_from_db
|
from server.db.repository.knowledge_base_repository import list_kbs_from_db
|
||||||
from configs.model_config import EMBEDDING_MODEL, logger, log_verbose
|
from configs import EMBEDDING_MODEL, logger, log_verbose
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,9 @@ from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
import threading
|
import threading
|
||||||
from configs.model_config import (CACHED_VS_NUM, EMBEDDING_MODEL, CHUNK_SIZE,
|
from configs import (EMBEDDING_MODEL, CHUNK_SIZE, CACHED_VS_NUM,
|
||||||
embedding_model_dict, logger, log_verbose)
|
logger, log_verbose)
|
||||||
from server.utils import embedding_device
|
from server.utils import embedding_device, get_model_path
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import List, Any, Union, Tuple
|
from typing import List, Any, Union, Tuple
|
||||||
|
|
@ -118,15 +118,15 @@ class EmbeddingsPool(CachePool):
|
||||||
with item.acquire(msg="初始化"):
|
with item.acquire(msg="初始化"):
|
||||||
self.atomic.release()
|
self.atomic.release()
|
||||||
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
|
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
|
||||||
embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE)
|
embeddings = OpenAIEmbeddings(openai_api_key=get_model_path(model), chunk_size=CHUNK_SIZE)
|
||||||
elif 'bge-' in model:
|
elif 'bge-' in model:
|
||||||
embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model],
|
embeddings = HuggingFaceBgeEmbeddings(model_name=get_model_path(model),
|
||||||
model_kwargs={'device': device},
|
model_kwargs={'device': device},
|
||||||
query_instruction="为这个句子生成表示以用于检索相关文章:")
|
query_instruction="为这个句子生成表示以用于检索相关文章:")
|
||||||
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
|
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
|
||||||
embeddings.query_instruction = ""
|
embeddings.query_instruction = ""
|
||||||
else:
|
else:
|
||||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device})
|
embeddings = HuggingFaceEmbeddings(model_name=get_model_path(model), model_kwargs={'device': device})
|
||||||
item.obj = embeddings
|
item.obj = embeddings
|
||||||
item.finish_loading()
|
item.finish_loading()
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
import os
|
import os
|
||||||
import urllib
|
import urllib
|
||||||
from fastapi import File, Form, Body, Query, UploadFile
|
from fastapi import File, Form, Body, Query, UploadFile
|
||||||
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
|
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
|
||||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
|
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
|
||||||
logger, log_verbose,)
|
logger, log_verbose,)
|
||||||
from server.utils import BaseResponse, ListResponse, run_in_thread_pool
|
from server.utils import BaseResponse, ListResponse, run_in_thread_pool
|
||||||
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder,get_file_path,
|
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder,get_file_path,
|
||||||
files2docs_in_thread, KnowledgeFile)
|
files2docs_in_thread, KnowledgeFile)
|
||||||
|
|
|
||||||
|
|
@ -18,8 +18,8 @@ from server.db.repository.knowledge_file_repository import (
|
||||||
list_docs_from_db,
|
list_docs_from_db,
|
||||||
)
|
)
|
||||||
|
|
||||||
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||||
EMBEDDING_MODEL)
|
EMBEDDING_MODEL)
|
||||||
from server.knowledge_base.utils import (
|
from server.knowledge_base.utils import (
|
||||||
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
|
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
|
||||||
list_kbs_from_folder, list_files_from_folder,
|
list_kbs_from_folder, list_files_from_folder,
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from configs.model_config import (
|
from configs import (
|
||||||
KB_ROOT_PATH,
|
KB_ROOT_PATH,
|
||||||
SCORE_THRESHOLD,
|
SCORE_THRESHOLD,
|
||||||
logger, log_verbose,
|
logger, log_verbose,
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from langchain.schema import Document
|
||||||
from langchain.vectorstores import Milvus
|
from langchain.vectorstores import Milvus
|
||||||
from sklearn.preprocessing import normalize
|
from sklearn.preprocessing import normalize
|
||||||
|
|
||||||
from configs.model_config import SCORE_THRESHOLD, kbs_config
|
from configs import SCORE_THRESHOLD, kbs_config
|
||||||
|
|
||||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
|
||||||
score_threshold_process
|
score_threshold_process
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from langchain.vectorstores import PGVector
|
||||||
from langchain.vectorstores.pgvector import DistanceStrategy
|
from langchain.vectorstores.pgvector import DistanceStrategy
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
|
|
||||||
from configs.model_config import EMBEDDING_DEVICE, kbs_config
|
from configs import kbs_config
|
||||||
|
|
||||||
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
|
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
|
||||||
score_threshold_process
|
score_threshold_process
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from configs.model_config import (EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
|
from configs import (EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
|
||||||
logger, log_verbose)
|
logger, log_verbose)
|
||||||
from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
|
from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
|
||||||
list_files_from_folder,files2docs_in_thread,
|
list_files_from_folder,files2docs_in_thread,
|
||||||
KnowledgeFile,)
|
KnowledgeFile,)
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import os
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from configs.model_config import (
|
from configs import (
|
||||||
EMBEDDING_MODEL,
|
EMBEDDING_MODEL,
|
||||||
KB_ROOT_PATH,
|
KB_ROOT_PATH,
|
||||||
CHUNK_SIZE,
|
CHUNK_SIZE,
|
||||||
|
|
@ -23,7 +23,7 @@ from langchain.text_splitter import TextSplitter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json
|
import json
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from server.utils import run_in_thread_pool, embedding_device
|
from server.utils import run_in_thread_pool, embedding_device, get_model_worker_config
|
||||||
import io
|
import io
|
||||||
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
||||||
|
|
||||||
|
|
@ -185,6 +185,7 @@ def make_text_splitter(
|
||||||
splitter_name: str = TEXT_SPLITTER,
|
splitter_name: str = TEXT_SPLITTER,
|
||||||
chunk_size: int = CHUNK_SIZE,
|
chunk_size: int = CHUNK_SIZE,
|
||||||
chunk_overlap: int = OVERLAP_SIZE,
|
chunk_overlap: int = OVERLAP_SIZE,
|
||||||
|
llm_model: str = LLM_MODEL,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
根据参数获取特定的分词器
|
根据参数获取特定的分词器
|
||||||
|
|
@ -220,8 +221,9 @@ def make_text_splitter(
|
||||||
)
|
)
|
||||||
elif text_splitter_dict[splitter_name]["source"] == "huggingface": ## 从huggingface加载
|
elif text_splitter_dict[splitter_name]["source"] == "huggingface": ## 从huggingface加载
|
||||||
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "":
|
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "":
|
||||||
|
config = get_model_worker_config(llm_model)
|
||||||
text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \
|
text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \
|
||||||
llm_model_dict[LLM_MODEL]["local_model_path"]
|
config.get("model_path")
|
||||||
|
|
||||||
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2":
|
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2":
|
||||||
from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from configs.model_config import LOG_PATH
|
from configs.basic_config import LOG_PATH
|
||||||
import fastchat.constants
|
import fastchat.constants
|
||||||
fastchat.constants.LOGDIR = LOG_PATH
|
fastchat.constants.LOGDIR = LOG_PATH
|
||||||
from fastchat.serve.model_worker import BaseModelWorker
|
from fastchat.serve.model_worker import BaseModelWorker
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,10 @@ from typing import List
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import asyncio
|
import asyncio
|
||||||
from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDING_DEVICE, logger, log_verbose
|
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
|
||||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
MODEL_PATH, MODEL_ROOT_PATH,
|
||||||
|
logger, log_verbose,
|
||||||
|
FSCHAT_MODEL_WORKERS)
|
||||||
import os
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from typing import Literal, Optional, Callable, Generator, Dict, Any
|
from typing import Literal, Optional, Callable, Generator, Dict, Any
|
||||||
|
|
@ -197,22 +199,56 @@ def MakeFastAPIOffline(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 从model_config中获取模型信息
|
||||||
|
def list_embed_models() -> List[str]:
|
||||||
|
return list(MODEL_PATH["embed_model"])
|
||||||
|
|
||||||
|
def list_llm_models() -> List[str]:
|
||||||
|
return list(MODEL_PATH["llm_model"])
|
||||||
|
|
||||||
|
def get_model_path(model_name: str, type: str = None) -> Optional[str]:
|
||||||
|
if type in MODEL_PATH:
|
||||||
|
paths = MODEL_PATH[type]
|
||||||
|
else:
|
||||||
|
paths = {}
|
||||||
|
for v in MODEL_PATH.values():
|
||||||
|
paths.update(v)
|
||||||
|
|
||||||
|
if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径
|
||||||
|
path = Path(path_str)
|
||||||
|
if path.is_dir(): # 任意绝对路径
|
||||||
|
return str(path)
|
||||||
|
|
||||||
|
root_path = Path(MODEL_ROOT_PATH)
|
||||||
|
if root_path.is_dir():
|
||||||
|
path = root_path / model_name
|
||||||
|
if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b
|
||||||
|
return str(path)
|
||||||
|
path = root_path / path_str
|
||||||
|
if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new
|
||||||
|
return str(path)
|
||||||
|
path = root_path / path_str.split("/")[-1]
|
||||||
|
if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new
|
||||||
|
return str(path)
|
||||||
|
return path_str # THUDM/chatglm06b
|
||||||
|
|
||||||
|
|
||||||
# 从server_config中获取服务信息
|
# 从server_config中获取服务信息
|
||||||
def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
|
def get_model_worker_config(model_name: str = None) -> dict:
|
||||||
'''
|
'''
|
||||||
加载model worker的配置项。
|
加载model worker的配置项。
|
||||||
优先级:FSCHAT_MODEL_WORKERS[model_name] > llm_model_dict[model_name] > FSCHAT_MODEL_WORKERS["default"]
|
优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"]
|
||||||
'''
|
'''
|
||||||
|
from configs.model_config import ONLINE_LLM_MODEL
|
||||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||||
from server import model_workers
|
from server import model_workers
|
||||||
from configs.model_config import llm_model_dict
|
|
||||||
|
|
||||||
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
||||||
config.update(llm_model_dict.get(model_name, {}))
|
config.update(ONLINE_LLM_MODEL.get(model_name, {}))
|
||||||
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
|
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
|
||||||
|
|
||||||
# 如果没有设置有效的local_model_path,则认为是在线模型API
|
# 在线模型API
|
||||||
if not os.path.isdir(config.get("local_model_path", "")):
|
if model_name in ONLINE_LLM_MODEL:
|
||||||
config["online_api"] = True
|
config["online_api"] = True
|
||||||
if provider := config.get("provider"):
|
if provider := config.get("provider"):
|
||||||
try:
|
try:
|
||||||
|
|
@ -222,13 +258,14 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
exc_info=e if log_verbose else None)
|
exc_info=e if log_verbose else None)
|
||||||
|
|
||||||
config["device"] = llm_device(config.get("device") or LLM_DEVICE)
|
config["model_path"] = get_model_path(model_name)
|
||||||
|
config["device"] = llm_device(config.get("device"))
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def get_all_model_worker_configs() -> dict:
|
def get_all_model_worker_configs() -> dict:
|
||||||
result = {}
|
result = {}
|
||||||
model_names = set(llm_model_dict.keys()) | set(FSCHAT_MODEL_WORKERS.keys())
|
model_names = set(FSCHAT_MODEL_WORKERS.keys())
|
||||||
for name in model_names:
|
for name in model_names:
|
||||||
if name != "default":
|
if name != "default":
|
||||||
result[name] = get_model_worker_config(name)
|
result[name] = get_model_worker_config(name)
|
||||||
|
|
@ -256,7 +293,7 @@ def fschat_openai_api_address() -> str:
|
||||||
|
|
||||||
host = FSCHAT_OPENAI_API["host"]
|
host = FSCHAT_OPENAI_API["host"]
|
||||||
port = FSCHAT_OPENAI_API["port"]
|
port = FSCHAT_OPENAI_API["port"]
|
||||||
return f"http://{host}:{port}"
|
return f"http://{host}:{port}/v1"
|
||||||
|
|
||||||
|
|
||||||
def api_address() -> str:
|
def api_address() -> str:
|
||||||
|
|
@ -302,13 +339,15 @@ def detect_device() -> Literal["cuda", "mps", "cpu"]:
|
||||||
return "cpu"
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
def llm_device(device: str = LLM_DEVICE) -> Literal["cuda", "mps", "cpu"]:
|
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
||||||
|
device = device or LLM_DEVICE
|
||||||
if device not in ["cuda", "mps", "cpu"]:
|
if device not in ["cuda", "mps", "cpu"]:
|
||||||
device = detect_device()
|
device = detect_device()
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "cpu"]:
|
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
||||||
|
device = device or EMBEDDING_DEVICE
|
||||||
if device not in ["cuda", "mps", "cpu"]:
|
if device not in ["cuda", "mps", "cpu"]:
|
||||||
device = detect_device()
|
device = detect_device()
|
||||||
return device
|
return device
|
||||||
|
|
|
||||||
30
startup.py
30
startup.py
|
|
@ -17,10 +17,19 @@ except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||||
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
|
from configs import (
|
||||||
logger, log_verbose, TEXT_SPLITTER
|
LOG_PATH,
|
||||||
from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER,
|
log_verbose,
|
||||||
FSCHAT_OPENAI_API, HTTPX_DEFAULT_TIMEOUT)
|
logger,
|
||||||
|
LLM_MODEL,
|
||||||
|
EMBEDDING_MODEL,
|
||||||
|
TEXT_SPLITTER_NAME,
|
||||||
|
FSCHAT_CONTROLLER,
|
||||||
|
FSCHAT_OPENAI_API,
|
||||||
|
API_SERVER,
|
||||||
|
WEBUI_SERVER,
|
||||||
|
HTTPX_DEFAULT_TIMEOUT,
|
||||||
|
)
|
||||||
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
||||||
fschat_openai_api_address, set_httpx_timeout,
|
fschat_openai_api_address, set_httpx_timeout,
|
||||||
get_model_worker_config, get_all_model_worker_configs,
|
get_model_worker_config, get_all_model_worker_configs,
|
||||||
|
|
@ -216,7 +225,7 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
|
||||||
@app.post("/release_worker")
|
@app.post("/release_worker")
|
||||||
def release_worker(
|
def release_worker(
|
||||||
model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]),
|
model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]),
|
||||||
# worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[fschat_controller_address()]),
|
# worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[FSCHAT_CONTROLLER_address()]),
|
||||||
new_model_name: str = Body(None, description="释放后加载该模型"),
|
new_model_name: str = Body(None, description="释放后加载该模型"),
|
||||||
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
|
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
|
|
@ -250,7 +259,7 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
|
||||||
return {"code": 500, "msg": msg}
|
return {"code": 500, "msg": msg}
|
||||||
|
|
||||||
if new_model_name:
|
if new_model_name:
|
||||||
timer = HTTPX_DEFAULT_TIMEOUT * 2 # wait for new model_worker register
|
timer = HTTPX_DEFAULT_TIMEOUT # wait for new model_worker register
|
||||||
while timer > 0:
|
while timer > 0:
|
||||||
models = app._controller.list_models()
|
models = app._controller.list_models()
|
||||||
if new_model_name in models:
|
if new_model_name in models:
|
||||||
|
|
@ -297,7 +306,7 @@ def run_model_worker(
|
||||||
kwargs["model_names"] = [model_name]
|
kwargs["model_names"] = [model_name]
|
||||||
kwargs["controller_address"] = controller_address or fschat_controller_address()
|
kwargs["controller_address"] = controller_address or fschat_controller_address()
|
||||||
kwargs["worker_address"] = fschat_model_worker_address(model_name)
|
kwargs["worker_address"] = fschat_model_worker_address(model_name)
|
||||||
model_path = kwargs.get("local_model_path", "")
|
model_path = kwargs.get("model_path", "")
|
||||||
kwargs["model_path"] = model_path
|
kwargs["model_path"] = model_path
|
||||||
|
|
||||||
app = create_model_worker_app(log_level=log_level, **kwargs)
|
app = create_model_worker_app(log_level=log_level, **kwargs)
|
||||||
|
|
@ -418,7 +427,7 @@ def parse_args() -> argparse.ArgumentParser:
|
||||||
"-c",
|
"-c",
|
||||||
"--controller",
|
"--controller",
|
||||||
type=str,
|
type=str,
|
||||||
help="specify controller address the worker is registered to. default is server_config.FSCHAT_CONTROLLER",
|
help="specify controller address the worker is registered to. default is FSCHAT_CONTROLLER",
|
||||||
dest="controller_address",
|
dest="controller_address",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
@ -474,15 +483,14 @@ def dump_server_info(after_start=False, args=None):
|
||||||
print(f"当前启动的LLM模型:{models} @ {llm_device()}")
|
print(f"当前启动的LLM模型:{models} @ {llm_device()}")
|
||||||
|
|
||||||
for model in models:
|
for model in models:
|
||||||
pprint(llm_model_dict[model])
|
pprint(get_model_worker_config(model))
|
||||||
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
|
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
|
||||||
|
|
||||||
if after_start:
|
if after_start:
|
||||||
print("\n")
|
print("\n")
|
||||||
print(f"服务端运行信息:")
|
print(f"服务端运行信息:")
|
||||||
if args.openai_api:
|
if args.openai_api:
|
||||||
print(f" OpenAI API Server: {fschat_openai_api_address()}/v1")
|
print(f" OpenAI API Server: {fschat_openai_api_address()}")
|
||||||
print(" (请确认llm_model_dict中配置的api_base_url与上面地址一致。)")
|
|
||||||
if args.api:
|
if args.api:
|
||||||
print(f" Chatchat API Server: {api_address()}")
|
print(f" Chatchat API Server: {api_address()}")
|
||||||
if args.webui:
|
if args.webui:
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from pathlib import Path
|
||||||
root_path = Path(__file__).parent.parent.parent
|
root_path = Path(__file__).parent.parent.parent
|
||||||
sys.path.append(str(root_path))
|
sys.path.append(str(root_path))
|
||||||
from server.utils import api_address
|
from server.utils import api_address
|
||||||
from configs.model_config import VECTOR_SEARCH_TOP_K
|
from configs import VECTOR_SEARCH_TOP_K
|
||||||
from server.knowledge_base.utils import get_kb_path, get_file_path
|
from server.knowledge_base.utils import get_kb_path, get_file_path
|
||||||
|
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from pathlib import Path
|
||||||
root_path = Path(__file__).parent.parent.parent
|
root_path = Path(__file__).parent.parent.parent
|
||||||
sys.path.append(str(root_path))
|
sys.path.append(str(root_path))
|
||||||
from server.utils import api_address
|
from server.utils import api_address
|
||||||
from configs.model_config import VECTOR_SEARCH_TOP_K
|
from configs import VECTOR_SEARCH_TOP_K
|
||||||
from server.knowledge_base.utils import get_kb_path, get_file_path
|
from server.knowledge_base.utils import get_kb_path, get_file_path
|
||||||
from webui_pages.utils import ApiRequest
|
from webui_pages.utils import ApiRequest
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,21 +6,19 @@ from pathlib import Path
|
||||||
root_path = Path(__file__).parent.parent.parent
|
root_path = Path(__file__).parent.parent.parent
|
||||||
sys.path.append(str(root_path))
|
sys.path.append(str(root_path))
|
||||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||||
from configs.model_config import LLM_MODEL, llm_model_dict
|
from configs.model_config import LLM_MODEL
|
||||||
from server.utils import api_address, get_model_worker_config
|
from server.utils import api_address, get_model_worker_config
|
||||||
|
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
import random
|
import random
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
def get_configured_models():
|
def get_configured_models() -> List[str]:
|
||||||
model_workers = list(FSCHAT_MODEL_WORKERS)
|
model_workers = list(FSCHAT_MODEL_WORKERS)
|
||||||
if "default" in model_workers:
|
if "default" in model_workers:
|
||||||
model_workers.remove("default")
|
model_workers.remove("default")
|
||||||
|
return model_workers
|
||||||
llm_dict = list(llm_model_dict)
|
|
||||||
|
|
||||||
return model_workers, llm_dict
|
|
||||||
|
|
||||||
|
|
||||||
api_base_url = api_address()
|
api_base_url = api_address()
|
||||||
|
|
@ -56,12 +54,9 @@ def test_change_model(api="/llm_model/change"):
|
||||||
running_models = get_running_models()
|
running_models = get_running_models()
|
||||||
assert len(running_models) > 0
|
assert len(running_models) > 0
|
||||||
|
|
||||||
model_workers, llm_dict = get_configured_models()
|
model_workers = get_configured_models()
|
||||||
|
|
||||||
availabel_new_models = set(model_workers) - set(running_models)
|
availabel_new_models = list(set(model_workers) - set(running_models))
|
||||||
if len(availabel_new_models) == 0:
|
|
||||||
availabel_new_models = set(llm_dict) - set(running_models)
|
|
||||||
availabel_new_models = list(availabel_new_models)
|
|
||||||
assert len(availabel_new_models) > 0
|
assert len(availabel_new_models) > 0
|
||||||
print(availabel_new_models)
|
print(availabel_new_models)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||||
from configs.model_config import BING_SUBSCRIPTION_KEY
|
from configs import BING_SUBSCRIPTION_KEY
|
||||||
from server.utils import api_address
|
from server.utils import api_address
|
||||||
|
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from transformers import AutoTokenizer
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.path.append("../..")
|
sys.path.append("../..")
|
||||||
from configs.model_config import (
|
from configs import (
|
||||||
CHUNK_SIZE,
|
CHUNK_SIZE,
|
||||||
OVERLAP_SIZE
|
OVERLAP_SIZE
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,10 @@
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
|
||||||
from webui_pages.utils import *
|
from webui_pages.utils import *
|
||||||
from streamlit_chatbox import *
|
from streamlit_chatbox import *
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from server.chat.search_engine_chat import SEARCH_ENGINES
|
from server.chat.search_engine_chat import SEARCH_ENGINES
|
||||||
import os
|
import os
|
||||||
from configs.model_config import LLM_MODEL, TEMPERATURE
|
from configs import LLM_MODEL, TEMPERATURE
|
||||||
from server.utils import get_model_worker_config
|
from server.utils import get_model_worker_config
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,9 +6,10 @@ import pandas as pd
|
||||||
from server.knowledge_base.utils import get_file_path, LOADER_DICT
|
from server.knowledge_base.utils import get_file_path, LOADER_DICT
|
||||||
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
|
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
|
||||||
from typing import Literal, Dict, Tuple
|
from typing import Literal, Dict, Tuple
|
||||||
from configs.model_config import (embedding_model_dict, kbs_config,
|
from configs import (kbs_config,
|
||||||
EMBEDDING_MODEL, DEFAULT_VS_TYPE,
|
EMBEDDING_MODEL, DEFAULT_VS_TYPE,
|
||||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
||||||
|
from server.utils import list_embed_models
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
@ -94,7 +95,7 @@ def knowledge_base_page(api: ApiRequest):
|
||||||
key="vs_type",
|
key="vs_type",
|
||||||
)
|
)
|
||||||
|
|
||||||
embed_models = list(embedding_model_dict.keys())
|
embed_models = list_embed_models()
|
||||||
|
|
||||||
embed_model = cols[1].selectbox(
|
embed_model = cols[1].selectbox(
|
||||||
"Embedding 模型",
|
"Embedding 模型",
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,11 @@
|
||||||
# 该文件包含webui通用工具,可以被不同的webui使用
|
# 该文件包含webui通用工具,可以被不同的webui使用
|
||||||
from typing import *
|
from typing import *
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from configs.model_config import (
|
from configs import (
|
||||||
EMBEDDING_MODEL,
|
EMBEDDING_MODEL,
|
||||||
DEFAULT_VS_TYPE,
|
DEFAULT_VS_TYPE,
|
||||||
KB_ROOT_PATH,
|
KB_ROOT_PATH,
|
||||||
LLM_MODEL,
|
LLM_MODEL,
|
||||||
llm_model_dict,
|
|
||||||
HISTORY_LEN,
|
HISTORY_LEN,
|
||||||
TEMPERATURE,
|
TEMPERATURE,
|
||||||
SCORE_THRESHOLD,
|
SCORE_THRESHOLD,
|
||||||
|
|
@ -15,9 +14,10 @@ from configs.model_config import (
|
||||||
ZH_TITLE_ENHANCE,
|
ZH_TITLE_ENHANCE,
|
||||||
VECTOR_SEARCH_TOP_K,
|
VECTOR_SEARCH_TOP_K,
|
||||||
SEARCH_ENGINE_TOP_K,
|
SEARCH_ENGINE_TOP_K,
|
||||||
|
FSCHAT_MODEL_WORKERS,
|
||||||
|
HTTPX_DEFAULT_TIMEOUT,
|
||||||
logger, log_verbose,
|
logger, log_verbose,
|
||||||
)
|
)
|
||||||
from configs.server_config import HTTPX_DEFAULT_TIMEOUT
|
|
||||||
import httpx
|
import httpx
|
||||||
import asyncio
|
import asyncio
|
||||||
from server.chat.openai_chat import OpenAiChatMsgIn
|
from server.chat.openai_chat import OpenAiChatMsgIn
|
||||||
|
|
@ -779,7 +779,10 @@ class ApiRequest:
|
||||||
'''
|
'''
|
||||||
获取configs中配置的模型列表
|
获取configs中配置的模型列表
|
||||||
'''
|
'''
|
||||||
return list(llm_model_dict.keys())
|
models = list(FSCHAT_MODEL_WORKERS.keys())
|
||||||
|
if "default" in models:
|
||||||
|
models.remove("default")
|
||||||
|
return models
|
||||||
|
|
||||||
def stop_llm_model(
|
def stop_llm_model(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue