实现Api和WEBUI的前后端分离 (#1772)

* update ApiRequest: 删除no_remote_api本地调用模式;支持同步/异步调用
* 实现API和WEBUI的分离:
- API运行服务器上的配置通过/llm_model/get_model_config、/server/configs接口提供,WEBUI运行机器上的配置项仅作为代码内部默认值使用
- 服务器可用的搜索引擎通过/server/list_search_engines提供
- WEBUI可选LLM列表中只列出在FSCHAT_MODEL_WORKERS中配置的模型
- 修改WEBUI中默认LLM_MODEL获取方式,改为从api端读取
- 删除knowledge_base_chat中`local_doc_url`参数

其它修改:
- 删除多余的kb_config.py.exmaple(名称错误)
- server_config中默认关闭vllm
- server_config中默认注释除智谱AI之外的在线模型
- 修改requests从系统获取的代理,避免model worker注册错误

* 修正:
- api.list_config_models返回模型原始配置
- api.list_config_models和api.get_model_config中过滤online api模型的敏感信息
- 将GPT等直接访问的模型列入WEBUI可选模型列表

其它:
- 指定langchain==0.3.313, fschat==0.2.30, langchain-experimental==0.0.30
This commit is contained in:
liunux4odoo 2023-10-17 16:52:07 +08:00 committed by GitHub
parent 94977c7ab1
commit 9ce328fea9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 521 additions and 644 deletions

View File

@ -1,99 +0,0 @@
import os
# 默认向量库类型。可选faiss, milvus, pg.
DEFAULT_VS_TYPE = "faiss"
# 缓存向量库数量针对FAISS
CACHED_VS_NUM = 1
# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
CHUNK_SIZE = 250
# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)
OVERLAP_SIZE = 50
# 知识库匹配向量数量
VECTOR_SEARCH_TOP_K = 3
# 知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右
SCORE_THRESHOLD = 1
# 搜索引擎匹配结题数量
SEARCH_ENGINE_TOP_K = 3
# Bing 搜索必备变量
# 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search
# 具体申请方式请见
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/create-bing-search-service-resource
# 使用python创建bing api 搜索实例详见:
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/quickstarts/rest/python
BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
# 注意不是bing Webmaster Tools的api key
# 此外如果是在服务器上报Failed to establish a new connection: [Errno 110] Connection timed out
# 是因为服务器加了防火墙需要联系管理员加白名单如果公司的服务器的话就别想了GG
BING_SUBSCRIPTION_KEY = ""
# 是否开启中文标题加强,以及标题增强的相关配置
# 通过增加标题判断判断哪些文本为标题并在metadata中进行标记
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
ZH_TITLE_ENHANCE = False
# 通常情况下不需要更改以下内容
# 知识库默认存储路径
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
if not os.path.exists(KB_ROOT_PATH):
os.mkdir(KB_ROOT_PATH)
# 数据库默认存储路径。
# 如果使用sqlite可以直接修改DB_ROOT_PATH如果使用其它数据库请直接修改SQLALCHEMY_DATABASE_URI。
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}"
# 可选向量库类型及对应配置
kbs_config = {
"faiss": {
},
"milvus": {
"host": "127.0.0.1",
"port": "19530",
"user": "",
"password": "",
"secure": False,
},
"pg": {
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
}
}
# TextSplitter配置项如果你不明白其中的含义就不要修改。
text_splitter_dict = {
"ChineseRecursiveTextSplitter": {
"source": "huggingface", ## 选择tiktoken则使用openai的方法
"tokenizer_name_or_path": "gpt2",
},
"SpacyTextSplitter": {
"source": "huggingface",
"tokenizer_name_or_path": "",
},
"RecursiveCharacterTextSplitter": {
"source": "tiktoken",
"tokenizer_name_or_path": "cl100k_base",
},
"MarkdownHeaderTextSplitter": {
"headers_to_split_on":
[
("#", "head1"),
("##", "head2"),
("###", "head3"),
("####", "head4"),
]
},
}
# TEXT_SPLITTER 名称
TEXT_SPLITTER_NAME = "ChineseRecursiveTextSplitter"

View File

@ -32,6 +32,7 @@ FSCHAT_OPENAI_API = {
# fastchat model_worker server # fastchat model_worker server
# 这些模型必须是在model_config.MODEL_PATH或ONLINE_MODEL中正确配置的。 # 这些模型必须是在model_config.MODEL_PATH或ONLINE_MODEL中正确配置的。
# 在启动startup.py时可用通过`--model-worker --model-name xxxx`指定模型不指定则为LLM_MODEL # 在启动startup.py时可用通过`--model-worker --model-name xxxx`指定模型不指定则为LLM_MODEL
# 必须在这里添加的模型才会出现在WEBUI中可选模型列表里LLM_MODEL会自动添加
FSCHAT_MODEL_WORKERS = { FSCHAT_MODEL_WORKERS = {
# 所有模型共用的默认配置,可在模型专项配置中进行覆盖。 # 所有模型共用的默认配置,可在模型专项配置中进行覆盖。
"default": { "default": {
@ -39,7 +40,8 @@ FSCHAT_MODEL_WORKERS = {
"port": 20002, "port": 20002,
"device": LLM_DEVICE, "device": LLM_DEVICE,
# False,'vllm',使用的推理加速框架,使用vllm如果出现HuggingFace通信问题参见doc/FAQ # False,'vllm',使用的推理加速框架,使用vllm如果出现HuggingFace通信问题参见doc/FAQ
"infer_turbo": "vllm" if sys.platform.startswith("linux") else False, # vllm对一些模型支持还不成熟暂时默认关闭
"infer_turbo": False,
# model_worker多卡加载需要配置的参数 # model_worker多卡加载需要配置的参数
# "gpus": None, # 使用的GPU以str的格式指定如"0,1"如失效请使用CUDA_VISIBLE_DEVICES="0,1"等形式指定 # "gpus": None, # 使用的GPU以str的格式指定如"0,1"如失效请使用CUDA_VISIBLE_DEVICES="0,1"等形式指定
@ -97,24 +99,24 @@ FSCHAT_MODEL_WORKERS = {
"zhipu-api": { # 请为每个要运行的在线API设置不同的端口 "zhipu-api": { # 请为每个要运行的在线API设置不同的端口
"port": 21001, "port": 21001,
}, },
"minimax-api": { # "minimax-api": {
"port": 21002, # "port": 21002,
}, # },
"xinghuo-api": { # "xinghuo-api": {
"port": 21003, # "port": 21003,
}, # },
"qianfan-api": { # "qianfan-api": {
"port": 21004, # "port": 21004,
}, # },
"fangzhou-api": { # "fangzhou-api": {
"port": 21005, # "port": 21005,
}, # },
"qwen-api": { # "qwen-api": {
"port": 21006, # "port": 21006,
}, # },
"baichuan-api": { # "baichuan-api": {
"port": 21007, # "port": 21007,
}, # },
} }
# fastchat multi model worker server # fastchat multi model worker server

View File

@ -1,5 +1,6 @@
langchain>=0.0.310 langchain==0.0.313
fschat[model_worker]>=0.2.30 langchain-experimental==0.0.30
fschat[model_worker]==0.2.30
openai openai
sentence_transformers sentence_transformers
transformers>=4.34 transformers>=4.34

View File

@ -1,5 +1,6 @@
langchain>=0.0.310 langchain==0.0.313
fschat[model_worker]>=0.2.30 langchain-experimental==0.0.30
fschat[model_worker]==0.2.30
openai openai
sentence_transformers>=2.2.2 sentence_transformers>=2.2.2
transformers>=4.34 transformers>=4.34

View File

@ -17,8 +17,10 @@ from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
update_docs, download_doc, recreate_vector_store, update_docs, download_doc, recreate_vector_store,
search_docs, DocumentWithScore) search_docs, DocumentWithScore)
from server.llm_api import list_running_models, list_config_models, change_llm_model, stop_llm_model from server.llm_api import (list_running_models, list_config_models,
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline change_llm_model, stop_llm_model,
get_model_config, list_search_engines)
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, get_server_configs
from typing import List from typing import List
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@ -139,6 +141,11 @@ def create_app():
summary="列出configs已配置的模型", summary="列出configs已配置的模型",
)(list_config_models) )(list_config_models)
app.post("/llm_model/get_model_config",
tags=["LLM Model Management"],
summary="获取模型配置(合并后)",
)(get_model_config)
app.post("/llm_model/stop", app.post("/llm_model/stop",
tags=["LLM Model Management"], tags=["LLM Model Management"],
summary="停止指定的LLM模型Model Worker)", summary="停止指定的LLM模型Model Worker)",
@ -149,6 +156,17 @@ def create_app():
summary="切换指定的LLM模型Model Worker)", summary="切换指定的LLM模型Model Worker)",
)(change_llm_model) )(change_llm_model)
# 服务器相关接口
app.post("/server/configs",
tags=["Server State"],
summary="获取服务器原始配置信息",
)(get_server_configs)
app.post("/server/list_search_engines",
tags=["Server State"],
summary="获取服务器支持的搜索引擎",
)(list_search_engines)
return app return app

View File

@ -33,7 +33,6 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"), # TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。 max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"), # TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
request: Request = None, request: Request = None,
): ):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name) kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
@ -74,11 +73,8 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
source_documents = [] source_documents = []
for inum, doc in enumerate(docs): for inum, doc in enumerate(docs):
filename = os.path.split(doc.metadata["source"])[-1] filename = os.path.split(doc.metadata["source"])[-1]
if local_doc_url: parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename})
url = "file://" + doc.metadata["source"] url = f"{request.base_url}knowledge_base/download_doc?" + parameters
else:
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename})
url = f"{request.base_url}knowledge_base/download_doc?" + parameters
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n""" text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
source_documents.append(text) source_documents.append(text)

View File

@ -1,7 +1,7 @@
from fastapi import Body from fastapi import Body
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
from server.utils import BaseResponse, fschat_controller_address, list_llm_models, get_httpx_client from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
get_httpx_client, get_model_worker_config)
def list_running_models( def list_running_models(
@ -9,19 +9,21 @@ def list_running_models(
placeholder: str = Body(None, description="该参数未使用,占位用"), placeholder: str = Body(None, description="该参数未使用,占位用"),
) -> BaseResponse: ) -> BaseResponse:
''' '''
从fastchat controller获取已加载模型列表 从fastchat controller获取已加载模型列表及其配置项
''' '''
try: try:
controller_address = controller_address or fschat_controller_address() controller_address = controller_address or fschat_controller_address()
with get_httpx_client() as client: with get_httpx_client() as client:
r = client.post(controller_address + "/list_models") r = client.post(controller_address + "/list_models")
return BaseResponse(data=r.json()["models"]) models = r.json()["models"]
data = {m: get_model_worker_config(m) for m in models}
return BaseResponse(data=data)
except Exception as e: except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}', logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
return BaseResponse( return BaseResponse(
code=500, code=500,
data=[], data={},
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}") msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
@ -29,7 +31,38 @@ def list_config_models() -> BaseResponse:
''' '''
从本地获取configs中配置的模型列表 从本地获取configs中配置的模型列表
''' '''
return BaseResponse(data=list_llm_models()) configs = list_config_llm_models()
# 删除ONLINE_MODEL配置中的敏感信息
for config in configs["online"].values():
del_keys = set(["worker_class"])
for k in config:
if "key" in k.lower() or "secret" in k.lower():
del_keys.add(k)
for k in del_keys:
config.pop(k, None)
return BaseResponse(data=configs)
def get_model_config(
model_name: str = Body(description="配置中LLM模型的名称"),
placeholder: str = Body(description="占位用,无实际效果")
) -> BaseResponse:
'''
获取LLM模型配置项合并后的
'''
config = get_model_worker_config(model_name=model_name)
# 删除ONLINE_MODEL配置中的敏感信息
del_keys = set(["worker_class"])
for k in config:
if "key" in k.lower() or "secret" in k.lower():
del_keys.add(k)
for k in del_keys:
config.pop(k, None)
return BaseResponse(data=config)
def stop_llm_model( def stop_llm_model(
@ -79,3 +112,9 @@ def change_llm_model(
return BaseResponse( return BaseResponse(
code=500, code=500,
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}") msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
def list_search_engines() -> BaseResponse:
from server.chat.search_engine_chat import SEARCH_ENGINES
return BaseResponse(data=list(SEARCH_ENGINES))

View File

@ -258,17 +258,18 @@ def list_embed_models() -> List[str]:
return list(MODEL_PATH["embed_model"]) return list(MODEL_PATH["embed_model"])
def list_llm_models() -> Dict[str, List[str]]: def list_config_llm_models() -> Dict[str, Dict]:
''' '''
get names of configured llm models with different types. get configured llm models with different types.
return [(model_name, config_type), ...] return [(model_name, config_type), ...]
''' '''
workers = list(FSCHAT_MODEL_WORKERS) workers = list(FSCHAT_MODEL_WORKERS)
if "default" in workers: if LLM_MODEL not in workers:
workers.remove("default") workers.insert(0, LLM_MODEL)
return { return {
"local": list(MODEL_PATH["llm_model"]), "local": MODEL_PATH["llm_model"],
"online": list(ONLINE_LLM_MODEL), "online": ONLINE_LLM_MODEL,
"worker": workers, "worker": workers,
} }
@ -306,7 +307,7 @@ def get_model_worker_config(model_name: str = None) -> dict:
加载model worker的配置项 加载model worker的配置项
优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"] 优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"]
''' '''
from configs.model_config import ONLINE_LLM_MODEL from configs.model_config import ONLINE_LLM_MODEL, MODEL_PATH
from configs.server_config import FSCHAT_MODEL_WORKERS from configs.server_config import FSCHAT_MODEL_WORKERS
from server import model_workers from server import model_workers
@ -324,9 +325,10 @@ def get_model_worker_config(model_name: str = None) -> dict:
msg = f"在线模型 {model_name} 的provider没有正确配置" msg = f"在线模型 {model_name} 的provider没有正确配置"
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
# 本地模型
config["model_path"] = get_model_path(model_name) if model_name in MODEL_PATH["llm_model"]:
config["device"] = llm_device(config.get("device")) config["model_path"] = get_model_path(model_name)
config["device"] = llm_device(config.get("device"))
return config return config
@ -449,11 +451,11 @@ def set_httpx_config(
# TODO: 简单的清除系统代理不是个好的选择影响太多。似乎修改代理服务器的bypass列表更好。 # TODO: 简单的清除系统代理不是个好的选择影响太多。似乎修改代理服务器的bypass列表更好。
# patch requests to use custom proxies instead of system settings # patch requests to use custom proxies instead of system settings
# def _get_proxies(): def _get_proxies():
# return {} return proxies
# import urllib.request import urllib.request
# urllib.request.getproxies = _get_proxies urllib.request.getproxies = _get_proxies
# 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch # 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch
@ -557,3 +559,35 @@ def get_httpx_client(
return httpx.AsyncClient(**kwargs) return httpx.AsyncClient(**kwargs)
else: else:
return httpx.Client(**kwargs) return httpx.Client(**kwargs)
def get_server_configs() -> Dict:
'''
获取configs中的原始配置项供前端使用
'''
from configs.kb_config import (
DEFAULT_VS_TYPE,
CHUNK_SIZE,
OVERLAP_SIZE,
SCORE_THRESHOLD,
VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K,
ZH_TITLE_ENHANCE,
text_splitter_dict,
TEXT_SPLITTER_NAME,
)
from configs.model_config import (
LLM_MODEL,
EMBEDDING_MODEL,
HISTORY_LEN,
TEMPERATURE,
)
from configs.prompt_config import PROMPT_TEMPLATES
_custom = {
"controller_address": fschat_controller_address(),
"openai_api_address": fschat_openai_api_address(),
"api_address": api_address(),
}
return {**{k: v for k, v in locals().items() if k[0] != "_"}, **_custom}

View File

@ -14,7 +14,7 @@ from pprint import pprint
api_base_url = api_address() api_base_url = api_address()
api: ApiRequest = ApiRequest(api_base_url, no_remote_api=False) api: ApiRequest = ApiRequest(api_base_url)
kb = "kb_for_api_test" kb = "kb_for_api_test"

View File

@ -32,7 +32,7 @@ def get_running_models(api="/llm_model/list_models"):
return [] return []
def test_running_models(api="/llm_model/list_models"): def test_running_models(api="/llm_model/list_running_models"):
url = api_base_url + api url = api_base_url + api
r = requests.post(url) r = requests.post(url)
assert r.status_code == 200 assert r.status_code == 200
@ -48,7 +48,7 @@ def test_running_models(api="/llm_model/list_models"):
# r = requests.post(url, json={""}) # r = requests.post(url, json={""})
def test_change_model(api="/llm_model/change"): def test_change_model(api="/llm_model/change_model"):
url = api_base_url + api url = api_base_url + api
running_models = get_running_models() running_models = get_running_models()

View File

@ -22,9 +22,10 @@ if __name__ == "__main__":
) )
if not chat_box.chat_inited: if not chat_box.chat_inited:
running_models = api.list_running_models()
st.toast( st.toast(
f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n" f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
f"当前使用模型`{LLM_MODEL}`, 您可以开始提问了." f"当前运行中的模型`{running_models}`, 您可以开始提问了."
) )
pages = { pages = {

View File

@ -2,11 +2,11 @@ import streamlit as st
from webui_pages.utils import * from webui_pages.utils import *
from streamlit_chatbox import * from streamlit_chatbox import *
from datetime import datetime from datetime import datetime
from server.chat.search_engine_chat import SEARCH_ENGINES
import os import os
from configs import LLM_MODEL, TEMPERATURE from configs import LLM_MODEL, TEMPERATURE, HISTORY_LEN
from server.utils import get_model_worker_config
from typing import List, Dict from typing import List, Dict
chat_box = ChatBox( chat_box = ChatBox(
assistant_avatar=os.path.join( assistant_avatar=os.path.join(
"img", "img",
@ -15,9 +15,6 @@ chat_box = ChatBox(
) )
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]: def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
''' '''
返回消息历史 返回消息历史
@ -38,6 +35,26 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) ->
return chat_box.filter_history(history_len=history_len, filter=filter) return chat_box.filter_history(history_len=history_len, filter=filter)
def get_default_llm_model(api: ApiRequest) -> (str, bool):
'''
从服务器上获取当前运行的LLM模型如果本机配置的LLM_MODEL属于本地模型且在其中则优先返回
返回类型为model_name, is_local_model
'''
running_models = api.list_running_models()
if not running_models:
return "", False
if LLM_MODEL in running_models:
return LLM_MODEL, True
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
if local_models:
return local_models[0], True
return running_models[0], False
def dialogue_page(api: ApiRequest): def dialogue_page(api: ApiRequest):
chat_box.init_session() chat_box.init_session()
@ -51,7 +68,6 @@ def dialogue_page(api: ApiRequest):
if cur_kb: if cur_kb:
text = f"{text} 当前知识库: `{cur_kb}`。" text = f"{text} 当前知识库: `{cur_kb}`。"
st.toast(text) st.toast(text)
# sac.alert(text, description="descp", type="success", closable=True, banner=True)
dialogue_mode = st.selectbox("请选择对话模式:", dialogue_mode = st.selectbox("请选择对话模式:",
["LLM 对话", ["LLM 对话",
@ -65,7 +81,7 @@ def dialogue_page(api: ApiRequest):
) )
def on_llm_change(): def on_llm_change():
config = get_model_worker_config(llm_model) config = api.get_model_config(llm_model)
if not config.get("online_api"): # 只有本地model_worker可以切换模型 if not config.get("online_api"): # 只有本地model_worker可以切换模型
st.session_state["prev_llm_model"] = llm_model st.session_state["prev_llm_model"] = llm_model
st.session_state["cur_llm_model"] = st.session_state.llm_model st.session_state["cur_llm_model"] = st.session_state.llm_model
@ -75,15 +91,20 @@ def dialogue_page(api: ApiRequest):
return f"{x} (Running)" return f"{x} (Running)"
return x return x
running_models = api.list_running_models() running_models = list(api.list_running_models())
available_models = [] available_models = []
config_models = api.list_config_models() config_models = api.list_config_models()
for models in config_models.values(): worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型
for m in models: for m in worker_models:
if m not in running_models: if m not in running_models and m != "default":
available_models.append(m) available_models.append(m)
for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型如GPT
if not v.get("provider") and k not in running_models:
print(k, v)
available_models.append(k)
llm_models = running_models + available_models llm_models = running_models + available_models
index = llm_models.index(st.session_state.get("cur_llm_model", LLM_MODEL)) index = llm_models.index(st.session_state.get("cur_llm_model", get_default_llm_model(api)[0]))
llm_model = st.selectbox("选择LLM模型", llm_model = st.selectbox("选择LLM模型",
llm_models, llm_models,
index, index,
@ -92,7 +113,7 @@ def dialogue_page(api: ApiRequest):
key="llm_model", key="llm_model",
) )
if (st.session_state.get("prev_llm_model") != llm_model if (st.session_state.get("prev_llm_model") != llm_model
and not get_model_worker_config(llm_model).get("online_api") and not api.get_model_config(llm_model).get("online_api")
and llm_model not in running_models): and llm_model not in running_models):
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"): with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
prev_model = st.session_state.get("prev_llm_model") prev_model = st.session_state.get("prev_llm_model")
@ -114,7 +135,7 @@ def dialogue_page(api: ApiRequest):
if dialogue_mode == "知识库问答": if dialogue_mode == "知识库问答":
with st.expander("知识库配置", True): with st.expander("知识库配置", True):
kb_list = api.list_knowledge_bases(no_remote_api=True) kb_list = api.list_knowledge_bases()
selected_kb = st.selectbox( selected_kb = st.selectbox(
"请选择知识库:", "请选择知识库:",
kb_list, kb_list,
@ -126,7 +147,7 @@ def dialogue_page(api: ApiRequest):
# chunk_content = st.checkbox("关联上下文", False, disabled=True) # chunk_content = st.checkbox("关联上下文", False, disabled=True)
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True) # chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
elif dialogue_mode == "搜索引擎问答": elif dialogue_mode == "搜索引擎问答":
search_engine_list = list(SEARCH_ENGINES.keys()) search_engine_list = api.list_search_engines()
with st.expander("搜索引擎配置", True): with st.expander("搜索引擎配置", True):
search_engine = st.selectbox( search_engine = st.selectbox(
label="请选择搜索引擎", label="请选择搜索引擎",

File diff suppressed because it is too large Load Diff