Azure 的fschat支持,弃用Langchain-MODEL (#1873)
* 支持了agentlm * 支持了agentlm和相关提示词 * 修改了Agent的一些功能,加入了Embed方面的一个优化 * 修改了部分Agent的工具 * 增加一些Langchain的自带工具 * 修复一些兼容性的bug * 恢复知识库 * 恢复知识库 * 修复Azure问题
This commit is contained in:
parent
27418ba9fa
commit
35a7ca74c0
|
|
@ -112,46 +112,6 @@ HISTORY_LEN = 3
|
||||||
TEMPERATURE = 0.7
|
TEMPERATURE = 0.7
|
||||||
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
|
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
|
||||||
|
|
||||||
|
|
||||||
LANGCHAIN_LLM_MODEL = {
|
|
||||||
# 不需要走Fschat封装的,Langchain直接支持的模型。
|
|
||||||
# 调用chatgpt时如果报出: urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
|
|
||||||
# Max retries exceeded with url: /v1/chat/completions
|
|
||||||
# 则需要将urllib3版本修改为1.25.11
|
|
||||||
# 如果依然报urllib3.exceptions.MaxRetryError: HTTPSConnectionPool,则将https改为http
|
|
||||||
# 参考https://zhuanlan.zhihu.com/p/350015032
|
|
||||||
|
|
||||||
# 如果报出:raise NewConnectionError(
|
|
||||||
# urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPSConnection object at 0x000001FE4BDB85E0>:
|
|
||||||
# Failed to establish a new connection: [WinError 10060]
|
|
||||||
# 则是因为内地和香港的IP都被OPENAI封了,需要切换为日本、新加坡等地
|
|
||||||
|
|
||||||
# 如果出现WARNING: Retrying langchain.chat_models.openai.acompletion_with_retry.<locals>._completion_with_retry in
|
|
||||||
# 4.0 seconds as it raised APIConnectionError: Error communicating with OpenAI.
|
|
||||||
# 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置
|
|
||||||
# 比如: "openai_proxy": 'http://127.0.0.1:4780'
|
|
||||||
|
|
||||||
# 这些配置文件的名字不能改动
|
|
||||||
"Azure-OpenAI": {
|
|
||||||
"deployment_name": "your Azure deployment name",
|
|
||||||
"model_version": "0701",
|
|
||||||
"openai_api_type": "azure",
|
|
||||||
"api_base_url": "https://your Azure point.azure.com",
|
|
||||||
"api_version": "2023-07-01-preview",
|
|
||||||
"api_key": "your Azure api key",
|
|
||||||
"openai_proxy": "",
|
|
||||||
},
|
|
||||||
"OpenAI": {
|
|
||||||
"model_name": "your openai model name(such as gpt-4)",
|
|
||||||
"api_base_url": "https://api.openai.com/v1",
|
|
||||||
"api_key": "your OPENAI_API_KEY",
|
|
||||||
"openai_proxy": "",
|
|
||||||
},
|
|
||||||
"Anthropic": {
|
|
||||||
"model_name": "your claude model name(such as claude2-100k)",
|
|
||||||
"api_key":"your ANTHROPIC_API_KEY",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ONLINE_LLM_MODEL = {
|
ONLINE_LLM_MODEL = {
|
||||||
# 线上模型。请在server_config中为每个在线API设置不同的端口
|
# 线上模型。请在server_config中为每个在线API设置不同的端口
|
||||||
# 具体注册及api key获取请前往 http://open.bigmodel.cn
|
# 具体注册及api key获取请前往 http://open.bigmodel.cn
|
||||||
|
|
@ -205,6 +165,15 @@ ONLINE_LLM_MODEL = {
|
||||||
"secret_key": "",
|
"secret_key": "",
|
||||||
"provider": "BaiChuanWorker",
|
"provider": "BaiChuanWorker",
|
||||||
},
|
},
|
||||||
|
|
||||||
|
# Azure API
|
||||||
|
"azure-api": {
|
||||||
|
"deployment_name": "", # 部署容器的名字
|
||||||
|
"resource_name": "", # https://{resource_name}.openai.azure.com/openai/ 填写resource_name的部分,其他部分不要填写
|
||||||
|
"api_version": "", # API的版本,不是模型版本
|
||||||
|
"api_key": "",
|
||||||
|
"provider": "AzureWorker",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ streamlit-option-menu>=0.3.6
|
||||||
streamlit-antd-components>=0.1.11
|
streamlit-antd-components>=0.1.11
|
||||||
streamlit-chatbox==1.1.10
|
streamlit-chatbox==1.1.10
|
||||||
streamlit-aggrid>=0.3.4.post3
|
streamlit-aggrid>=0.3.4.post3
|
||||||
httpx~=0.24.1
|
httpx[brotli,http2,socks]>=0.25.0
|
||||||
watchdog
|
watchdog
|
||||||
tqdm
|
tqdm
|
||||||
websockets
|
websockets
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT,LANGCHAIN_LLM_MODEL
|
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
|
||||||
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
|
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
|
||||||
get_httpx_client, get_model_worker_config)
|
get_httpx_client, get_model_worker_config)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,3 +5,4 @@ from .qianfan import QianFanWorker
|
||||||
from .fangzhou import FangZhouWorker
|
from .fangzhou import FangZhouWorker
|
||||||
from .qwen import QwenWorker
|
from .qwen import QwenWorker
|
||||||
from .baichuan import BaiChuanWorker
|
from .baichuan import BaiChuanWorker
|
||||||
|
from .azure import AzureWorker
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,117 @@
|
||||||
|
import sys
|
||||||
|
from fastchat.conversation import Conversation
|
||||||
|
from server.model_workers.base import ApiModelWorker
|
||||||
|
from server.utils import get_model_worker_config, get_httpx_client
|
||||||
|
from fastchat import conversation as conv
|
||||||
|
import json
|
||||||
|
from typing import List, Dict
|
||||||
|
from configs import TEMPERATURE
|
||||||
|
|
||||||
|
|
||||||
|
def request_azure_api(
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
resource_name: str = None,
|
||||||
|
api_key: str = None,
|
||||||
|
deployment_name: str = None,
|
||||||
|
api_version: str = "2023-07-01-preview",
|
||||||
|
temperature: float = TEMPERATURE,
|
||||||
|
max_tokens: int = None,
|
||||||
|
model_name: str = "azure-api",
|
||||||
|
):
|
||||||
|
config = get_model_worker_config(model_name)
|
||||||
|
data = dict(
|
||||||
|
messages=messages,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
url = f"https://{resource_name}.openai.azure.com/openai/deployments/{deployment_name}/chat/completions?api-version={api_version}"
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Accept': 'application/json',
|
||||||
|
'api-key': api_key,
|
||||||
|
}
|
||||||
|
with get_httpx_client() as client:
|
||||||
|
with client.stream("POST", url, headers=headers, json=data) as response:
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if not line.strip() or "[DONE]" in line:
|
||||||
|
continue
|
||||||
|
if line.startswith("data: "):
|
||||||
|
line = line[6:]
|
||||||
|
resp = json.loads(line)
|
||||||
|
yield resp
|
||||||
|
|
||||||
|
class AzureWorker(ApiModelWorker):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
controller_addr: str,
|
||||||
|
worker_addr: str,
|
||||||
|
model_names: List[str] = ["azure-api"],
|
||||||
|
version: str = "gpt-35-turbo",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||||
|
kwargs.setdefault("context_len", 8192)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
self.resource_name = config.get("resource_name")
|
||||||
|
self.api_key = config.get("api_key")
|
||||||
|
self.api_version = config.get("api_version")
|
||||||
|
self.version = version
|
||||||
|
self.deployment_name = config.get("deployment_name")
|
||||||
|
|
||||||
|
def generate_stream_gate(self, params):
|
||||||
|
super().generate_stream_gate(params)
|
||||||
|
|
||||||
|
messages = self.prompt_to_messages(params["prompt"])
|
||||||
|
text = ""
|
||||||
|
for resp in request_azure_api(messages=messages,
|
||||||
|
resource_name=self.resource_name,
|
||||||
|
api_key=self.api_key,
|
||||||
|
api_version=self.api_version,
|
||||||
|
deployment_name=self.deployment_name,
|
||||||
|
temperature=params.get("temperature"),
|
||||||
|
max_tokens=params.get("max_tokens")):
|
||||||
|
if choices := resp["choices"]:
|
||||||
|
if chunk := choices[0].get("delta", {}).get("content"):
|
||||||
|
text += chunk
|
||||||
|
yield json.dumps(
|
||||||
|
{
|
||||||
|
"error_code": 0,
|
||||||
|
"text": text
|
||||||
|
},
|
||||||
|
ensure_ascii=False
|
||||||
|
).encode() + b"\0"
|
||||||
|
|
||||||
|
def get_embeddings(self, params):
|
||||||
|
# TODO: 支持embeddings
|
||||||
|
print("embedding")
|
||||||
|
print(params)
|
||||||
|
|
||||||
|
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||||
|
# TODO: 确认模板是否需要修改
|
||||||
|
return conv.Conversation(
|
||||||
|
name=self.model_names[0],
|
||||||
|
system_message="",
|
||||||
|
messages=[],
|
||||||
|
roles=["user", "assistant"],
|
||||||
|
sep="\n### ",
|
||||||
|
stop_str="###",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
from server.utils import MakeFastAPIOffline
|
||||||
|
from fastchat.serve.base_model_worker import app
|
||||||
|
|
||||||
|
worker = AzureWorker(
|
||||||
|
controller_addr="http://127.0.0.1:20001",
|
||||||
|
worker_addr="http://127.0.0.1:21008",
|
||||||
|
)
|
||||||
|
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||||
|
MakeFastAPIOffline(app)
|
||||||
|
uvicorn.run(app, port=21008)
|
||||||
|
|
@ -138,9 +138,9 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
worker = BaiChuanWorker(
|
worker = BaiChuanWorker(
|
||||||
controller_addr="http://127.0.0.1:20001",
|
controller_addr="http://127.0.0.1:20001",
|
||||||
worker_addr="http://127.0.0.1:21001",
|
worker_addr="http://127.0.0.1:21007",
|
||||||
)
|
)
|
||||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||||
MakeFastAPIOffline(app)
|
MakeFastAPIOffline(app)
|
||||||
uvicorn.run(app, port=21001)
|
uvicorn.run(app, port=21007)
|
||||||
# do_request()
|
# do_request()
|
||||||
|
|
@ -1,17 +1,16 @@
|
||||||
|
import sys
|
||||||
from fastchat.conversation import Conversation
|
from fastchat.conversation import Conversation
|
||||||
from server.model_workers.base import ApiModelWorker
|
from server.model_workers.base import ApiModelWorker
|
||||||
from configs.model_config import TEMPERATURE
|
from configs.model_config import TEMPERATURE
|
||||||
from fastchat import conversation as conv
|
from fastchat import conversation as conv
|
||||||
import sys
|
|
||||||
import json
|
import json
|
||||||
import httpx
|
|
||||||
from cachetools import cached, TTLCache
|
from cachetools import cached, TTLCache
|
||||||
from server.utils import get_model_worker_config, get_httpx_client
|
from server.utils import get_model_worker_config, get_httpx_client
|
||||||
from typing import List, Literal, Dict
|
from typing import List, Literal, Dict
|
||||||
|
|
||||||
|
|
||||||
MODEL_VERSIONS = {
|
MODEL_VERSIONS = {
|
||||||
"ernie-bot": "completions",
|
"ernie-bot": "completions",
|
||||||
|
"ernie-bot-4": "completions_pro",
|
||||||
"ernie-bot-turbo": "eb-instant",
|
"ernie-bot-turbo": "eb-instant",
|
||||||
"bloomz-7b": "bloomz_7b1",
|
"bloomz-7b": "bloomz_7b1",
|
||||||
"qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed",
|
"qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed",
|
||||||
|
|
@ -46,7 +45,7 @@ MODEL_VERSIONS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@cached(TTLCache(1, 1800)) # 经过测试,缓存的token可以使用,目前每30分钟刷新一次
|
@cached(TTLCache(1, 1800)) # 经过测试,缓存的token可以使用,目前每30分钟刷新一次
|
||||||
def get_baidu_access_token(api_key: str, secret_key: str) -> str:
|
def get_baidu_access_token(api_key: str, secret_key: str) -> str:
|
||||||
"""
|
"""
|
||||||
使用 AK,SK 生成鉴权签名(Access Token)
|
使用 AK,SK 生成鉴权签名(Access Token)
|
||||||
|
|
@ -62,12 +61,12 @@ def get_baidu_access_token(api_key: str, secret_key: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
def request_qianfan_api(
|
def request_qianfan_api(
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
model_name: str = "qianfan-api",
|
model_name: str = "qianfan-api",
|
||||||
version: str = None,
|
version: str = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat'\
|
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \
|
||||||
'/{model_version}?access_token={access_token}'
|
'/{model_version}?access_token={access_token}'
|
||||||
config = get_model_worker_config(model_name)
|
config = get_model_worker_config(model_name)
|
||||||
version = version or config.get("version")
|
version = version or config.get("version")
|
||||||
|
|
@ -83,6 +82,7 @@ def request_qianfan_api(
|
||||||
model_version=version_url or MODEL_VERSIONS[version],
|
model_version=version_url or MODEL_VERSIONS[version],
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
|
|
@ -101,6 +101,7 @@ def request_qianfan_api(
|
||||||
if line.startswith("data: "):
|
if line.startswith("data: "):
|
||||||
line = line[6:]
|
line = line[6:]
|
||||||
resp = json.loads(line)
|
resp = json.loads(line)
|
||||||
|
breakpoint()
|
||||||
yield resp
|
yield resp
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -108,6 +109,7 @@ class QianFanWorker(ApiModelWorker):
|
||||||
"""
|
"""
|
||||||
百度千帆
|
百度千帆
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
|
@ -128,7 +130,7 @@ class QianFanWorker(ApiModelWorker):
|
||||||
|
|
||||||
def generate_stream_gate(self, params):
|
def generate_stream_gate(self, params):
|
||||||
messages = self.prompt_to_messages(params["prompt"])
|
messages = self.prompt_to_messages(params["prompt"])
|
||||||
text=""
|
text = ""
|
||||||
for resp in request_qianfan_api(messages,
|
for resp in request_qianfan_api(messages,
|
||||||
temperature=params.get("temperature"),
|
temperature=params.get("temperature"),
|
||||||
model_name=self.model_names[0]):
|
model_name=self.model_names[0]):
|
||||||
|
|
@ -147,7 +149,7 @@ class QianFanWorker(ApiModelWorker):
|
||||||
},
|
},
|
||||||
ensure_ascii=False
|
ensure_ascii=False
|
||||||
).encode() + b"\0"
|
).encode() + b"\0"
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
def get_embeddings(self, params):
|
||||||
# TODO: 支持embeddings
|
# TODO: 支持embeddings
|
||||||
print("embedding")
|
print("embedding")
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from fastapi import FastAPI
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import asyncio
|
import asyncio
|
||||||
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
|
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
|
||||||
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, LANGCHAIN_LLM_MODEL, logger, log_verbose,
|
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, logger, log_verbose,
|
||||||
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
|
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
|
||||||
import os
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
@ -388,7 +388,6 @@ def list_config_llm_models() -> Dict[str, Dict]:
|
||||||
workers.insert(0, LLM_MODEL)
|
workers.insert(0, LLM_MODEL)
|
||||||
return {
|
return {
|
||||||
"local": MODEL_PATH["llm_model"],
|
"local": MODEL_PATH["llm_model"],
|
||||||
"langchain": LANGCHAIN_LLM_MODEL,
|
|
||||||
"online": ONLINE_LLM_MODEL,
|
"online": ONLINE_LLM_MODEL,
|
||||||
"worker": workers,
|
"worker": workers,
|
||||||
}
|
}
|
||||||
|
|
@ -436,10 +435,6 @@ def get_model_worker_config(model_name: str = None) -> dict:
|
||||||
config.update(ONLINE_LLM_MODEL.get(model_name, {}))
|
config.update(ONLINE_LLM_MODEL.get(model_name, {}))
|
||||||
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
|
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
|
||||||
|
|
||||||
# 在线模型API
|
|
||||||
if model_name in LANGCHAIN_LLM_MODEL:
|
|
||||||
config["langchain_model"] = True
|
|
||||||
config["worker_class"] = ""
|
|
||||||
|
|
||||||
if model_name in ONLINE_LLM_MODEL:
|
if model_name in ONLINE_LLM_MODEL:
|
||||||
config["online_api"] = True
|
config["online_api"] = True
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,26 @@
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
root_path = Path(__file__).parent.parent.parent
|
||||||
|
sys.path.append(str(root_path))
|
||||||
|
from server.model_workers.azure import request_azure_api
|
||||||
|
import pytest
|
||||||
|
from configs import TEMPERATURE, ONLINE_LLM_MODEL
|
||||||
|
@pytest.mark.parametrize("version", ["azure-api"])
|
||||||
|
def test_azure(version):
|
||||||
|
messages = [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "介绍一下自己"}]
|
||||||
|
for resp in request_azure_api(
|
||||||
|
messages=messages,
|
||||||
|
api_base_url=ONLINE_LLM_MODEL["azure-api"]["api_base_url"],
|
||||||
|
api_key=ONLINE_LLM_MODEL["azure-api"]["api_key"],
|
||||||
|
deployment_name=ONLINE_LLM_MODEL["azure-api"]["deployment_name"],
|
||||||
|
api_version=ONLINE_LLM_MODEL["azure-api"]["api_version"],
|
||||||
|
temperature=TEMPERATURE,
|
||||||
|
max_tokens=1024,
|
||||||
|
model_name="azure-api"
|
||||||
|
):
|
||||||
|
assert resp["code"] == 200
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_azure("azure-api")
|
||||||
|
|
@ -7,7 +7,7 @@ from server.model_workers.baichuan import request_baichuan_api
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
||||||
|
|
||||||
def test_qwen():
|
def test_baichuan():
|
||||||
messages = [{"role": "user", "content": "hello"}]
|
messages = [{"role": "user", "content": "hello"}]
|
||||||
|
|
||||||
for x in request_baichuan_api(messages):
|
for x in request_baichuan_api(messages):
|
||||||
|
|
|
||||||
|
|
@ -17,4 +17,4 @@ def test_qianfan(version):
|
||||||
pprint(x)
|
pprint(x)
|
||||||
assert isinstance(x, dict)
|
assert isinstance(x, dict)
|
||||||
assert "error_code" not in x
|
assert "error_code" not in x
|
||||||
i += 1
|
i += 1
|
||||||
|
|
@ -4,7 +4,7 @@ from streamlit_chatbox import *
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import os
|
import os
|
||||||
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
|
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
|
||||||
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE,LANGCHAIN_LLM_MODEL)
|
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE)
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -83,7 +83,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
running_models = list(api.list_running_models())
|
running_models = list(api.list_running_models())
|
||||||
# running_models += LANGCHAIN_LLM_MODEL.keys()
|
|
||||||
available_models = []
|
available_models = []
|
||||||
config_models = api.list_config_models()
|
config_models = api.list_config_models()
|
||||||
worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型
|
worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型
|
||||||
|
|
@ -93,8 +92,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||||
for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型
|
for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型
|
||||||
if not v.get("provider") and k not in running_models:
|
if not v.get("provider") and k not in running_models:
|
||||||
available_models.append(k)
|
available_models.append(k)
|
||||||
for k, v in config_models.get("langchain", {}).items(): # 列出LANGCHAIN_LLM_MODEL支持的模型
|
|
||||||
available_models.append(k)
|
|
||||||
llm_models = running_models + available_models
|
llm_models = running_models + available_models
|
||||||
index = llm_models.index(st.session_state.get("cur_llm_model", api.get_default_llm_model()[0]))
|
index = llm_models.index(st.session_state.get("cur_llm_model", api.get_default_llm_model()[0]))
|
||||||
llm_model = st.selectbox("选择LLM模型:",
|
llm_model = st.selectbox("选择LLM模型:",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue