优化configs (#1474)
* remove llm_model_dict * optimize configs * fix get_model_path * 更改一些默认参数,添加千帆的默认配置 * Update server_config.py.example
This commit is contained in:
parent
456229c13f
commit
f7c73b842a
|
|
@ -1,19 +1,12 @@
|
|||
from langchain.chat_models import ChatOpenAI
|
||||
from configs.model_config import llm_model_dict, LLM_MODEL
|
||||
from server.chat.utils import get_ChatOpenAI
|
||||
from configs.model_config import LLM_MODEL, TEMPERATURE
|
||||
from langchain import LLMChain
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
)
|
||||
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
# callbacks=[callback],
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL
|
||||
)
|
||||
model = get_ChatOpenAI(model_name=LLM_MODEL, temperature=TEMPERATURE)
|
||||
|
||||
|
||||
human_prompt = "{input}"
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
from .basic_config import *
|
||||
from .model_config import *
|
||||
from .kb_config import *
|
||||
from .server_config import *
|
||||
|
||||
VERSION = "v0.2.4"
|
||||
|
||||
VERSION = "v0.2.5-preview"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,21 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
|
||||
# 是否显示详细日志
|
||||
log_verbose = False
|
||||
|
||||
|
||||
# 通常情况下不需要更改以下内容
|
||||
|
||||
# 日志格式
|
||||
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
logging.basicConfig(format=LOG_FORMAT)
|
||||
|
||||
|
||||
# 日志存储路径
|
||||
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
|
||||
if not os.path.exists(LOG_PATH):
|
||||
os.mkdir(LOG_PATH)
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
import os
|
||||
|
||||
|
||||
# 默认向量库类型。可选:faiss, milvus, pg.
|
||||
DEFAULT_VS_TYPE = "faiss"
|
||||
|
||||
# 缓存向量库数量(针对FAISS)
|
||||
CACHED_VS_NUM = 1
|
||||
|
||||
# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
|
||||
CHUNK_SIZE = 500
|
||||
|
||||
# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)
|
||||
OVERLAP_SIZE = 50
|
||||
|
||||
# 知识库匹配向量数量
|
||||
VECTOR_SEARCH_TOP_K = 3
|
||||
|
||||
# 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右
|
||||
SCORE_THRESHOLD = 1
|
||||
|
||||
# 搜索引擎匹配结题数量
|
||||
SEARCH_ENGINE_TOP_K = 3
|
||||
|
||||
|
||||
# 基于本地知识问答的提示词模版(使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号
|
||||
PROMPT_TEMPLATE = """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
|
||||
|
||||
<已知信息>{{ context }}</已知信息>
|
||||
|
||||
<问题>{{ question }}</问题>"""
|
||||
|
||||
|
||||
# Bing 搜索必备变量
|
||||
# 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search
|
||||
# 具体申请方式请见
|
||||
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/create-bing-search-service-resource
|
||||
# 使用python创建bing api 搜索实例详见:
|
||||
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/quickstarts/rest/python
|
||||
BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
|
||||
# 注意不是bing Webmaster Tools的api key,
|
||||
|
||||
# 此外,如果是在服务器上,报Failed to establish a new connection: [Errno 110] Connection timed out
|
||||
# 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG
|
||||
BING_SUBSCRIPTION_KEY = ""
|
||||
|
||||
# 是否开启中文标题加强,以及标题增强的相关配置
|
||||
# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记;
|
||||
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
|
||||
ZH_TITLE_ENHANCE = False
|
||||
|
||||
|
||||
# 通常情况下不需要更改以下内容
|
||||
|
||||
# 知识库默认存储路径
|
||||
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
|
||||
if not os.path.exists(KB_ROOT_PATH):
|
||||
os.mkdir(KB_ROOT_PATH)
|
||||
|
||||
# 数据库默认存储路径。
|
||||
# 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。
|
||||
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
|
||||
SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}"
|
||||
|
||||
# 可选向量库类型及对应配置
|
||||
kbs_config = {
|
||||
"faiss": {
|
||||
},
|
||||
"milvus": {
|
||||
"host": "127.0.0.1",
|
||||
"port": "19530",
|
||||
"user": "",
|
||||
"password": "",
|
||||
"secure": False,
|
||||
},
|
||||
"pg": {
|
||||
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
|
||||
}
|
||||
}
|
||||
|
||||
# TextSplitter配置项,如果你不明白其中的含义,就不要修改。
|
||||
text_splitter_dict = {
|
||||
"ChineseRecursiveTextSplitter": {
|
||||
"source": "huggingface", ## 选择tiktoken则使用openai的方法
|
||||
"tokenizer_name_or_path": "gpt2",
|
||||
},
|
||||
"SpacyTextSplitter": {
|
||||
"source": "huggingface",
|
||||
"tokenizer_name_or_path": "",
|
||||
},
|
||||
"RecursiveCharacterTextSplitter": {
|
||||
"source": "tiktoken",
|
||||
"tokenizer_name_or_path": "cl100k_base",
|
||||
},
|
||||
"MarkdownHeaderTextSplitter": {
|
||||
"headers_to_split_on":
|
||||
[
|
||||
("#", "head1"),
|
||||
("##", "head2"),
|
||||
("###", "head3"),
|
||||
("####", "head4"),
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
# TEXT_SPLITTER 名称
|
||||
TEXT_SPLITTER_NAME = "SpacyTextSplitter"
|
||||
|
|
@ -1,18 +1,20 @@
|
|||
import os
|
||||
import logging
|
||||
# 日志格式
|
||||
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
logging.basicConfig(format=LOG_FORMAT)
|
||||
# 是否显示详细日志
|
||||
log_verbose = False
|
||||
|
||||
|
||||
# 在以下字典中修改属性值,以指定本地embedding模型存储位置
|
||||
# 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese"
|
||||
# 此处请写绝对路径
|
||||
embedding_model_dict = {
|
||||
# 可以指定一个绝对路径,统一存放所有的Embedding和LLM模型。
|
||||
# 每个模型可以是一个单独的目录,也可以是某个目录下的二级子目录
|
||||
MODEL_ROOT_PATH = ""
|
||||
|
||||
# 在以下字典中修改属性值,以指定本地embedding模型存储位置。支持3种设置方法:
|
||||
# 1、将对应的值修改为模型绝对路径
|
||||
# 2、不修改此处的值(以 text2vec 为例):
|
||||
# 2.1 如果{MODEL_ROOT_PATH}下存在如下任一子目录:
|
||||
# - text2vec
|
||||
# - GanymedeNil/text2vec-large-chinese
|
||||
# - text2vec-large-chinese
|
||||
# 2.2 如果以上本地路径不存在,则使用huggingface模型
|
||||
MODEL_PATH = {
|
||||
"embed_model": {
|
||||
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
|
||||
"ernie-base": "nghuyong/ernie-3.0-base-zh",
|
||||
"text2vec-base": "shibing624/text2vec-base-chinese",
|
||||
|
|
@ -30,7 +32,15 @@ embedding_model_dict = {
|
|||
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
|
||||
"piccolo-base-zh": "sensenova/piccolo-base-zh",
|
||||
"piccolo-large-zh": "sensenova/piccolo-large-zh",
|
||||
"text-embedding-ada-002": os.environ.get("OPENAI_API_KEY")
|
||||
"text-embedding-ada-002": "your OPENAI_API_KEY",
|
||||
},
|
||||
# TODO: add all supported llm models
|
||||
"llm_model": {
|
||||
"chatglm-6b": "THUDM/chatglm-6b",
|
||||
"chatglm2-6b": "THUDM/chatglm2-6b",
|
||||
"chatglm2-6b-int4": "THUDM/chatglm2-6b-int4",
|
||||
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
|
||||
},
|
||||
}
|
||||
|
||||
# 选用的 Embedding 名称
|
||||
|
|
@ -39,25 +49,21 @@ EMBEDDING_MODEL = "m3e-base"
|
|||
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
|
||||
EMBEDDING_DEVICE = "auto"
|
||||
|
||||
llm_model_dict = {
|
||||
"chatglm-6b": {
|
||||
"local_model_path": "THUDM/chatglm-6b",
|
||||
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
|
||||
"api_key": "EMPTY"
|
||||
},
|
||||
# LLM 名称
|
||||
LLM_MODEL = "chatglm2-6b"
|
||||
|
||||
"chatglm2-6b": {
|
||||
"local_model_path": "THUDM/chatglm2-6b",
|
||||
"api_base_url": "http://localhost:8888/v1", # URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
|
||||
"api_key": "EMPTY"
|
||||
},
|
||||
# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
|
||||
LLM_DEVICE = "auto"
|
||||
|
||||
"chatglm2-6b-32k": {
|
||||
"local_model_path": "THUDM/chatglm2-6b-32k", # "THUDM/chatglm2-6b-32k",
|
||||
"api_base_url": "http://localhost:8888/v1", # "URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
|
||||
"api_key": "EMPTY"
|
||||
},
|
||||
# 历史对话轮数
|
||||
HISTORY_LEN = 3
|
||||
|
||||
# LLM通用对话参数
|
||||
TEMPERATURE = 0.7
|
||||
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
|
||||
|
||||
|
||||
ONLINE_LLM_MODEL = {
|
||||
# 调用chatgpt时如果报出: urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
|
||||
# Max retries exceeded with url: /v1/chat/completions
|
||||
# 则需要将urllib3版本修改为1.25.11
|
||||
|
|
@ -74,29 +80,25 @@ llm_model_dict = {
|
|||
# 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置
|
||||
# 比如: "openai_proxy": 'http://127.0.0.1:4780'
|
||||
"gpt-3.5-turbo": {
|
||||
"api_base_url": "https://api.openai.com/v1",
|
||||
"api_key": "",
|
||||
"openai_proxy": ""
|
||||
"api_key": "your OPENAI_API_KEY",
|
||||
"openai_proxy": "your OPENAI_PROXY",
|
||||
},
|
||||
# 线上模型。当前支持智谱AI。
|
||||
# 如果没有设置有效的local_model_path,则认为是在线模型API。
|
||||
# 请在server_config中为每个在线API设置不同的端口
|
||||
# 线上模型。请在server_config中为每个在线API设置不同的端口
|
||||
# 具体注册及api key获取请前往 http://open.bigmodel.cn
|
||||
"zhipu-api": {
|
||||
"api_base_url": "http://127.0.0.1:8888/v1",
|
||||
"api_key": "",
|
||||
"provider": "ChatGLMWorker",
|
||||
"version": "chatglm_pro", # 可选包括 "chatglm_lite", "chatglm_std", "chatglm_pro"
|
||||
"provider": "ChatGLMWorker",
|
||||
},
|
||||
# 具体注册及api key获取请前往 https://api.minimax.chat/
|
||||
"minimax-api": {
|
||||
"api_base_url": "http://127.0.0.1:8888/v1",
|
||||
"group_id": "",
|
||||
"api_key": "",
|
||||
"is_pro": False,
|
||||
"provider": "MiniMaxWorker",
|
||||
},
|
||||
# 具体注册及api key获取请前往 https://xinghuo.xfyun.cn/
|
||||
"xinghuo-api": {
|
||||
"api_base_url": "http://127.0.0.1:8888/v1",
|
||||
"APPID": "",
|
||||
"APISecret": "",
|
||||
"api_key": "",
|
||||
|
|
@ -105,140 +107,16 @@ llm_model_dict = {
|
|||
},
|
||||
# 百度千帆 API,申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf
|
||||
"qianfan-api": {
|
||||
"version": "ernie-bot", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo", 更多的见文档模型支持列表中千帆部分。
|
||||
"version_url": "", # 可以不填写version,直接填写在千帆申请模型发布的API地址
|
||||
"api_base_url": "http://127.0.0.1:8888/v1",
|
||||
"version": "ernie-bot-turbo", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo", 更多的见官方文档。
|
||||
"version_url": "", # 也可以不填写version,直接填写在千帆申请模型发布的API地址
|
||||
"api_key": "",
|
||||
"secret_key": "",
|
||||
"provider": "ErnieWorker",
|
||||
}
|
||||
}
|
||||
|
||||
# LLM 名称
|
||||
LLM_MODEL = "chatglm2-6b"
|
||||
|
||||
# 历史对话轮数
|
||||
HISTORY_LEN = 3
|
||||
|
||||
# LLM通用对话参数
|
||||
TEMPERATURE = 0.7
|
||||
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
|
||||
|
||||
|
||||
# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
|
||||
LLM_DEVICE = "auto"
|
||||
|
||||
# TextSplitter
|
||||
|
||||
text_splitter_dict = {
|
||||
"ChineseRecursiveTextSplitter": {
|
||||
"source": "",
|
||||
"tokenizer_name_or_path": "",
|
||||
},
|
||||
"SpacyTextSplitter": {
|
||||
"source": "huggingface",
|
||||
"tokenizer_name_or_path": "gpt2",
|
||||
},
|
||||
"RecursiveCharacterTextSplitter": {
|
||||
"source": "tiktoken",
|
||||
"tokenizer_name_or_path": "cl100k_base",
|
||||
},
|
||||
|
||||
"MarkdownHeaderTextSplitter": {
|
||||
"headers_to_split_on":
|
||||
[
|
||||
("#", "head1"),
|
||||
("##", "head2"),
|
||||
("###", "head3"),
|
||||
("####", "head4"),
|
||||
]
|
||||
"provider": "QianFanWorker",
|
||||
},
|
||||
}
|
||||
|
||||
# TEXT_SPLITTER 名称
|
||||
TEXT_SPLITTER = "ChineseRecursiveTextSplitter"
|
||||
|
||||
# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
|
||||
CHUNK_SIZE = 250
|
||||
|
||||
# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)
|
||||
OVERLAP_SIZE = 0
|
||||
|
||||
|
||||
# 日志存储路径
|
||||
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
|
||||
if not os.path.exists(LOG_PATH):
|
||||
os.mkdir(LOG_PATH)
|
||||
|
||||
# 知识库默认存储路径
|
||||
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
|
||||
if not os.path.exists(KB_ROOT_PATH):
|
||||
os.mkdir(KB_ROOT_PATH)
|
||||
# 数据库默认存储路径。
|
||||
# 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。
|
||||
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
|
||||
SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}"
|
||||
|
||||
|
||||
# 可选向量库类型及对应配置
|
||||
kbs_config = {
|
||||
"faiss": {
|
||||
},
|
||||
"milvus": {
|
||||
"host": "127.0.0.1",
|
||||
"port": "19530",
|
||||
"user": "",
|
||||
"password": "",
|
||||
"secure": False,
|
||||
},
|
||||
"pg": {
|
||||
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
|
||||
}
|
||||
}
|
||||
|
||||
# 默认向量库类型。可选:faiss, milvus, pg.
|
||||
DEFAULT_VS_TYPE = "faiss"
|
||||
|
||||
# 缓存向量库数量
|
||||
CACHED_VS_NUM = 1
|
||||
|
||||
# 知识库匹配向量数量
|
||||
VECTOR_SEARCH_TOP_K = 3
|
||||
|
||||
# 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右
|
||||
SCORE_THRESHOLD = 1
|
||||
|
||||
# 搜索引擎匹配结题数量
|
||||
SEARCH_ENGINE_TOP_K = 3
|
||||
# 通常情况下不需要更改以下内容
|
||||
|
||||
# nltk 模型存储路径
|
||||
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
|
||||
|
||||
# 基于本地知识问答的提示词模版(使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号
|
||||
PROMPT_TEMPLATE = """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
|
||||
|
||||
<已知信息>{{ context }}</已知信息>
|
||||
|
||||
<问题>{{ question }}</问题>"""
|
||||
|
||||
# API 是否开启跨域,默认为False,如果需要开启,请设置为True
|
||||
# is open cross domain
|
||||
OPEN_CROSS_DOMAIN = False
|
||||
|
||||
# Bing 搜索必备变量
|
||||
# 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search
|
||||
# 具体申请方式请见
|
||||
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/create-bing-search-service-resource
|
||||
# 使用python创建bing api 搜索实例详见:
|
||||
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/quickstarts/rest/python
|
||||
BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
|
||||
# 注意不是bing Webmaster Tools的api key,
|
||||
|
||||
# 此外,如果是在服务器上,报Failed to establish a new connection: [Errno 110] Connection timed out
|
||||
# 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG
|
||||
BING_SUBSCRIPTION_KEY = ""
|
||||
|
||||
# 是否开启中文标题加强,以及标题增强的相关配置
|
||||
# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记;
|
||||
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
|
||||
ZH_TITLE_ENHANCE = False
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from .model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE
|
||||
from configs.model_config import LLM_DEVICE
|
||||
import httpx
|
||||
|
||||
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
|
||||
|
|
@ -8,7 +8,7 @@ HTTPX_DEFAULT_TIMEOUT = 300.0
|
|||
# is open cross domain
|
||||
OPEN_CROSS_DOMAIN = False
|
||||
|
||||
# 各服务器默认绑定host
|
||||
# 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host
|
||||
DEFAULT_BIND_HOST = "127.0.0.1"
|
||||
|
||||
# webui.py server
|
||||
|
|
@ -26,14 +26,14 @@ API_SERVER = {
|
|||
# fastchat openai_api server
|
||||
FSCHAT_OPENAI_API = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 8888, # model_config.llm_model_dict中模型配置的api_base_url需要与这里一致。
|
||||
"port": 20000,
|
||||
}
|
||||
|
||||
# fastchat model_worker server
|
||||
# 这些模型必须是在model_config.llm_model_dict中正确配置的。
|
||||
# 这些模型必须是在model_config.MODEL_PATH或ONLINE_MODEL中正确配置的。
|
||||
# 在启动startup.py时,可用通过`--model-worker --model-name xxxx`指定模型,不指定则为LLM_MODEL
|
||||
FSCHAT_MODEL_WORKERS = {
|
||||
# 所有模型共用的默认配置,可在模型专项配置或llm_model_dict中进行覆盖。
|
||||
# 所有模型共用的默认配置,可在模型专项配置中进行覆盖。
|
||||
"default": {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 20002,
|
||||
|
|
@ -64,17 +64,17 @@ FSCHAT_MODEL_WORKERS = {
|
|||
"baichuan-7b": { # 使用default中的IP和端口
|
||||
"device": "cpu",
|
||||
},
|
||||
"zhipu-api": { # 请为每个在线API设置不同的端口
|
||||
"port": 20003,
|
||||
"zhipu-api": { # 请为每个要运行的在线API设置不同的端口
|
||||
"port": 21001,
|
||||
},
|
||||
"minimax-api": { # 请为每个在线API设置不同的端口
|
||||
"port": 20004,
|
||||
"minimax-api": {
|
||||
"port": 21002,
|
||||
},
|
||||
"xinghuo-api": { # 请为每个在线API设置不同的端口
|
||||
"port": 20005,
|
||||
"xinghuo-api": {
|
||||
"port": 21003,
|
||||
},
|
||||
"qianfan-api": {
|
||||
"port": 20006,
|
||||
"port": 21004,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs.model_config import llm_model_dict, LLM_MODEL, TEMPERATURE
|
||||
from server.chat.utils import wrap_done
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from configs.model_config import LLM_MODEL, TEMPERATURE
|
||||
from server.chat.utils import wrap_done, get_ChatOpenAI
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable
|
||||
|
|
@ -31,18 +30,11 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||
model_name: str = LLM_MODEL,
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callbacks=[callback],
|
||||
openai_api_key=llm_model_dict[model_name]["api_key"],
|
||||
openai_api_base=llm_model_dict[model_name]["api_base_url"],
|
||||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
|
||||
callbacks=[callback],
|
||||
)
|
||||
|
||||
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
from fastapi import Body, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||||
from configs import (LLM_MODEL, PROMPT_TEMPLATE,
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
TEMPERATURE)
|
||||
from server.chat.utils import wrap_done
|
||||
from server.chat.utils import wrap_done, get_ChatOpenAI
|
||||
from server.utils import BaseResponse
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable, List, Optional
|
||||
|
|
@ -50,16 +49,10 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
model_name: str = LLM_MODEL,
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callbacks=[callback],
|
||||
openai_api_key=llm_model_dict[model_name]["api_key"],
|
||||
openai_api_base=llm_model_dict[model_name]["api_base_url"],
|
||||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
|
||||
callbacks=[callback],
|
||||
)
|
||||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from fastapi.responses import StreamingResponse
|
||||
from typing import List
|
||||
import openai
|
||||
from configs.model_config import llm_model_dict, LLM_MODEL, logger, log_verbose
|
||||
from configs import LLM_MODEL, logger, log_verbose
|
||||
from server.utils import get_model_worker_config, fschat_openai_api_address
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
|
@ -23,9 +24,10 @@ class OpenAiChatMsgIn(BaseModel):
|
|||
|
||||
|
||||
async def openai_chat(msg: OpenAiChatMsgIn):
|
||||
openai.api_key = llm_model_dict[LLM_MODEL]["api_key"]
|
||||
config = get_model_worker_config(msg.model)
|
||||
openai.api_key = config.get("api_key", "EMPTY")
|
||||
print(f"{openai.api_key=}")
|
||||
openai.api_base = llm_model_dict[LLM_MODEL]["api_base_url"]
|
||||
openai.api_base = fschat_openai_api_address()
|
||||
print(f"{openai.api_base=}")
|
||||
print(msg)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,12 @@
|
|||
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
|
||||
from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY
|
||||
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY,
|
||||
LLM_MODEL, SEARCH_ENGINE_TOP_K,
|
||||
PROMPT_TEMPLATE, TEMPERATURE)
|
||||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K,
|
||||
PROMPT_TEMPLATE, TEMPERATURE)
|
||||
from server.chat.utils import wrap_done
|
||||
from server.chat.utils import wrap_done, get_ChatOpenAI
|
||||
from server.utils import BaseResponse
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable
|
||||
|
|
@ -90,15 +89,10 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
|||
model_name: str = LLM_MODEL,
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callbacks=[callback],
|
||||
openai_api_key=llm_model_dict[model_name]["api_key"],
|
||||
openai_api_base=llm_model_dict[model_name]["api_base_url"],
|
||||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
|
||||
callbacks=[callback],
|
||||
)
|
||||
|
||||
docs = await lookup_search_engine(query, search_engine_name, top_k)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,29 @@
|
|||
import asyncio
|
||||
from typing import Awaitable, List, Tuple, Dict, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.prompts.chat import ChatMessagePromptTemplate
|
||||
from configs import logger, log_verbose
|
||||
from server.utils import get_model_worker_config, fschat_openai_api_address
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from typing import Awaitable, List, Tuple, Dict, Union, Callable
|
||||
|
||||
|
||||
def get_ChatOpenAI(
|
||||
model_name: str,
|
||||
temperature: float,
|
||||
callbacks: List[Callable] = [],
|
||||
) -> ChatOpenAI:
|
||||
config = get_model_worker_config(model_name)
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callbacks=callbacks,
|
||||
openai_api_key=config.get("api_key", "EMPTY"),
|
||||
openai_api_base=fschat_openai_api_address(),
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
openai_proxy=config.get("openai_proxy")
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from sqlalchemy import create_engine
|
|||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs.model_config import SQLALCHEMY_DATABASE_URI
|
||||
from configs import SQLALCHEMY_DATABASE_URI
|
||||
import json
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from server.utils import BaseResponse, ListResponse
|
|||
from server.knowledge_base.utils import validate_kb_name
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.db.repository.knowledge_base_repository import list_kbs_from_db
|
||||
from configs.model_config import EMBEDDING_MODEL, logger, log_verbose
|
||||
from configs import EMBEDDING_MODEL, logger, log_verbose
|
||||
from fastapi import Body
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@ from langchain.embeddings import HuggingFaceBgeEmbeddings
|
|||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
import threading
|
||||
from configs.model_config import (CACHED_VS_NUM, EMBEDDING_MODEL, CHUNK_SIZE,
|
||||
embedding_model_dict, logger, log_verbose)
|
||||
from server.utils import embedding_device
|
||||
from configs import (EMBEDDING_MODEL, CHUNK_SIZE, CACHED_VS_NUM,
|
||||
logger, log_verbose)
|
||||
from server.utils import embedding_device, get_model_path
|
||||
from contextlib import contextmanager
|
||||
from collections import OrderedDict
|
||||
from typing import List, Any, Union, Tuple
|
||||
|
|
@ -118,15 +118,15 @@ class EmbeddingsPool(CachePool):
|
|||
with item.acquire(msg="初始化"):
|
||||
self.atomic.release()
|
||||
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
|
||||
embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE)
|
||||
embeddings = OpenAIEmbeddings(openai_api_key=get_model_path(model), chunk_size=CHUNK_SIZE)
|
||||
elif 'bge-' in model:
|
||||
embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model],
|
||||
embeddings = HuggingFaceBgeEmbeddings(model_name=get_model_path(model),
|
||||
model_kwargs={'device': device},
|
||||
query_instruction="为这个句子生成表示以用于检索相关文章:")
|
||||
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
|
||||
embeddings.query_instruction = ""
|
||||
else:
|
||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device})
|
||||
embeddings = HuggingFaceEmbeddings(model_name=get_model_path(model), model_kwargs={'device': device})
|
||||
item.obj = embeddings
|
||||
item.finish_loading()
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import urllib
|
||||
from fastapi import File, Form, Body, Query, UploadFile
|
||||
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
|
||||
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
|
||||
logger, log_verbose,)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from server.db.repository.knowledge_file_repository import (
|
|||
list_docs_from_db,
|
||||
)
|
||||
|
||||
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
EMBEDDING_MODEL)
|
||||
from server.knowledge_base.utils import (
|
||||
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from configs.model_config import (
|
||||
from configs import (
|
||||
KB_ROOT_PATH,
|
||||
SCORE_THRESHOLD,
|
||||
logger, log_verbose,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from langchain.schema import Document
|
|||
from langchain.vectorstores import Milvus
|
||||
from sklearn.preprocessing import normalize
|
||||
|
||||
from configs.model_config import SCORE_THRESHOLD, kbs_config
|
||||
from configs import SCORE_THRESHOLD, kbs_config
|
||||
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
|
||||
score_threshold_process
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from langchain.vectorstores import PGVector
|
|||
from langchain.vectorstores.pgvector import DistanceStrategy
|
||||
from sqlalchemy import text
|
||||
|
||||
from configs.model_config import EMBEDDING_DEVICE, kbs_config
|
||||
from configs import kbs_config
|
||||
|
||||
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
|
||||
score_threshold_process
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from configs.model_config import (EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
|
||||
from configs import (EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
|
||||
logger, log_verbose)
|
||||
from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
|
||||
list_files_from_folder,files2docs_in_thread,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import os
|
|||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from configs.model_config import (
|
||||
from configs import (
|
||||
EMBEDDING_MODEL,
|
||||
KB_ROOT_PATH,
|
||||
CHUNK_SIZE,
|
||||
|
|
@ -23,7 +23,7 @@ from langchain.text_splitter import TextSplitter
|
|||
from pathlib import Path
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from server.utils import run_in_thread_pool, embedding_device
|
||||
from server.utils import run_in_thread_pool, embedding_device, get_model_worker_config
|
||||
import io
|
||||
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
||||
|
||||
|
|
@ -185,6 +185,7 @@ def make_text_splitter(
|
|||
splitter_name: str = TEXT_SPLITTER,
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
llm_model: str = LLM_MODEL,
|
||||
):
|
||||
"""
|
||||
根据参数获取特定的分词器
|
||||
|
|
@ -220,8 +221,9 @@ def make_text_splitter(
|
|||
)
|
||||
elif text_splitter_dict[splitter_name]["source"] == "huggingface": ## 从huggingface加载
|
||||
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "":
|
||||
config = get_model_worker_config(llm_model)
|
||||
text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \
|
||||
llm_model_dict[LLM_MODEL]["local_model_path"]
|
||||
config.get("model_path")
|
||||
|
||||
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2":
|
||||
from transformers import GPT2TokenizerFast
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from configs.model_config import LOG_PATH
|
||||
from configs.basic_config import LOG_PATH
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.model_worker import BaseModelWorker
|
||||
|
|
|
|||
|
|
@ -4,8 +4,10 @@ from typing import List
|
|||
from fastapi import FastAPI
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDING_DEVICE, logger, log_verbose
|
||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
|
||||
MODEL_PATH, MODEL_ROOT_PATH,
|
||||
logger, log_verbose,
|
||||
FSCHAT_MODEL_WORKERS)
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Literal, Optional, Callable, Generator, Dict, Any
|
||||
|
|
@ -197,22 +199,56 @@ def MakeFastAPIOffline(
|
|||
)
|
||||
|
||||
|
||||
# 从model_config中获取模型信息
|
||||
def list_embed_models() -> List[str]:
|
||||
return list(MODEL_PATH["embed_model"])
|
||||
|
||||
def list_llm_models() -> List[str]:
|
||||
return list(MODEL_PATH["llm_model"])
|
||||
|
||||
def get_model_path(model_name: str, type: str = None) -> Optional[str]:
|
||||
if type in MODEL_PATH:
|
||||
paths = MODEL_PATH[type]
|
||||
else:
|
||||
paths = {}
|
||||
for v in MODEL_PATH.values():
|
||||
paths.update(v)
|
||||
|
||||
if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径
|
||||
path = Path(path_str)
|
||||
if path.is_dir(): # 任意绝对路径
|
||||
return str(path)
|
||||
|
||||
root_path = Path(MODEL_ROOT_PATH)
|
||||
if root_path.is_dir():
|
||||
path = root_path / model_name
|
||||
if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b
|
||||
return str(path)
|
||||
path = root_path / path_str
|
||||
if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new
|
||||
return str(path)
|
||||
path = root_path / path_str.split("/")[-1]
|
||||
if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new
|
||||
return str(path)
|
||||
return path_str # THUDM/chatglm06b
|
||||
|
||||
|
||||
# 从server_config中获取服务信息
|
||||
def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
|
||||
def get_model_worker_config(model_name: str = None) -> dict:
|
||||
'''
|
||||
加载model worker的配置项。
|
||||
优先级:FSCHAT_MODEL_WORKERS[model_name] > llm_model_dict[model_name] > FSCHAT_MODEL_WORKERS["default"]
|
||||
优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"]
|
||||
'''
|
||||
from configs.model_config import ONLINE_LLM_MODEL
|
||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||
from server import model_workers
|
||||
from configs.model_config import llm_model_dict
|
||||
|
||||
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
||||
config.update(llm_model_dict.get(model_name, {}))
|
||||
config.update(ONLINE_LLM_MODEL.get(model_name, {}))
|
||||
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
|
||||
|
||||
# 如果没有设置有效的local_model_path,则认为是在线模型API
|
||||
if not os.path.isdir(config.get("local_model_path", "")):
|
||||
# 在线模型API
|
||||
if model_name in ONLINE_LLM_MODEL:
|
||||
config["online_api"] = True
|
||||
if provider := config.get("provider"):
|
||||
try:
|
||||
|
|
@ -222,13 +258,14 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
|
|||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
|
||||
config["device"] = llm_device(config.get("device") or LLM_DEVICE)
|
||||
config["model_path"] = get_model_path(model_name)
|
||||
config["device"] = llm_device(config.get("device"))
|
||||
return config
|
||||
|
||||
|
||||
def get_all_model_worker_configs() -> dict:
|
||||
result = {}
|
||||
model_names = set(llm_model_dict.keys()) | set(FSCHAT_MODEL_WORKERS.keys())
|
||||
model_names = set(FSCHAT_MODEL_WORKERS.keys())
|
||||
for name in model_names:
|
||||
if name != "default":
|
||||
result[name] = get_model_worker_config(name)
|
||||
|
|
@ -256,7 +293,7 @@ def fschat_openai_api_address() -> str:
|
|||
|
||||
host = FSCHAT_OPENAI_API["host"]
|
||||
port = FSCHAT_OPENAI_API["port"]
|
||||
return f"http://{host}:{port}"
|
||||
return f"http://{host}:{port}/v1"
|
||||
|
||||
|
||||
def api_address() -> str:
|
||||
|
|
@ -302,13 +339,15 @@ def detect_device() -> Literal["cuda", "mps", "cpu"]:
|
|||
return "cpu"
|
||||
|
||||
|
||||
def llm_device(device: str = LLM_DEVICE) -> Literal["cuda", "mps", "cpu"]:
|
||||
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
||||
device = device or LLM_DEVICE
|
||||
if device not in ["cuda", "mps", "cpu"]:
|
||||
device = detect_device()
|
||||
return device
|
||||
|
||||
|
||||
def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "cpu"]:
|
||||
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
||||
device = device or EMBEDDING_DEVICE
|
||||
if device not in ["cuda", "mps", "cpu"]:
|
||||
device = detect_device()
|
||||
return device
|
||||
|
|
|
|||
30
startup.py
30
startup.py
|
|
@ -17,10 +17,19 @@ except:
|
|||
pass
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
|
||||
logger, log_verbose, TEXT_SPLITTER
|
||||
from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER,
|
||||
FSCHAT_OPENAI_API, HTTPX_DEFAULT_TIMEOUT)
|
||||
from configs import (
|
||||
LOG_PATH,
|
||||
log_verbose,
|
||||
logger,
|
||||
LLM_MODEL,
|
||||
EMBEDDING_MODEL,
|
||||
TEXT_SPLITTER_NAME,
|
||||
FSCHAT_CONTROLLER,
|
||||
FSCHAT_OPENAI_API,
|
||||
API_SERVER,
|
||||
WEBUI_SERVER,
|
||||
HTTPX_DEFAULT_TIMEOUT,
|
||||
)
|
||||
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
||||
fschat_openai_api_address, set_httpx_timeout,
|
||||
get_model_worker_config, get_all_model_worker_configs,
|
||||
|
|
@ -216,7 +225,7 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
|
|||
@app.post("/release_worker")
|
||||
def release_worker(
|
||||
model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]),
|
||||
# worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[fschat_controller_address()]),
|
||||
# worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[FSCHAT_CONTROLLER_address()]),
|
||||
new_model_name: str = Body(None, description="释放后加载该模型"),
|
||||
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
|
||||
) -> Dict:
|
||||
|
|
@ -250,7 +259,7 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
|
|||
return {"code": 500, "msg": msg}
|
||||
|
||||
if new_model_name:
|
||||
timer = HTTPX_DEFAULT_TIMEOUT * 2 # wait for new model_worker register
|
||||
timer = HTTPX_DEFAULT_TIMEOUT # wait for new model_worker register
|
||||
while timer > 0:
|
||||
models = app._controller.list_models()
|
||||
if new_model_name in models:
|
||||
|
|
@ -297,7 +306,7 @@ def run_model_worker(
|
|||
kwargs["model_names"] = [model_name]
|
||||
kwargs["controller_address"] = controller_address or fschat_controller_address()
|
||||
kwargs["worker_address"] = fschat_model_worker_address(model_name)
|
||||
model_path = kwargs.get("local_model_path", "")
|
||||
model_path = kwargs.get("model_path", "")
|
||||
kwargs["model_path"] = model_path
|
||||
|
||||
app = create_model_worker_app(log_level=log_level, **kwargs)
|
||||
|
|
@ -418,7 +427,7 @@ def parse_args() -> argparse.ArgumentParser:
|
|||
"-c",
|
||||
"--controller",
|
||||
type=str,
|
||||
help="specify controller address the worker is registered to. default is server_config.FSCHAT_CONTROLLER",
|
||||
help="specify controller address the worker is registered to. default is FSCHAT_CONTROLLER",
|
||||
dest="controller_address",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
|
@ -474,15 +483,14 @@ def dump_server_info(after_start=False, args=None):
|
|||
print(f"当前启动的LLM模型:{models} @ {llm_device()}")
|
||||
|
||||
for model in models:
|
||||
pprint(llm_model_dict[model])
|
||||
pprint(get_model_worker_config(model))
|
||||
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
|
||||
|
||||
if after_start:
|
||||
print("\n")
|
||||
print(f"服务端运行信息:")
|
||||
if args.openai_api:
|
||||
print(f" OpenAI API Server: {fschat_openai_api_address()}/v1")
|
||||
print(" (请确认llm_model_dict中配置的api_base_url与上面地址一致。)")
|
||||
print(f" OpenAI API Server: {fschat_openai_api_address()}")
|
||||
if args.api:
|
||||
print(f" Chatchat API Server: {api_address()}")
|
||||
if args.webui:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from pathlib import Path
|
|||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from server.utils import api_address
|
||||
from configs.model_config import VECTOR_SEARCH_TOP_K
|
||||
from configs import VECTOR_SEARCH_TOP_K
|
||||
from server.knowledge_base.utils import get_kb_path, get_file_path
|
||||
|
||||
from pprint import pprint
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from pathlib import Path
|
|||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from server.utils import api_address
|
||||
from configs.model_config import VECTOR_SEARCH_TOP_K
|
||||
from configs import VECTOR_SEARCH_TOP_K
|
||||
from server.knowledge_base.utils import get_kb_path, get_file_path
|
||||
from webui_pages.utils import ApiRequest
|
||||
|
||||
|
|
|
|||
|
|
@ -6,21 +6,19 @@ from pathlib import Path
|
|||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||
from configs.model_config import LLM_MODEL, llm_model_dict
|
||||
from configs.model_config import LLM_MODEL
|
||||
from server.utils import api_address, get_model_worker_config
|
||||
|
||||
from pprint import pprint
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
|
||||
def get_configured_models():
|
||||
def get_configured_models() -> List[str]:
|
||||
model_workers = list(FSCHAT_MODEL_WORKERS)
|
||||
if "default" in model_workers:
|
||||
model_workers.remove("default")
|
||||
|
||||
llm_dict = list(llm_model_dict)
|
||||
|
||||
return model_workers, llm_dict
|
||||
return model_workers
|
||||
|
||||
|
||||
api_base_url = api_address()
|
||||
|
|
@ -56,12 +54,9 @@ def test_change_model(api="/llm_model/change"):
|
|||
running_models = get_running_models()
|
||||
assert len(running_models) > 0
|
||||
|
||||
model_workers, llm_dict = get_configured_models()
|
||||
model_workers = get_configured_models()
|
||||
|
||||
availabel_new_models = set(model_workers) - set(running_models)
|
||||
if len(availabel_new_models) == 0:
|
||||
availabel_new_models = set(llm_dict) - set(running_models)
|
||||
availabel_new_models = list(availabel_new_models)
|
||||
availabel_new_models = list(set(model_workers) - set(running_models))
|
||||
assert len(availabel_new_models) > 0
|
||||
print(availabel_new_models)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import sys
|
|||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
from configs.model_config import BING_SUBSCRIPTION_KEY
|
||||
from configs import BING_SUBSCRIPTION_KEY
|
||||
from server.utils import api_address
|
||||
|
||||
from pprint import pprint
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from transformers import AutoTokenizer
|
|||
import sys
|
||||
|
||||
sys.path.append("../..")
|
||||
from configs.model_config import (
|
||||
from configs import (
|
||||
CHUNK_SIZE,
|
||||
OVERLAP_SIZE
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
import streamlit as st
|
||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||
from webui_pages.utils import *
|
||||
from streamlit_chatbox import *
|
||||
from datetime import datetime
|
||||
from server.chat.search_engine_chat import SEARCH_ENGINES
|
||||
import os
|
||||
from configs.model_config import LLM_MODEL, TEMPERATURE
|
||||
from configs import LLM_MODEL, TEMPERATURE
|
||||
from server.utils import get_model_worker_config
|
||||
from typing import List, Dict
|
||||
|
||||
|
|
|
|||
|
|
@ -6,9 +6,10 @@ import pandas as pd
|
|||
from server.knowledge_base.utils import get_file_path, LOADER_DICT
|
||||
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
|
||||
from typing import Literal, Dict, Tuple
|
||||
from configs.model_config import (embedding_model_dict, kbs_config,
|
||||
from configs import (kbs_config,
|
||||
EMBEDDING_MODEL, DEFAULT_VS_TYPE,
|
||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
||||
from server.utils import list_embed_models
|
||||
import os
|
||||
import time
|
||||
|
||||
|
|
@ -94,7 +95,7 @@ def knowledge_base_page(api: ApiRequest):
|
|||
key="vs_type",
|
||||
)
|
||||
|
||||
embed_models = list(embedding_model_dict.keys())
|
||||
embed_models = list_embed_models()
|
||||
|
||||
embed_model = cols[1].selectbox(
|
||||
"Embedding 模型",
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
# 该文件包含webui通用工具,可以被不同的webui使用
|
||||
from typing import *
|
||||
from pathlib import Path
|
||||
from configs.model_config import (
|
||||
from configs import (
|
||||
EMBEDDING_MODEL,
|
||||
DEFAULT_VS_TYPE,
|
||||
KB_ROOT_PATH,
|
||||
LLM_MODEL,
|
||||
llm_model_dict,
|
||||
HISTORY_LEN,
|
||||
TEMPERATURE,
|
||||
SCORE_THRESHOLD,
|
||||
|
|
@ -15,9 +14,10 @@ from configs.model_config import (
|
|||
ZH_TITLE_ENHANCE,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
SEARCH_ENGINE_TOP_K,
|
||||
FSCHAT_MODEL_WORKERS,
|
||||
HTTPX_DEFAULT_TIMEOUT,
|
||||
logger, log_verbose,
|
||||
)
|
||||
from configs.server_config import HTTPX_DEFAULT_TIMEOUT
|
||||
import httpx
|
||||
import asyncio
|
||||
from server.chat.openai_chat import OpenAiChatMsgIn
|
||||
|
|
@ -779,7 +779,10 @@ class ApiRequest:
|
|||
'''
|
||||
获取configs中配置的模型列表
|
||||
'''
|
||||
return list(llm_model_dict.keys())
|
||||
models = list(FSCHAT_MODEL_WORKERS.keys())
|
||||
if "default" in models:
|
||||
models.remove("default")
|
||||
return models
|
||||
|
||||
def stop_llm_model(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Reference in New Issue