From e920cd00645fd6aea81f4c560e3b055d660c4ca9 Mon Sep 17 00:00:00 2001
From: zR <2448370773@qq.com>
Date: Fri, 20 Oct 2023 18:13:55 +0800
Subject: [PATCH] =?UTF-8?q?=E5=90=88=E5=B9=B6=E5=88=86=E6=94=AF=EF=BC=8C?=
=?UTF-8?q?=E6=94=AF=E6=8C=81=20(#1808)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* 北京黑客松更新
知识库支持:
支持zilliz数据库
Agent支持:
支持以下工具调用
1. 支持互联网Agent调用
2. 支持知识库Agent调用
3. 支持旅游助手工具(未上传)
知识库更新
1. 支持知识库简介,用于Agent选择
2. UI对应知识库简介
提示词选择
1. UI 和模板支持提示词模板更换选择
* 数据库更新介绍问题解决
* 关于Langchain自己支持的模型
1. 修复了Openai无法调用的bug
2. 支持了Azure Openai Claude模型
(在模型切换界面由于优先级问题,显示的会是其他联网模型)
3. 422问题被修复,用了另一种替代方案。
4. 更新了部分依赖
---
configs/kb_config.py.example | 4 --
configs/model_config.py.example | 25 +++++++--
init_database.py | 2 +
requirements.txt | 11 ++--
requirements_api.txt | 11 ++--
requirements_webui.txt | 8 +--
server/agent/callbacks.py | 1 -
server/agent/custom_template.py | 3 +-
server/agent/tools/__init__.py | 1 -
server/knowledge_base/utils.py | 2 -
server/llm_api.py | 3 +-
server/utils.py | 93 ++++++++++++++++++++++++--------
startup.py | 14 +++--
webui_pages/dialogue/dialogue.py | 25 +++++----
14 files changed, 135 insertions(+), 68 deletions(-)
diff --git a/configs/kb_config.py.example b/configs/kb_config.py.example
index 15abb28..40165b9 100644
--- a/configs/kb_config.py.example
+++ b/configs/kb_config.py.example
@@ -1,6 +1,5 @@
import os
-
# 默认使用的知识库
DEFAULT_KNOWLEDGE_BASE = "samples"
@@ -57,14 +56,11 @@ KB_INFO = {
"知识库名称": "知识库介绍",
"samples": "关于本项目issue的解答",
}
-
# 通常情况下不需要更改以下内容
-
# 知识库默认存储路径
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
if not os.path.exists(KB_ROOT_PATH):
os.mkdir(KB_ROOT_PATH)
-
# 数据库默认存储路径。
# 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
diff --git a/configs/model_config.py.example b/configs/model_config.py.example
index 54c3f0f..78a10e9 100644
--- a/configs/model_config.py.example
+++ b/configs/model_config.py.example
@@ -112,7 +112,8 @@ TEMPERATURE = 0.7
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
-ONLINE_LLM_MODEL = {
+LANGCHAIN_LLM_MODEL = {
+ # 不需要走Fschat封装的,Langchain直接支持的模型。
# 调用chatgpt时如果报出: urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
# Max retries exceeded with url: /v1/chat/completions
# 则需要将urllib3版本修改为1.25.11
@@ -128,11 +129,29 @@ ONLINE_LLM_MODEL = {
# 4.0 seconds as it raised APIConnectionError: Error communicating with OpenAI.
# 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置
# 比如: "openai_proxy": 'http://127.0.0.1:4780'
- "gpt-3.5-turbo": {
+
+ # 这些配置文件的名字不能改动
+ "Azure-OpenAI": {
+ "deployment_name": "your Azure deployment name",
+ "model_version": "0701",
+ "openai_api_type": "azure",
+ "api_base_url": "https://your Azure point.azure.com",
+ "api_version": "2023-07-01-preview",
+ "api_key": "your Azure api key",
+ "openai_proxy": "",
+ },
+ "OpenAI": {
+ "model_name": "your openai model name(such as gpt-4)",
"api_base_url": "https://api.openai.com/v1",
"api_key": "your OPENAI_API_KEY",
- "openai_proxy": "your OPENAI_PROXY",
+ "openai_proxy": "",
},
+ "Anthropic": {
+ "model_name": "your claude model name(such as claude2-100k)",
+ "api_key":"your ANTHROPIC_API_KEY",
+ }
+}
+ONLINE_LLM_MODEL = {
# 线上模型。请在server_config中为每个在线API设置不同的端口
# 具体注册及api key获取请前往 http://open.bigmodel.cn
"zhipu-api": {
diff --git a/init_database.py b/init_database.py
index 9e807a8..c2cd1d4 100644
--- a/init_database.py
+++ b/init_database.py
@@ -1,3 +1,5 @@
+import sys
+sys.path.append(".")
from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files
from configs.model_config import NLTK_DATA_PATH
import nltk
diff --git a/requirements.txt b/requirements.txt
index e1b0437..3da9c84 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,14 +1,15 @@
-langchain==0.0.317
-langchain-experimental==0.0.30
+langchain>=0.0.319
+langchain-experimental>=0.0.30
fschat[model_worker]==0.2.31
-openai
+xformers>=0.0.22.post4
+openai>=0.28.1
sentence_transformers
transformers>=4.34
torch>=2.0.1 # 推荐2.1
torchvision
torchaudio
-fastapi>=0.103.2
-nltk~=3.8.1
+fastapi>=0.104
+nltk>=3.8.1
uvicorn~=0.23.1
starlette~=0.27.0
pydantic~=1.10.11
diff --git a/requirements_api.txt b/requirements_api.txt
index 3bcb1b5..b8a7f6d 100644
--- a/requirements_api.txt
+++ b/requirements_api.txt
@@ -1,13 +1,14 @@
-langchain==0.0.317
-langchain-experimental==0.0.30
+langchain>=0.0.319
+langchain-experimental>=0.0.30
fschat[model_worker]==0.2.31
-openai
+xformers>=0.0.22.post4
+openai>=0.28.1
sentence_transformers>=2.2.2
transformers>=4.34
-torch>=2.0.1
+torch>=2.1
torchvision
torchaudio
-fastapi>=0.103.1
+fastapi>=0.104
nltk~=3.8.1
uvicorn~=0.23.1
starlette~=0.27.0
diff --git a/requirements_webui.txt b/requirements_webui.txt
index 561b10a..a36fec4 100644
--- a/requirements_webui.txt
+++ b/requirements_webui.txt
@@ -1,11 +1,11 @@
numpy~=1.24.4
pandas~=2.0.3
-streamlit>=1.26.0
+streamlit>=1.27.2
streamlit-option-menu>=0.3.6
-streamlit-antd-components>=0.1.11
+streamlit-antd-components>=0.2.3
streamlit-chatbox==1.1.10
streamlit-aggrid>=0.3.4.post3
-httpx~=0.24.1
-nltk
+httpx>=0.25.0
+nltk>=3.8.1
watchdog
websockets
diff --git a/server/agent/callbacks.py b/server/agent/callbacks.py
index ddc6ffe..49ce973 100644
--- a/server/agent/callbacks.py
+++ b/server/agent/callbacks.py
@@ -97,7 +97,6 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
llm_token="",
)
self.queue.put_nowait(dumps(self.cur_tool))
-
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
diff --git a/server/agent/custom_template.py b/server/agent/custom_template.py
index 22469c6..fdac6e2 100644
--- a/server/agent/custom_template.py
+++ b/server/agent/custom_template.py
@@ -4,7 +4,6 @@ from langchain.prompts import StringPromptTemplate
from typing import List
from langchain.schema import AgentAction, AgentFinish
from server.agent import model_container
-begin = False
class CustomPromptTemplate(StringPromptTemplate):
# The template to use
template: str
@@ -38,7 +37,7 @@ class CustomOutputParser(AgentOutputParser):
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
# Check if agent should finish
- support_agent = ["gpt","Qwen","qwen-api","baichuan-api"]
+ support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
if not any(agent in model_container.MODEL for agent in support_agent) and self.begin:
self.begin = False
stop_words = ["Observation:"]
diff --git a/server/agent/tools/__init__.py b/server/agent/tools/__init__.py
index 3682787..7031b71 100644
--- a/server/agent/tools/__init__.py
+++ b/server/agent/tools/__init__.py
@@ -2,7 +2,6 @@
from .search_knowledge_simple import knowledge_search_simple
from .search_all_knowledge_once import knowledge_search_once
from .search_all_knowledge_more import knowledge_search_more
-# from .travel_assistant import travel_assistant
from .calculate import calculate
from .translator import translate
from .weather import weathercheck
diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py
index 02212c8..c73d021 100644
--- a/server/knowledge_base/utils.py
+++ b/server/knowledge_base/utils.py
@@ -1,7 +1,5 @@
import os
-
from transformers import AutoTokenizer
-
from configs import (
EMBEDDING_MODEL,
KB_ROOT_PATH,
diff --git a/server/llm_api.py b/server/llm_api.py
index 0d3ba3f..453f147 100644
--- a/server/llm_api.py
+++ b/server/llm_api.py
@@ -1,5 +1,5 @@
from fastapi import Body
-from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
+from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT,LANGCHAIN_LLM_MODEL
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
get_httpx_client, get_model_worker_config)
@@ -52,7 +52,6 @@ def get_model_config(
获取LLM模型配置项(合并后的)
'''
config = get_model_worker_config(model_name=model_name)
-
# 删除ONLINE_MODEL配置中的敏感信息
del_keys = set(["worker_class"])
for k in config:
diff --git a/server/utils.py b/server/utils.py
index 22185eb..5dff5f1 100644
--- a/server/utils.py
+++ b/server/utils.py
@@ -5,12 +5,11 @@ from fastapi import FastAPI
from pathlib import Path
import asyncio
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
- MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL,
- logger, log_verbose,
+ MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, LANGCHAIN_LLM_MODEL, logger, log_verbose,
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
-from langchain.chat_models import ChatOpenAI
+from langchain.chat_models import ChatOpenAI, AzureChatOpenAI, ChatAnthropic
import httpx
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
@@ -40,19 +39,64 @@ def get_ChatOpenAI(
verbose: bool = True,
**kwargs: Any,
) -> ChatOpenAI:
- config = get_model_worker_config(model_name)
- model = ChatOpenAI(
- streaming=streaming,
- verbose=verbose,
- callbacks=callbacks,
- openai_api_key=config.get("api_key", "EMPTY"),
- openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
- model_name=model_name,
- temperature=temperature,
- max_tokens=max_tokens,
- openai_proxy=config.get("openai_proxy"),
- **kwargs
- )
+ ## 以下模型是Langchain原生支持的模型,这些模型不会走Fschat封装
+ config_models = list_config_llm_models()
+ if model_name in config_models.get("langchain", {}):
+ config = config_models["langchain"][model_name]
+ if model_name == "Azure-OpenAI":
+ model = AzureChatOpenAI(
+ streaming=streaming,
+ verbose=verbose,
+ callbacks=callbacks,
+ deployment_name=config.get("deployment_name"),
+ model_version=config.get("model_version"),
+ openai_api_type=config.get("openai_api_type"),
+ openai_api_base=config.get("api_base_url"),
+ openai_api_version=config.get("api_version"),
+ openai_api_key=config.get("api_key"),
+ openai_proxy=config.get("openai_proxy"),
+ temperature=temperature,
+ max_tokens=max_tokens,
+ )
+
+ elif model_name == "OpenAI":
+ model = ChatOpenAI(
+ streaming=streaming,
+ verbose=verbose,
+ callbacks=callbacks,
+ model_name=config.get("model_name"),
+ openai_api_base=config.get("api_base_url"),
+ openai_api_key=config.get("api_key"),
+ openai_proxy=config.get("openai_proxy"),
+ temperature=temperature,
+ max_tokens=max_tokens,
+ )
+ elif model_name == "Anthropic":
+ model = ChatAnthropic(
+ streaming=streaming,
+ verbose=verbose,
+ callbacks=callbacks,
+ model_name=config.get("model_name"),
+ anthropic_api_key=config.get("api_key"),
+
+ )
+ ## TODO 支持其他的Langchain原生支持的模型
+ else:
+ ## 非Langchain原生支持的模型,走Fschat封装
+ config = get_model_worker_config(model_name)
+ model = ChatOpenAI(
+ streaming=streaming,
+ verbose=verbose,
+ callbacks=callbacks,
+ openai_api_key=config.get("api_key", "EMPTY"),
+ openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
+ model_name=model_name,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ openai_proxy=config.get("openai_proxy"),
+ **kwargs
+ )
+
return model
@@ -249,8 +293,9 @@ def MakeFastAPIOffline(
redoc_favicon_url=favicon,
)
+ # 从model_config中获取模型信息
+
-# 从model_config中获取模型信息
def list_embed_models() -> List[str]:
'''
get names of configured embedding models
@@ -266,9 +311,9 @@ def list_config_llm_models() -> Dict[str, Dict]:
workers = list(FSCHAT_MODEL_WORKERS)
if LLM_MODEL not in workers:
workers.insert(0, LLM_MODEL)
-
return {
"local": MODEL_PATH["llm_model"],
+ "langchain": LANGCHAIN_LLM_MODEL,
"online": ONLINE_LLM_MODEL,
"worker": workers,
}
@@ -300,8 +345,9 @@ def get_model_path(model_name: str, type: str = None) -> Optional[str]:
return str(path)
return path_str # THUDM/chatglm06b
+ # 从server_config中获取服务信息
+
-# 从server_config中获取服务信息
def get_model_worker_config(model_name: str = None) -> dict:
'''
加载model worker的配置项。
@@ -316,6 +362,10 @@ def get_model_worker_config(model_name: str = None) -> dict:
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
# 在线模型API
+ if model_name in LANGCHAIN_LLM_MODEL:
+ config["langchain_model"] = True
+ config["worker_class"] = ""
+
if model_name in ONLINE_LLM_MODEL:
config["online_api"] = True
if provider := config.get("provider"):
@@ -389,7 +439,7 @@ def webui_address() -> str:
return f"http://{host}:{port}"
-def get_prompt_template(type:str,name: str) -> Optional[str]:
+def get_prompt_template(type: str, name: str) -> Optional[str]:
'''
从prompt_config中加载模板内容
type: "llm_chat","agent_chat","knowledge_base_chat","search_engine_chat"的其中一种,如果有新功能,应该进行加入。
@@ -459,8 +509,9 @@ def set_httpx_config(
import urllib.request
urllib.request.getproxies = _get_proxies
+ # 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch
+
-# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch
def detect_device() -> Literal["cuda", "mps", "cpu"]:
try:
import torch
diff --git a/startup.py b/startup.py
index a9504ee..4bb3449 100644
--- a/startup.py
+++ b/startup.py
@@ -68,7 +68,9 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
controller_address:
worker_address:
-
+ 对于Langchain支持的模型:
+ langchain_model:True
+ 不会使用fschat
对于online_api:
online_api:True
worker_class: `provider`
@@ -85,9 +87,11 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
for k, v in kwargs.items():
setattr(args, k, v)
-
+ if worker_class := kwargs.get("langchain_model"): #Langchian支持的模型不用做操作
+ from fastchat.serve.base_model_worker import app
+ worker = ""
# 在线模型API
- if worker_class := kwargs.get("worker_class"):
+ elif worker_class := kwargs.get("worker_class"):
from fastchat.serve.base_model_worker import app
worker = worker_class(model_names=args.model_names,
@@ -127,8 +131,8 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
args.engine_use_ray = False
args.disable_log_requests = False
- # 0.2.0 vllm后要加的参数
- args.max_model_len = 8192 # 模型可以处理的最大序列长度。请根据你的大模型设置,
+ # 0.2.0 vllm后要加的参数, 但是这里不需要
+ args.max_model_len = None
args.revision = None
args.quantization = None
args.max_log_len = None
diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py
index bdfdfa3..8628aac 100644
--- a/webui_pages/dialogue/dialogue.py
+++ b/webui_pages/dialogue/dialogue.py
@@ -4,10 +4,9 @@ from streamlit_chatbox import *
from datetime import datetime
import os
from configs import (LLM_MODEL, TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
- DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE)
+ DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE,LANGCHAIN_LLM_MODEL)
from typing import List, Dict
-
chat_box = ChatBox(
assistant_avatar=os.path.join(
"img",
@@ -42,7 +41,6 @@ def get_default_llm_model(api: ApiRequest) -> (str, bool):
返回类型为(model_name, is_local_model)
'''
running_models = api.list_running_models()
-
if not running_models:
return "", False
@@ -52,7 +50,6 @@ def get_default_llm_model(api: ApiRequest) -> (str, bool):
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
if local_models:
return local_models[0], True
-
return list(running_models)[0], False
@@ -82,7 +79,7 @@ def dialogue_page(api: ApiRequest):
"搜索引擎问答",
"自定义Agent问答",
],
- index=3,
+ index=0,
on_change=on_mode_change,
key="dialogue_mode",
)
@@ -100,16 +97,18 @@ def dialogue_page(api: ApiRequest):
return x
running_models = list(api.list_running_models())
+ running_models += LANGCHAIN_LLM_MODEL.keys()
available_models = []
config_models = api.list_config_models()
worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型
for m in worker_models:
if m not in running_models and m != "default":
available_models.append(m)
- for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型(如GPT)
+ for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型
if not v.get("provider") and k not in running_models:
- print(k, v)
available_models.append(k)
+ for k, v in config_models.get("langchain", {}).items(): # 列出LANGCHAIN_LLM_MODEL支持的模型
+ available_models.append(k)
llm_models = running_models + available_models
index = llm_models.index(st.session_state.get("cur_llm_model", get_default_llm_model(api)[0]))
llm_model = st.selectbox("选择LLM模型:",
@@ -119,10 +118,10 @@ def dialogue_page(api: ApiRequest):
on_change=on_llm_change,
key="llm_model",
)
- if (llm_model
- and st.session_state.get("prev_llm_model") != llm_model
- and not api.get_model_config(llm_model).get("online_api")
- and llm_model not in running_models):
+ if (st.session_state.get("prev_llm_model") != llm_model
+ and not llm_model in config_models.get("online", {})
+ and not llm_model in config_models.get("langchain", {})
+ and llm_model not in running_models):
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
prev_model = st.session_state.get("prev_llm_model")
r = api.change_llm_model(prev_model, llm_model)
@@ -228,9 +227,9 @@ def dialogue_page(api: ApiRequest):
])
text = ""
ans = ""
- support_agent = ["gpt", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
+ support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
if not any(agent in llm_model for agent in support_agent):
- ans += "正在思考... \n\n 该模型并没有进行Agent对齐,无法正常使用Agent功能!\n\n\n请更换 GPT4或Qwen-14B等支持Agent的模型获得更好的体验! \n\n\n"
+ ans += "正在思考... \n\n 该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!\n\n\n"
chat_box.update_msg(ans, element_index=0, streaming=False)
for d in api.agent_chat(prompt,
history=history,