Merge branch 'dev' of github.com:chatchat-space/Langchain-Chatchat into dev

This commit is contained in:
hzg0601 2023-08-25 09:43:46 +08:00
commit 999870c3a7
41 changed files with 1332 additions and 300 deletions

View File

@ -23,10 +23,11 @@ assignees: ''
描述实际发生的结果 / Describe the actual result.
**环境信息 / Environment Information**
- langchain-ChatGLM 版本/commit 号:(例如v1.0.0 或 commit 123456) / langchain-ChatGLM version/commit number: (e.g., v1.0.0 or commit 123456)
- langchain-ChatGLM 版本/commit 号:(例如v2.0.1 或 commit 123456) / langchain-ChatGLM version/commit number: (e.g., v2.0.1 or commit 123456)
- 是否使用 Docker 部署(是/否):是 / Is Docker deployment used (yes/no): yes
- 使用的模型ChatGLM-6B / ClueAI/ChatYuan-large-v2 等ChatGLM-6B / Model used (ChatGLM-6B / ClueAI/ChatYuan-large-v2, etc.): ChatGLM-6B
- 使用的 Embedding 模型GanymedeNil/text2vec-large-chinese 等GanymedeNil/text2vec-large-chinese / Embedding model used (GanymedeNil/text2vec-large-chinese, etc.): GanymedeNil/text2vec-large-chinese
- 使用的模型ChatGLM2-6B / Qwen-7B 等ChatGLM-6B / Model used (ChatGLM2-6B / Qwen-7B, etc.): ChatGLM2-6B
- 使用的 Embedding 模型moka-ai/m3e-base 等moka-ai/m3e-base / Embedding model used (moka-ai/m3e-base, etc.): moka-ai/m3e-base
- 使用的向量库类型 (faiss / milvus / pg_vector 等) faiss / Vector library used (faiss, milvus, pg_vector, etc.): faiss
- 操作系统及版本 / Operating system and version:
- Python 版本 / Python version:
- 其他相关环境信息 / Other relevant environment information:

2
.gitignore vendored
View File

@ -4,4 +4,4 @@ logs
.idea/
__pycache__/
knowledge_base/
configs/model_config.py
configs/*.py

147
README.md
View File

@ -126,6 +126,7 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
- [BAAI/bge-small-zh](https://huggingface.co/BAAI/bge-small-zh)
- [BAAI/bge-base-zh](https://huggingface.co/BAAI/bge-base-zh)
- [BAAI/bge-large-zh](https://huggingface.co/BAAI/bge-large-zh)
- [BAAI/bge-large-zh-noinstruct](https://huggingface.co/BAAI/bge-large-zh-noinstruct)
- [text2vec-base-chinese-sentence](https://huggingface.co/shibing624/text2vec-base-chinese-sentence)
- [text2vec-base-chinese-paraphrase](https://huggingface.co/shibing624/text2vec-base-chinese-paraphrase)
- [text2vec-base-multilingual](https://huggingface.co/shibing624/text2vec-base-multilingual)
@ -133,6 +134,7 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
- [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese)
- [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh)
- [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh)
- [OpenAI/text-embedding-ada-002](https://platform.openai.com/docs/guides/embeddings)
---
@ -181,9 +183,11 @@ $ git clone https://huggingface.co/moka-ai/m3e-base
### 3. 设置配置项
复制文件 [configs/model_config.py.example](configs/model_config.py.example) 存储至项目路径下 `./configs` 路径下,并重命名为 `model_config.py`
复制模型相关参数配置模板文件 [configs/model_config.py.example](configs/model_config.py.example) 存储至项目路径下 `./configs` 路径下,并重命名为 `model_config.py`
在开始执行 Web UI 或命令行交互前,请先检查 `configs/model_config.py` 中的各项模型参数设计是否符合需求:
复制服务相关参数配置模板文件 [configs/server_config.py.example](configs/server_config.py.example) 存储至项目路径下 `./configs` 路径下,并重命名为 `server_config.py`
在开始执行 Web UI 或命令行交互前,请先检查 `configs/model_config.py``configs/server_config.py` 中的各项模型参数设计是否符合需求:
- 请确认已下载至本地的 LLM 模型本地存储路径写在 `llm_model_dict` 对应模型的 `local_model_path` 属性中,如:
@ -204,18 +208,18 @@ embedding_model_dict = {
"m3e-base": "/Users/xxx/Downloads/m3e-base",
}
```
如果你选择使用OpenAI的Embedding模型请将模型的```key```写入`embedding_model_dict`中。使用该模型你需要鞥能够访问OpenAI官的API或设置代理。
### 4. 知识库初始化与迁移
当前项目的知识库信息存储在数据库中,在正式运行项目之前请先初始化数据库(我们强烈建议您在执行操作前备份您的知识文件)。
- 如果您是从 `0.1.x` 版本升级过来的用户针对已建立的知识库请确认知识库的向量库类型、Embedding 模型 `configs/model_config.py` 中默认设置一致,如无变化只需以下命令将现有知识库信息添加到数据库即可:
- 如果您是从 `0.1.x` 版本升级过来的用户针对已建立的知识库请确认知识库的向量库类型、Embedding 模型 `configs/model_config.py` 中默认设置一致,如无变化只需以下命令将现有知识库信息添加到数据库即可:
```shell
$ python init_database.py
```
- 如果您是第一次运行本项目,知识库尚未建立,或者配置文件中的知识库类型、嵌入模型发生变化,需要以下命令初始化或重建知识库:
- 如果您是第一次运行本项目,知识库尚未建立,或者配置文件中的知识库类型、嵌入模型发生变化,或者之前的向量库没有开启`normalize_L2`,需要以下命令初始化或重建知识库:
```shell
$ python init_database.py --recreate-vs
@ -228,7 +232,7 @@ embedding_model_dict = {
如需使用开源模型进行本地部署,需首先启动 LLM 服务,启动方式分为三种:
- [基于多进程脚本 llm_api.py 启动 LLM 服务](README.md#5.1.1-基于多进程脚本-llm_api.py-启动-LLM-服务)
- [基于命令行脚本 llm_api_launch.py 启动 LLM 服务](README.md#5.1.2-基于命令行脚本-llm_api_launch.py-启动-LLM-服务)
- [基于命令行脚本 llm_api_stale.py 启动 LLM 服务](README.md#5.1.2-基于命令行脚本-llm_api_stale.py-启动-LLM-服务)
- [PEFT 加载](README.md#5.1.3-PEFT-加载)
三种方式只需选择一个即可,具体操作方式详见 5.1.1 - 5.1.3。
@ -244,6 +248,7 @@ $ python server/llm_api.py
```
项目支持多卡加载,需在 llm_api.py 中修改 create_model_worker_app 函数中,修改如下三个参数:
```python
gpus=None,
num_gpus=1,
@ -256,34 +261,36 @@ max_gpu_memory="20GiB"
`max_gpu_memory` 控制每个卡使用的显存容量。
##### 5.1.2 基于命令行脚本 llm_api_launch.py 启动 LLM 服务
##### 5.1.2 基于命令行脚本 llm_api_stale.py 启动 LLM 服务
⚠️ **注意:**
⚠️ **注意:**
**1.llm_api_launch.py脚本原生仅适用于linux,mac设备需要安装对应的linux命令,win平台请使用wls;**
**1.llm_api_stale.py脚本原生仅适用于linux,mac设备需要安装对应的linux命令,win平台请使用wls;**
**2.加载非默认模型需要用命令行参数--model-path-address指定模型不会读取model_config.py配置;**
在项目根目录下,执行 [server/llm_api_launch.py](server/llm_api.py) 脚本启动 **LLM 模型**服务:
在项目根目录下,执行 [server/llm_api_stale.py](server/llm_api_stale.py) 脚本启动 **LLM 模型**服务:
```shell
$ python server/llm_api_launch.py
$ python server/llm_api_stale.py
```
该方式支持启动多个worker示例启动方式
```shell
$ python server/llm_api_launch.py --model-path-addresss model1@host1@port1 model2@host2@port2
$ python server/llm_api_stale.py --model-path-address model1@host1@port1 model2@host2@port2
```
如果出现server端口占用情况需手动指定server端口,并同步修改model_config.py下对应模型的base_api_url为指定端口:
```shell
$ python server/llm_api_launch.py --server-port 8887
$ python server/llm_api_stale.py --server-port 8887
```
如果要启动多卡加载,示例命令如下:
```shell
$ python server/llm_api_launch.py --gpus 0,1 --num-gpus 2 --max-gpu-memory 10GiB
$ python server/llm_api_stale.py --gpus 0,1 --num-gpus 2 --max-gpu-memory 10GiB
```
以如上方式启动LLM服务会以nohup命令在后台运行 FastChat 服务,如需停止服务,可以运行如下命令:
@ -294,24 +301,13 @@ $ python server/llm_api_shutdown.py --serve all
亦可单独停止一个 FastChat 服务模块,可选 [`all`, `controller`, `model_worker`, `openai_api_server`]
##### 5.1.3 PEFT 加载
##### 5.1.3 PEFT 加载(包括lora,p-tuning,prefix tuning, prompt tuning,ia等)
本项目基于 FastChat 加载 LLM 服务,故需以 FastChat 加载 PEFT 路径,即保证路径名称里必须有 peft 这个词,配置文件的名字为 adapter_config.jsonpeft 路径下包含 model.bin 格式的 PEFT 权重。
详细步骤参考[加载lora微调后模型失效](https://github.com/chatchat-space/Langchain-Chatchat/issues/1130#issuecomment-1685291822)
示例代码如下:
![image](https://github.com/chatchat-space/Langchain-Chatchat/assets/22924096/4e056c1c-5c4b-4865-a1af-859cd58a625d)
```shell
PEFT_SHARE_BASE_WEIGHTS=true python3 -m fastchat.serve.multi_model_worker \
--model-path /data/chris/peft-llama-dummy-1 \
--model-names peft-dummy-1 \
--model-path /data/chris/peft-llama-dummy-2 \
--model-names peft-dummy-2 \
--model-path /data/chris/peft-llama-dummy-3 \
--model-names peft-dummy-3 \
--num-gpus 2
```
详见 [FastChat 相关 PR](https://github.com/lm-sys/fastchat/pull/1905#issuecomment-1627801216)
#### 5.2 启动 API 服务
@ -354,7 +350,6 @@ $ streamlit run webui.py --server.port 666
- Web UI 对话界面:
![](img/webui_0813_0.png)
- Web UI 知识库管理页面:
![](img/webui_0813_1.png)
@ -363,86 +358,40 @@ $ streamlit run webui.py --server.port 666
### 6. 一键启动
⚠️ **注意:**
**1. 一键启动脚本仅原生适用于Linux,Mac 设备需要安装对应的linux命令, Winodws 平台请使用 WLS;**
**2. 加载非默认模型需要用命令行参数 `--model-path-address` 指定模型,不会读取 `model_config.py` 配置。**
#### 6.1 API 服务一键启动脚本
新增 API 一键启动脚本,可一键开启 FastChat 后台服务及本项目提供的 API 服务,调用示例:
调用默认模型:
更新一键启动脚本 startup.py,一键启动所有 Fastchat 服务、API 服务、WebUI 服务,示例代码:
```shell
$ python server/api_allinone.py
$ python startup.py -a
```
加载多个非默认模型:
并可使用 `Ctrl + C` 直接关闭所有运行服务。如果一次结束不了,可以多按几次。
可选参数包括 `-a (或--all-webui)`, `--all-api`, `--llm-api`, `-c (或--controller)`, `--openai-api`,
`-m (或--model-worker)`, `--api`, `--webui`,其中:
- `--all-webui` 为一键启动 WebUI 所有依赖服务;
- `--all-api` 为一键启动 API 所有依赖服务;
- `--llm-api` 为一键启动 Fastchat 所有依赖的 LLM 服务;
- `--openai-api` 为仅启动 FastChat 的 controller 和 openai-api-server 服务;
- 其他为单独服务启动选项。
若想指定非默认模型,需要用 `--model-name` 选项,示例:
```shell
$ python server/api_allinone.py --model-path-address model1@host1@port1 model2@host2@port2
$ python startup.py --all-webui --model-name Qwen-7B-Chat
```
如果出现server端口占用情况需手动指定server端口,并同步修改model_config.py下对应模型的base_api_url为指定端口:
更多信息可通过`python startup.py -h`查看。
```shell
$ python server/api_allinone.py --server-port 8887
```
**注意:**
多卡启动:
**1. startup 脚本用多进程方式启动各模块的服务,可能会导致打印顺序问题,请等待全部服务发起后再调用,并根据默认或指定端口调用服务(默认 LLM API 服务端口:`127.0.0.1:8888`,默认 API 服务端口:`127.0.0.1:7861`,默认 WebUI 服务端口:`本机IP8501`)**
```shell
python server/api_allinone.py --model-path-address model@host@port --num-gpus 2 --gpus 0,1 --max-gpu-memory 10GiB
```
其他参数详见各脚本及 FastChat 服务说明。
#### 6.2 webui一键启动脚本
加载本地模型:
```shell
$ python webui_allinone.py
```
调用远程 API 服务:
```shell
$ python webui_allinone.py --use-remote-api
```
如果出现server端口占用情况需手动指定server端口,并同步修改model_config.py下对应模型的base_api_url为指定端口:
```shell
$ python webui_allinone.py --server-port 8887
```
后台运行webui服务
```shell
$ python webui_allinone.py --nohup
```
加载多个非默认模型:
```shell
$ python webui_allinone.py --model-path-address model1@host1@port1 model2@host2@port2
```
多卡启动:
```shell
$ python webui_alline.py --model-path-address model@host@port --num-gpus 2 --gpus 0,1 --max-gpu-memory 10GiB
```
其他参数详见各脚本及 Fastchat 服务说明。
上述两个一键启动脚本会后台运行多个服务,如要停止所有服务,可使用 `shutdown_all.sh` 脚本:
```shell
bash shutdown_all.sh
```
**2.服务启动时间示设备不同而不同,约 3-10 分钟,如长时间没有启动请前往 `./logs`目录下监控日志,定位问题。**
## 常见问题
@ -486,6 +435,6 @@ bash shutdown_all.sh
## 项目交流群
<img src="img/qr_code_52.jpg" alt="二维码" width="300" height="300" />
<img src="img/qr_code_56.jpg" alt="二维码" width="300" height="300" />
🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。

0
common/__init__.py Normal file
View File

View File

@ -1 +1,4 @@
from .model_config import *
from .model_config import *
from .server_config import *
VERSION = "v0.2.2-preview"

View File

@ -1,14 +1,11 @@
import os
import logging
import torch
import argparse
import json
# 日志格式
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.basicConfig(format=LOG_FORMAT)
import json
# 在以下字典中修改属性值以指定本地embedding模型存储位置
@ -27,7 +24,9 @@ embedding_model_dict = {
"m3e-large": "moka-ai/m3e-large",
"bge-small-zh": "BAAI/bge-small-zh",
"bge-base-zh": "BAAI/bge-base-zh",
"bge-large-zh": "BAAI/bge-large-zh"
"bge-large-zh": "BAAI/bge-large-zh",
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
"text-embedding-ada-002": os.environ.get("OPENAI_API_KEY")
}
# 选用的 Embedding 名称
@ -44,27 +43,15 @@ llm_model_dict = {
"api_key": "EMPTY"
},
"chatglm-6b-int4": {
"local_model_path": "THUDM/chatglm-6b-int4",
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_key": "EMPTY"
},
"chatglm2-6b": {
"local_model_path": "THUDM/chatglm2-6b",
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_base_url": "http://localhost:8888/v1", # URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
"api_key": "EMPTY"
},
"chatglm2-6b-32k": {
"local_model_path": "THUDM/chatglm2-6b-32k", # "THUDM/chatglm2-6b-32k",
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_key": "EMPTY"
},
"vicuna-13b-hf": {
"local_model_path": "",
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_base_url": "http://localhost:8888/v1", # "URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
"api_key": "EMPTY"
},
@ -78,10 +65,15 @@ llm_model_dict = {
# urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPSConnection object at 0x000001FE4BDB85E0>:
# Failed to establish a new connection: [WinError 10060]
# 则是因为内地和香港的IP都被OPENAI封了需要切换为日本、新加坡等地
# 如果出现WARNING: Retrying langchain.chat_models.openai.acompletion_with_retry.<locals>._completion_with_retry in
# 4.0 seconds as it raised APIConnectionError: Error communicating with OpenAI.
# 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置
"gpt-3.5-turbo": {
"local_model_path": "gpt-3.5-turbo",
"api_base_url": "https://api.openai.com/v1",
"api_key": os.environ.get("OPENAI_API_KEY")
"api_key": os.environ.get("OPENAI_API_KEY"),
"openai_proxy": os.environ.get("OPENAI_PROXY")
},
}
@ -117,7 +109,7 @@ kbs_config = {
"secure": False,
},
"pg": {
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatglm",
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
}
}
@ -145,12 +137,12 @@ SEARCH_ENGINE_TOP_K = 5
# nltk 模型存储路径
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
# 基于本地知识问答的提示词模版
PROMPT_TEMPLATE = """【指令】根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。
# 基于本地知识问答的提示词模版使用Jinja2语法简单点就是用双大括号代替f-string的单大括号
PROMPT_TEMPLATE = """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
【已知信息】{context}
<已知信息>{{ context }}</已知信息>
【问题】{question}"""
<问题>{{ question }}</问题>"""
# API 是否开启跨域默认为False如果需要开启请设置为True
# is open cross domain

View File

@ -0,0 +1,100 @@
from .model_config import LLM_MODEL, LLM_DEVICE
# API 是否开启跨域默认为False如果需要开启请设置为True
# is open cross domain
OPEN_CROSS_DOMAIN = False
# 各服务器默认绑定host
DEFAULT_BIND_HOST = "127.0.0.1"
# webui.py server
WEBUI_SERVER = {
"host": DEFAULT_BIND_HOST,
"port": 8501,
}
# api.py server
API_SERVER = {
"host": DEFAULT_BIND_HOST,
"port": 7861,
}
# fastchat openai_api server
FSCHAT_OPENAI_API = {
"host": DEFAULT_BIND_HOST,
"port": 8888, # model_config.llm_model_dict中模型配置的api_base_url需要与这里一致。
}
# fastchat model_worker server
# 这些模型必须是在model_config.llm_model_dict中正确配置的。
# 在启动startup.py时可用通过`--model-worker --model-name xxxx`指定模型不指定则为LLM_MODEL
FSCHAT_MODEL_WORKERS = {
LLM_MODEL: {
"host": DEFAULT_BIND_HOST,
"port": 20002,
"device": LLM_DEVICE,
# todo: 多卡加载需要配置的参数
"gpus": None,
"numgpus": 1,
# 以下为非常用参数,可根据需要配置
# "max_gpu_memory": "20GiB",
# "load_8bit": False,
# "cpu_offloading": None,
# "gptq_ckpt": None,
# "gptq_wbits": 16,
# "gptq_groupsize": -1,
# "gptq_act_order": False,
# "awq_ckpt": None,
# "awq_wbits": 16,
# "awq_groupsize": -1,
# "model_names": [LLM_MODEL],
# "conv_template": None,
# "limit_worker_concurrency": 5,
# "stream_interval": 2,
# "no_register": False,
},
}
# fastchat multi model worker server
FSCHAT_MULTI_MODEL_WORKERS = {
# todo
}
# fastchat controller server
FSCHAT_CONTROLLER = {
"host": DEFAULT_BIND_HOST,
"port": 20001,
"dispatch_method": "shortest_queue",
}
# 以下不要更改
def fschat_controller_address() -> str:
host = FSCHAT_CONTROLLER["host"]
port = FSCHAT_CONTROLLER["port"]
return f"http://{host}:{port}"
def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str:
if model := FSCHAT_MODEL_WORKERS.get(model_name):
host = model["host"]
port = model["port"]
return f"http://{host}:{port}"
def fschat_openai_api_address() -> str:
host = FSCHAT_OPENAI_API["host"]
port = FSCHAT_OPENAI_API["port"]
return f"http://{host}:{port}"
def api_address() -> str:
host = API_SERVER["host"]
port = API_SERVER["port"]
return f"http://{host}:{port}"
def webui_address() -> str:
host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]
return f"http://{host}:{port}"

View File

@ -170,3 +170,16 @@ A13: 疑为 chatglm 的 quantization 的问题或 torch 版本差异问题,针
Q14: 修改配置中路径后,加载 text2vec-large-chinese 依然提示 `WARNING: No sentence-transformers model found with name text2vec-large-chinese. Creating a new one with MEAN pooling.`
A14: 尝试更换 embedding如 text2vec-base-chinese请在 [configs/model_config.py](../configs/model_config.py) 文件中,修改 `text2vec-base`参数为本地路径,绝对路径或者相对路径均可
---
Q15: 使用pg向量库建表报错
A15: 需要手动安装对应的vector扩展(连接pg执行 CREATE EXTENSION IF NOT EXISTS vector)
---
Q16: pymilvus 连接超时
A16.pymilvus版本需要匹配和milvus对应否则会超时参考pymilvus==2.1.3

View File

@ -2,9 +2,9 @@ version: "3.8"
services:
postgresql:
image: ankane/pgvector:v0.4.1
container_name: langchain-chatgml-pg-db
container_name: langchain_chatchat-pg-db
environment:
POSTGRES_DB: langchain_chatgml
POSTGRES_DB: langchain_chatchat
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
ports:

View File

@ -5,3 +5,4 @@
cd docs/docker/vector_db/milvus
docker-compose up -d
```

BIN
img/qr_code_53.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 292 KiB

BIN
img/qr_code_54.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 269 KiB

BIN
img/qr_code_55.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 291 KiB

BIN
img/qr_code_56.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 200 KiB

View File

@ -2,6 +2,8 @@ from server.knowledge_base.migrate import create_tables, folder2db, recreate_all
from configs.model_config import NLTK_DATA_PATH
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
from startup import dump_server_info
if __name__ == "__main__":
import argparse
@ -21,6 +23,8 @@ if __name__ == "__main__":
)
args = parser.parse_args()
dump_server_info()
create_tables()
print("database talbes created")

View File

@ -1,4 +1,4 @@
langchain==0.0.257
langchain==0.0.266
openai
sentence_transformers
fschat==0.2.24

View File

@ -1,4 +1,4 @@
langchain==0.0.257
langchain==0.0.266
openai
sentence_transformers
fschat==0.2.24

View File

@ -4,7 +4,9 @@ import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN
from configs.model_config import NLTK_DATA_PATH
from configs.server_config import OPEN_CROSS_DOMAIN
from configs import VERSION
import argparse
import uvicorn
from fastapi.middleware.cors import CORSMiddleware
@ -14,11 +16,10 @@ from server.chat import (chat, knowledge_base_chat, openai_chat,
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
update_doc, download_doc, recreate_vector_store,
search_docs, DocumentWithScore)
search_docs, DocumentWithScore)
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
from typing import List
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@ -27,7 +28,10 @@ async def document():
def create_app():
app = FastAPI(title="Langchain-Chatchat API Server")
app = FastAPI(
title="Langchain-Chatchat API Server",
version=VERSION
)
MakeFastAPIOffline(app)
# Add CORS middleware to allow all origins
# 在config.py中设置OPEN_DOMAIN=True允许跨域
@ -75,10 +79,10 @@ def create_app():
)(create_kb)
app.post("/knowledge_base/delete_knowledge_base",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库"
)(delete_kb)
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库"
)(delete_kb)
app.get("/knowledge_base/list_docs",
tags=["Knowledge Base Management"],
@ -87,10 +91,10 @@ def create_app():
)(list_docs)
app.post("/knowledge_base/search_docs",
tags=["Knowledge Base Management"],
response_model=List[DocumentWithScore],
summary="搜索知识库"
)(search_docs)
tags=["Knowledge Base Management"],
response_model=List[DocumentWithScore],
summary="搜索知识库"
)(search_docs)
app.post("/knowledge_base/upload_doc",
tags=["Knowledge Base Management"],
@ -99,10 +103,10 @@ def create_app():
)(upload_doc)
app.post("/knowledge_base/delete_doc",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库内指定文件"
)(delete_doc)
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库内指定文件"
)(delete_doc)
app.post("/knowledge_base/update_doc",
tags=["Knowledge Base Management"],

View File

@ -15,7 +15,7 @@ import os
sys.path.append(os.path.dirname(__file__))
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from llm_api_launch import launch_all, parser, controller_args, worker_args, server_args
from llm_api_stale import launch_all, parser, controller_args, worker_args, server_args
from api import create_app
import uvicorn

View File

@ -21,7 +21,7 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
),
stream: bool = Body(False, description="流式输出"),
):
history = [History(**h) if isinstance(h, dict) else h for h in history]
history = [History.from_data(h) for h in history]
async def chat_iterator(query: str,
history: List[History] = [],
@ -34,11 +34,13 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL
model_name=LLM_MODEL,
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
)
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_tuple() for i in history] + [("human", "{input}")])
[i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model)
# Begin a task that runs in the background.

View File

@ -38,7 +38,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
history = [History(**h) if isinstance(h, dict) else h for h in history]
history = [History.from_data(h) for h in history]
async def knowledge_base_chat_iterator(query: str,
kb: KBService,
@ -52,13 +52,15 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL
model_name=LLM_MODEL,
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
)
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs])
input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
[i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model)

View File

@ -73,6 +73,8 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
history = [History.from_data(h) for h in history]
async def search_engine_chat_iterator(query: str,
search_engine_name: str,
top_k: int,
@ -85,14 +87,16 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL
model_name=LLM_MODEL,
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
)
docs = lookup_search_engine(query, search_engine_name, top_k)
context = "\n".join([doc.page_content for doc in docs])
input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
[i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model)
@ -117,7 +121,7 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps({"answer": token,
yield json.dumps({"answer": answer,
"docs": source_documents},
ensure_ascii=False)
await task

View File

@ -1,6 +1,7 @@
import asyncio
from typing import Awaitable
from typing import Awaitable, List, Tuple, Dict, Union
from pydantic import BaseModel, Field
from langchain.prompts.chat import ChatMessagePromptTemplate
async def wrap_done(fn: Awaitable, event: asyncio.Event):
@ -28,3 +29,29 @@ class History(BaseModel):
def to_msg_tuple(self):
return "ai" if self.role=="assistant" else "human", self.content
def to_msg_template(self, is_raw=True) -> ChatMessagePromptTemplate:
role_maps = {
"ai": "assistant",
"human": "user",
}
role = role_maps.get(self.role, self.role)
if is_raw: # 当前默认历史消息都是没有input_variable的文本。
content = "{% raw %}" + self.content + "{% endraw %}"
else:
content = self.content
return ChatMessagePromptTemplate.from_template(
content,
"jinja2",
role=role,
)
@classmethod
def from_data(cls, h: Union[List, Tuple, Dict]) -> "History":
if isinstance(h, (list,tuple)) and len(h) >= 2:
h = cls(role=h[0], content=h[1])
elif isinstance(h, dict):
h = cls(**h)
return h

View File

@ -15,7 +15,7 @@ async def list_kbs():
async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"),
embed_model: str = Body(EMBEDDING_MODEL),
):
) -> BaseResponse:
# Create selected knowledge base
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -27,13 +27,18 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
kb.create_kb()
try:
kb.create_kb()
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"创建知识库出错: {e}")
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
async def delete_kb(
knowledge_base_name: str = Body(..., examples=["samples"])
):
) -> BaseResponse:
# Delete selected knowledge base
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -51,5 +56,6 @@ async def delete_kb(
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"删除知识库时出现意外: {e}")
return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}")

View File

@ -22,7 +22,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
) -> List[DocumentWithScore]:
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}", "docs": []}
return []
docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
@ -31,7 +31,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
async def list_docs(
knowledge_base_name: str
):
) -> ListResponse:
if not validate_kb_name(knowledge_base_name):
return ListResponse(code=403, msg="Don't attack me", data=[])
@ -41,13 +41,14 @@ async def list_docs(
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
else:
all_doc_names = kb.list_docs()
return ListResponse(data=all_doc_names)
return ListResponse(data=all_doc_names)
async def upload_doc(file: UploadFile = File(..., description="上传文件"),
knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]),
override: bool = Form(False, description="覆盖已有文件"),
):
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -57,31 +58,38 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"),
file_content = await file.read() # 读取上传文件的内容
kb_file = KnowledgeFile(filename=file.filename,
knowledge_base_name=knowledge_base_name)
if (os.path.exists(kb_file.filepath)
and not override
and os.path.getsize(kb_file.filepath) == len(file_content)
):
# TODO: filesize 不同后的处理
file_status = f"文件 {kb_file.filename} 已存在。"
return BaseResponse(code=404, msg=file_status)
try:
kb_file = KnowledgeFile(filename=file.filename,
knowledge_base_name=knowledge_base_name)
if (os.path.exists(kb_file.filepath)
and not override
and os.path.getsize(kb_file.filepath) == len(file_content)
):
# TODO: filesize 不同后的处理
file_status = f"文件 {kb_file.filename} 已存在。"
return BaseResponse(code=404, msg=file_status)
with open(kb_file.filepath, "wb") as f:
f.write(file_content)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}")
kb.add_doc(kb_file)
try:
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}")
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
doc_name: str = Body(..., examples=["file_name.md"]),
delete_content: bool = Body(False),
):
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -92,17 +100,23 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
if not kb.exist_doc(doc_name):
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
kb_file = KnowledgeFile(filename=doc_name,
knowledge_base_name=knowledge_base_name)
kb.delete_doc(kb_file, delete_content)
try:
kb_file = KnowledgeFile(filename=doc_name,
knowledge_base_name=knowledge_base_name)
kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=not_refresh_vs_cache)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}")
return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功")
# return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败")
async def update_doc(
knowledge_base_name: str = Body(..., examples=["samples"]),
file_name: str = Body(..., examples=["file_name"]),
):
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
'''
更新知识库文档
'''
@ -113,14 +127,17 @@ async def update_doc(
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
try:
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
if os.path.exists(kb_file.filepath):
kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败,错误信息是:{e}")
if os.path.exists(kb_file.filepath):
kb.update_doc(kb_file)
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
else:
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
async def download_doc(
@ -137,18 +154,20 @@ async def download_doc(
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
if os.path.exists(kb_file.filepath):
return FileResponse(
path=kb_file.filepath,
filename=kb_file.filename,
media_type="multipart/form-data")
else:
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
try:
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
if os.path.exists(kb_file.filepath):
return FileResponse(
path=kb_file.filepath,
filename=kb_file.filename,
media_type="multipart/form-data")
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败,错误信息是:{e}")
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
async def recreate_vector_store(
@ -163,24 +182,35 @@ async def recreate_vector_store(
by default, get_service_by_name only return knowledge base in the info.db and having document files in it.
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
'''
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
async def output(kb):
kb.create_kb()
kb.clear_vs()
docs = list_docs_from_folder(knowledge_base_name)
for i, doc in enumerate(docs):
try:
kb_file = KnowledgeFile(doc, knowledge_base_name)
yield json.dumps({
"total": len(docs),
"finished": i,
"doc": doc,
}, ensure_ascii=False)
kb.add_doc(kb_file)
except Exception as e:
print(e)
async def output():
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else:
kb.create_kb()
kb.clear_vs()
docs = list_docs_from_folder(knowledge_base_name)
for i, doc in enumerate(docs):
try:
kb_file = KnowledgeFile(doc, knowledge_base_name)
yield json.dumps({
"code": 200,
"msg": f"({i + 1} / {len(docs)}): {doc}",
"total": len(docs),
"finished": i,
"doc": doc,
}, ensure_ascii=False)
if i == len(docs) - 1:
not_refresh_vs_cache = False
else:
not_refresh_vs_cache = True
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
except Exception as e:
print(e)
yield json.dumps({
"code": 500,
"msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。",
})
return StreamingResponse(output(kb), media_type="text/event-stream")
return StreamingResponse(output(), media_type="text/event-stream")

View File

@ -71,36 +71,37 @@ class KBService(ABC):
status = delete_kb_from_db(self.kb_name)
return status
def add_doc(self, kb_file: KnowledgeFile):
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
"""
向知识库添加文件
"""
docs = kb_file.file2text()
if docs:
self.delete_doc(kb_file)
embeddings = self._load_embeddings()
self.do_add_doc(docs, embeddings)
self.do_add_doc(docs, embeddings, **kwargs)
status = add_doc_to_db(kb_file)
else:
status = False
return status
def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False):
def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs):
"""
从知识库删除文件
"""
self.do_delete_doc(kb_file)
self.do_delete_doc(kb_file, **kwargs)
status = delete_file_from_db(kb_file)
if delete_content and os.path.exists(kb_file.filepath):
os.remove(kb_file.filepath)
return status
def update_doc(self, kb_file: KnowledgeFile):
def update_doc(self, kb_file: KnowledgeFile, **kwargs):
"""
使用content中的文件更新向量库
"""
if os.path.exists(kb_file.filepath):
self.delete_doc(kb_file)
return self.add_doc(kb_file)
self.delete_doc(kb_file, **kwargs)
return self.add_doc(kb_file, **kwargs)
def exist_doc(self, file_name: str):
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
@ -156,6 +157,7 @@ class KBService(ABC):
def do_search(self,
query: str,
top_k: int,
score_threshold: float,
embeddings: Embeddings,
) -> List[Document]:
"""

View File

@ -13,7 +13,8 @@ from functools import lru_cache
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
from langchain.vectorstores import FAISS
from langchain.embeddings.base import Embeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings,HuggingFaceBgeEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from typing import List
from langchain.docstore.document import Document
from server.utils import torch_gc
@ -21,10 +22,19 @@ from server.utils import torch_gc
# make HuggingFaceEmbeddings hashable
def _embeddings_hash(self):
return hash(self.model_name)
if isinstance(self, HuggingFaceEmbeddings):
return hash(self.model_name)
elif isinstance(self, HuggingFaceBgeEmbeddings):
return hash(self.model_name)
elif isinstance(self, OpenAIEmbeddings):
return hash(self.model)
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
OpenAIEmbeddings.__hash__ = _embeddings_hash
HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash
_VECTOR_STORE_TICKS = {}
_VECTOR_STORE_TICKS = {}
@ -41,7 +51,23 @@ def load_vector_store(
vs_path = get_vs_path(knowledge_base_name)
if embeddings is None:
embeddings = load_embeddings(embed_model, embed_device)
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
if not os.path.exists(vs_path):
os.makedirs(vs_path)
if "index.faiss" in os.listdir(vs_path):
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
else:
# create an empty vector store
doc = Document(page_content="init", metadata={})
search_index = FAISS.from_documents([doc], embeddings, normalize_L2=True)
ids = [k for k, v in search_index.docstore._dict.items()]
search_index.delete(ids)
search_index.save_local(vs_path)
if tick == 0: # vector store is loaded first time
_VECTOR_STORE_TICKS[knowledge_base_name] = 0
return search_index
@ -50,6 +76,7 @@ def refresh_vs_cache(kb_name: str):
make vector store cache refreshed when next loading
"""
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
print(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
class FaissKBService(KBService):
@ -74,8 +101,10 @@ class FaissKBService(KBService):
def do_create_kb(self):
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
load_vector_store(self.kb_name)
def do_drop_kb(self):
self.clear_vs()
shutil.rmtree(self.kb_path)
def do_search(self,
@ -93,38 +122,40 @@ class FaissKBService(KBService):
def do_add_doc(self,
docs: List[Document],
embeddings: Embeddings,
**kwargs,
):
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True)
vector_store.add_documents(docs)
torch_gc()
else:
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
vector_store = FAISS.from_documents(
docs, embeddings, normalize_L2=True) # docs 为Document列表
torch_gc()
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
def do_delete_doc(self,
kb_file: KnowledgeFile):
embeddings = self._load_embeddings()
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True)
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
if len(ids) == 0:
return None
vector_store.delete(ids)
vector_store = load_vector_store(self.kb_name,
embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
vector_store.add_documents(docs)
torch_gc()
if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
return True
else:
def do_delete_doc(self,
kb_file: KnowledgeFile,
**kwargs):
embeddings = self._load_embeddings()
vector_store = load_vector_store(self.kb_name,
embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
if len(ids) == 0:
return None
vector_store.delete(ids)
if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
return True
def do_clear_vs(self):
shutil.rmtree(self.vs_path)
os.makedirs(self.vs_path)
refresh_vs_cache(self.kb_name)
def exist_doc(self, file_name: str):
if super().exist_doc(file_name):

View File

@ -45,12 +45,12 @@ class MilvusKBService(KBService):
def do_drop_kb(self):
self.milvus.col.drop()
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
def do_search(self, query: str, top_k: int,score_threshold: float, embeddings: Embeddings):
# todo: support score threshold
self._load_milvus(embeddings=embeddings)
return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD)
return self.milvus.similarity_search_with_score(query, top_k)
def add_doc(self, kb_file: KnowledgeFile):
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
"""
向知识库添加文件
"""
@ -60,22 +60,24 @@ class MilvusKBService(KBService):
status = add_doc_to_db(kb_file)
return status
def do_add_doc(self, docs: List[Document], embeddings: Embeddings):
def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs):
pass
def do_delete_doc(self, kb_file: KnowledgeFile):
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
filepath = kb_file.filepath.replace('\\', '\\\\')
delete_list = [item.get("pk") for item in
self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
self.milvus.col.delete(expr=f'pk in {delete_list}')
def do_clear_vs(self):
self.milvus.col.drop()
if not self.milvus.col:
self.milvus.col.drop()
if __name__ == '__main__':
# 测试建表使用
from server.db.base import Base, engine
Base.metadata.create_all(bind=engine)
milvusService = MilvusKBService("test")
milvusService.add_doc(KnowledgeFile("README.md", "test"))

View File

@ -43,12 +43,12 @@ class PGKBService(KBService):
'''))
connect.commit()
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
# todo: support score threshold
self._load_pg_vector(embeddings=embeddings)
return self.pg_vector.similarity_search(query, top_k)
return self.pg_vector.similarity_search_with_score(query, top_k)
def add_doc(self, kb_file: KnowledgeFile):
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
"""
向知识库添加文件
"""
@ -58,10 +58,10 @@ class PGKBService(KBService):
status = add_doc_to_db(kb_file)
return status
def do_add_doc(self, docs: List[Document], embeddings: Embeddings):
def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs):
pass
def do_delete_doc(self, kb_file: KnowledgeFile):
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
with self.pg_vector.connect() as connect:
filepath = kb_file.filepath.replace('\\', '\\\\')
connect.execute(
@ -76,6 +76,7 @@ class PGKBService(KBService):
if __name__ == '__main__':
from server.db.base import Base, engine
Base.metadata.create_all(bind=engine)
pGKBService = PGKBService("test")
pGKBService.create_kb()

View File

@ -43,7 +43,11 @@ def folder2db(
kb_file = KnowledgeFile(doc, kb_name)
if callable(callback_before):
callback_before(kb_file, i, docs)
kb.add_doc(kb_file)
if i == len(docs) - 1:
not_refresh_vs_cache = False
else:
not_refresh_vs_cache = True
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
if callable(callback_after):
callback_after(kb_file, i, docs)
except Exception as e:
@ -67,7 +71,11 @@ def folder2db(
kb_file = KnowledgeFile(doc, kb_name)
if callable(callback_before):
callback_before(kb_file, i, docs)
kb.update_doc(kb_file)
if i == len(docs) - 1:
not_refresh_vs_cache = False
else:
not_refresh_vs_cache = True
kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
if callable(callback_after):
callback_after(kb_file, i, docs)
except Exception as e:
@ -81,7 +89,11 @@ def folder2db(
kb_file = KnowledgeFile(doc, kb_name)
if callable(callback_before):
callback_before(kb_file, i, docs)
kb.add_doc(kb_file)
if i == len(docs) - 1:
not_refresh_vs_cache = False
else:
not_refresh_vs_cache = True
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
if callable(callback_after):
callback_after(kb_file, i, docs)
except Exception as e:

View File

@ -1,5 +1,7 @@
import os
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceBgeEmbeddings
from configs.model_config import (
embedding_model_dict,
KB_ROOT_PATH,
@ -41,11 +43,20 @@ def list_docs_from_folder(kb_name: str):
@lru_cache(1)
def load_embeddings(model: str, device: str):
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
model_kwargs={'device': device})
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE)
elif 'bge-' in model:
embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model],
model_kwargs={'device': device},
query_instruction="为这个句子生成表示以用于检索相关文章:")
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
embeddings.query_instruction = ""
else:
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device})
return embeddings
LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst',
'.rtf', '.txt', '.xml',
'.doc', '.docx', '.epub', '.odt', '.pdf',
@ -69,7 +80,7 @@ class KnowledgeFile:
):
self.kb_name = knowledge_base_name
self.filename = filename
self.ext = os.path.splitext(filename)[-1]
self.ext = os.path.splitext(filename)[-1].lower()
if self.ext not in SUPPORTED_EXTS:
raise ValueError(f"暂未支持的文件格式 {self.ext}")
self.filepath = get_file_path(knowledge_base_name, filename)

View File

@ -1,5 +1,5 @@
"""
调用示例: python llm_api_launch.py --model-path-address THUDM/chatglm2-6b@localhost@7650 THUDM/chatglm2-6b-32k@localhost@7651
调用示例: python llm_api_stale.py --model-path-address THUDM/chatglm2-6b@localhost@7650 THUDM/chatglm2-6b-32k@localhost@7651
其他fastchat.server.controller/worker/openai_api_server参数可按照fastchat文档调用
但少数非关键参数如--worker-address,--allowed-origins,--allowed-methods,--allowed-headers不支持

View File

@ -9,8 +9,8 @@ from typing import Any, Optional
class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="HTTP status code")
msg: str = pydantic.Field("success", description="HTTP status message")
code: int = pydantic.Field(200, description="API status code")
msg: str = pydantic.Field("success", description="API status message")
class Config:
schema_extra = {

View File

@ -20,9 +20,9 @@ from webui_pages.utils import *
from streamlit_option_menu import option_menu
from webui_pages import *
import os
from server.llm_api_launch import string_args,launch_all,controller_args,worker_args,server_args,LOG_PATH
from server.llm_api_stale import string_args,launch_all,controller_args,worker_args,server_args,LOG_PATH
from server.api_allinone import parser, api_args
from server.api_allinone_stale import parser, api_args
import subprocess
parser.add_argument("--use-remote-api",action="store_true")

472
startup.py Normal file
View File

@ -0,0 +1,472 @@
from multiprocessing import Process, Queue
import multiprocessing as mp
import subprocess
import sys
import os
from pprint import pprint
# 设置numexpr最大线程数默认为CPU核心数
try:
import numexpr
n_cores = numexpr.utils.detect_number_of_cores()
os.environ["NUMEXPR_MAX_THREADS"] = str(n_cores)
except:
pass
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, \
logger
from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS,
FSCHAT_OPENAI_API, fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, )
from server.utils import MakeFastAPIOffline, FastAPI
import argparse
from typing import Tuple, List
from configs import VERSION
def set_httpx_timeout(timeout=60.0):
import httpx
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
def create_controller_app(
dispatch_method: str,
) -> FastAPI:
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.controller import app, Controller
controller = Controller(dispatch_method)
sys.modules["fastchat.serve.controller"].controller = controller
MakeFastAPIOffline(app)
app.title = "FastChat Controller"
return app
def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]:
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
import argparse
import threading
import fastchat.serve.model_worker
# workaround to make program exit with Ctrl+c
# it should be deleted after pr is merged by fastchat
def _new_init_heart_beat(self):
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
)
self.heart_beat_thread.start()
ModelWorker.init_heart_beat = _new_init_heart_beat
parser = argparse.ArgumentParser()
args = parser.parse_args([])
# default args. should be deleted after pr is merged by fastchat
args.gpus = None
args.max_gpu_memory = "20GiB"
args.load_8bit = False
args.cpu_offloading = None
args.gptq_ckpt = None
args.gptq_wbits = 16
args.gptq_groupsize = -1
args.gptq_act_order = False
args.awq_ckpt = None
args.awq_wbits = 16
args.awq_groupsize = -1
args.num_gpus = 1
args.model_names = []
args.conv_template = None
args.limit_worker_concurrency = 5
args.stream_interval = 2
args.no_register = False
for k, v in kwargs.items():
setattr(args, k, v)
if args.gpus:
if args.num_gpus is None:
args.num_gpus = len(args.gpus.split(','))
if len(args.gpus.split(",")) < args.num_gpus:
raise ValueError(
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
gptq_config = GptqConfig(
ckpt=args.gptq_ckpt or args.model_path,
wbits=args.gptq_wbits,
groupsize=args.gptq_groupsize,
act_order=args.gptq_act_order,
)
awq_config = AWQConfig(
ckpt=args.awq_ckpt or args.model_path,
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
)
worker = ModelWorker(
controller_addr=args.controller_address,
worker_addr=args.worker_address,
worker_id=worker_id,
model_path=args.model_path,
model_names=args.model_names,
limit_worker_concurrency=args.limit_worker_concurrency,
no_register=args.no_register,
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
)
sys.modules["fastchat.serve.model_worker"].worker = worker
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
MakeFastAPIOffline(app)
app.title = f"FastChat LLM Server ({LLM_MODEL})"
return app
def create_openai_api_app(
controller_address: str,
api_keys: List = [],
) -> FastAPI:
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
app_settings.controller_address = controller_address
app_settings.api_keys = api_keys
MakeFastAPIOffline(app)
app.title = "FastChat OpeanAI API Server"
return app
def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
if run_seq == 1:
@app.on_event("startup")
async def on_startup():
set_httpx_timeout()
q.put(run_seq)
elif run_seq > 1:
@app.on_event("startup")
async def on_startup():
set_httpx_timeout()
while True:
no = q.get()
if no != run_seq - 1:
q.put(no)
else:
break
q.put(run_seq)
def run_controller(q: Queue, run_seq: int = 1):
import uvicorn
app = create_controller_app(FSCHAT_CONTROLLER.get("dispatch_method"))
_set_app_seq(app, q, run_seq)
host = FSCHAT_CONTROLLER["host"]
port = FSCHAT_CONTROLLER["port"]
uvicorn.run(app, host=host, port=port)
def run_model_worker(
model_name: str = LLM_MODEL,
controller_address: str = "",
q: Queue = None,
run_seq: int = 2,
):
import uvicorn
kwargs = FSCHAT_MODEL_WORKERS[LLM_MODEL].copy()
host = kwargs.pop("host")
port = kwargs.pop("port")
model_path = llm_model_dict[model_name].get("local_model_path", "")
kwargs["model_path"] = model_path
kwargs["model_names"] = [model_name]
kwargs["controller_address"] = controller_address or fschat_controller_address()
kwargs["worker_address"] = fschat_model_worker_address()
app = create_model_worker_app(**kwargs)
_set_app_seq(app, q, run_seq)
uvicorn.run(app, host=host, port=port)
def run_openai_api(q: Queue, run_seq: int = 3):
import uvicorn
controller_addr = fschat_controller_address()
app = create_openai_api_app(controller_addr) # todo: not support keys yet.
_set_app_seq(app, q, run_seq)
host = FSCHAT_OPENAI_API["host"]
port = FSCHAT_OPENAI_API["port"]
uvicorn.run(app, host=host, port=port)
def run_api_server(q: Queue, run_seq: int = 4):
from server.api import create_app
import uvicorn
app = create_app()
_set_app_seq(app, q, run_seq)
host = API_SERVER["host"]
port = API_SERVER["port"]
uvicorn.run(app, host=host, port=port)
def run_webui(q: Queue, run_seq: int = 5):
host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]
while True:
no = q.get()
if no != run_seq - 1:
q.put(no)
else:
break
q.put(run_seq)
p = subprocess.Popen(["streamlit", "run", "webui.py",
"--server.address", host,
"--server.port", str(port)])
p.wait()
def parse_args() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"-a",
"--all-webui",
action="store_true",
help="run fastchat's controller/openai_api/model_worker servers, run api.py and webui.py",
dest="all_webui",
)
parser.add_argument(
"--all-api",
action="store_true",
help="run fastchat's controller/openai_api/model_worker servers, run api.py",
dest="all_api",
)
parser.add_argument(
"--llm-api",
action="store_true",
help="run fastchat's controller/openai_api/model_worker servers",
dest="llm_api",
)
parser.add_argument(
"-o",
"--openai-api",
action="store_true",
help="run fastchat's controller/openai_api servers",
dest="openai_api",
)
parser.add_argument(
"-m",
"--model-worker",
action="store_true",
help="run fastchat's model_worker server with specified model name. specify --model-name if not using default LLM_MODEL",
dest="model_worker",
)
parser.add_argument(
"-n",
"--model-name",
type=str,
default=LLM_MODEL,
help="specify model name for model worker.",
dest="model_name",
)
parser.add_argument(
"-c",
"--controller",
type=str,
help="specify controller address the worker is registered to. default is server_config.FSCHAT_CONTROLLER",
dest="controller_address",
)
parser.add_argument(
"--api",
action="store_true",
help="run api.py server",
dest="api",
)
parser.add_argument(
"-w",
"--webui",
action="store_true",
help="run webui.py server",
dest="webui",
)
args = parser.parse_args()
return args
def dump_server_info(after_start=False):
import platform
import langchain
import fastchat
from configs.server_config import api_address, webui_address
print("\n\n")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print(f"操作系统:{platform.platform()}.")
print(f"python版本{sys.version}")
print(f"项目版本:{VERSION}")
print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}")
print("\n")
print(f"当前LLM模型{LLM_MODEL} @ {LLM_DEVICE}")
pprint(llm_model_dict[LLM_MODEL])
print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}")
if after_start:
print("\n")
print(f"服务端运行信息:")
if args.openai_api:
print(f" OpenAI API Server: {fschat_openai_api_address()}/v1")
print(" (请确认llm_model_dict中配置的api_base_url与上面地址一致。)")
if args.api:
print(f" Chatchat API Server: {api_address()}")
if args.webui:
print(f" Chatchat WEBUI Server: {webui_address()}")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print("\n\n")
if __name__ == "__main__":
import time
mp.set_start_method("spawn")
queue = Queue()
args = parse_args()
if args.all_webui:
args.openai_api = True
args.model_worker = True
args.api = True
args.webui = True
elif args.all_api:
args.openai_api = True
args.model_worker = True
args.api = True
args.webui = False
elif args.llm_api:
args.openai_api = True
args.model_worker = True
args.api = False
args.webui = False
dump_server_info()
logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
processes = {}
if args.openai_api:
process = Process(
target=run_controller,
name=f"controller({os.getpid()})",
args=(queue, len(processes) + 1),
daemon=True,
)
process.start()
processes["controller"] = process
process = Process(
target=run_openai_api,
name=f"openai_api({os.getpid()})",
args=(queue, len(processes) + 1),
daemon=True,
)
process.start()
processes["openai_api"] = process
if args.model_worker:
process = Process(
target=run_model_worker,
name=f"model_worker({os.getpid()})",
args=(args.model_name, args.controller_address, queue, len(processes) + 1),
daemon=True,
)
process.start()
processes["model_worker"] = process
if args.api:
process = Process(
target=run_api_server,
name=f"API Server{os.getpid()})",
args=(queue, len(processes) + 1),
daemon=True,
)
process.start()
processes["api"] = process
if args.webui:
process = Process(
target=run_webui,
name=f"WEBUI Server{os.getpid()})",
args=(queue, len(processes) + 1),
daemon=True,
)
process.start()
processes["webui"] = process
try:
# log infors
while True:
no = queue.get()
if no == len(processes):
time.sleep(0.5)
dump_server_info(True)
break
else:
queue.put(no)
if model_worker_process := processes.get("model_worker"):
model_worker_process.join()
for name, process in processes.items():
if name != "model_worker":
process.join()
except:
if model_worker_process := processes.get("model_worker"):
model_worker_process.terminate()
for name, process in processes.items():
if name != "model_worker":
process.terminate()
# 服务启动后接口调用示例:
# import openai
# openai.api_key = "EMPTY" # Not support yet
# openai.api_base = "http://localhost:8888/v1"
# model = "chatglm2-6b"
# # create a chat completion
# completion = openai.ChatCompletion.create(
# model=model,
# messages=[{"role": "user", "content": "Hello! What is your name?"}]
# )
# # print the completion
# print(completion.choices[0].message.content)

View File

@ -28,4 +28,14 @@ if __name__ == "__main__":
for line in response.iter_content(decode_unicode=True):
print(line, flush=True)
else:
print("Error:", response.status_code)
print("Error:", response.status_code)
r = requests.post(
openai_url + "/chat/completions",
json={"model": LLM_MODEL, "messages": "你好", "max_tokens": 1000})
data = r.json()
print(f"/chat/completions\n")
print(data)
assert "choices" in data

204
tests/api/test_kb_api.py Normal file
View File

@ -0,0 +1,204 @@
from doctest import testfile
import requests
import json
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from configs.server_config import api_address
from configs.model_config import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path
from pprint import pprint
api_base_url = api_address()
kb = "kb_for_api_test"
test_files = {
"README.MD": str(root_path / "README.MD"),
"FAQ.MD": str(root_path / "docs" / "FAQ.MD")
}
def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"):
if not Path(get_kb_path(kb)).exists():
return
url = api_base_url + api
print("\n测试知识库存在,需要删除")
r = requests.post(url, json=kb)
data = r.json()
pprint(data)
# check kb not exists anymore
url = api_base_url + "/knowledge_base/list_knowledge_bases"
print("\n获取知识库列表:")
r = requests.get(url)
data = r.json()
pprint(data)
assert data["code"] == 200
assert isinstance(data["data"], list) and len(data["data"]) > 0
assert kb not in data["data"]
def test_create_kb(api="/knowledge_base/create_knowledge_base"):
url = api_base_url + api
print(f"\n尝试用空名称创建知识库:")
r = requests.post(url, json={"knowledge_base_name": " "})
data = r.json()
pprint(data)
assert data["code"] == 404
assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称"
print(f"\n创建新知识库: {kb}")
r = requests.post(url, json={"knowledge_base_name": kb})
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"已新增知识库 {kb}"
print(f"\n尝试创建同名知识库: {kb}")
r = requests.post(url, json={"knowledge_base_name": kb})
data = r.json()
pprint(data)
assert data["code"] == 404
assert data["msg"] == f"已存在同名知识库 {kb}"
def test_list_kbs(api="/knowledge_base/list_knowledge_bases"):
url = api_base_url + api
print("\n获取知识库列表:")
r = requests.get(url)
data = r.json()
pprint(data)
assert data["code"] == 200
assert isinstance(data["data"], list) and len(data["data"]) > 0
assert kb in data["data"]
def test_upload_doc(api="/knowledge_base/upload_doc"):
url = api_base_url + api
for name, path in test_files.items():
print(f"\n上传知识文件: {name}")
data = {"knowledge_base_name": kb, "override": True}
files = {"file": (name, open(path, "rb"))}
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"成功上传文件 {name}"
for name, path in test_files.items():
print(f"\n尝试重新上传知识文件: {name} 不覆盖")
data = {"knowledge_base_name": kb, "override": False}
files = {"file": (name, open(path, "rb"))}
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 404
assert data["msg"] == f"文件 {name} 已存在。"
for name, path in test_files.items():
print(f"\n尝试重新上传知识文件: {name} 覆盖")
data = {"knowledge_base_name": kb, "override": True}
files = {"file": (name, open(path, "rb"))}
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"成功上传文件 {name}"
def test_list_docs(api="/knowledge_base/list_docs"):
url = api_base_url + api
print("\n获取知识库中文件列表:")
r = requests.get(url, params={"knowledge_base_name": kb})
data = r.json()
pprint(data)
assert data["code"] == 200
assert isinstance(data["data"], list)
for name in test_files:
assert name in data["data"]
def test_search_docs(api="/knowledge_base/search_docs"):
url = api_base_url + api
query = "介绍一下langchain-chatchat项目"
print("\n检索知识库:")
print(query)
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
data = r.json()
pprint(data)
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
def test_update_doc(api="/knowledge_base/update_doc"):
url = api_base_url + api
for name, path in test_files.items():
print(f"\n更新知识文件: {name}")
r = requests.post(url, json={"knowledge_base_name": kb, "file_name": name})
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"成功更新文件 {name}"
def test_delete_doc(api="/knowledge_base/delete_doc"):
url = api_base_url + api
for name, path in test_files.items():
print(f"\n删除知识文件: {name}")
r = requests.post(url, json={"knowledge_base_name": kb, "doc_name": name})
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"{name} 文件删除成功"
url = api_base_url + "/knowledge_base/search_docs"
query = "介绍一下langchain-chatchat项目"
print("\n尝试检索删除后的检索知识库:")
print(query)
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
data = r.json()
pprint(data)
assert isinstance(data, list) and len(data) == 0
def test_recreate_vs(api="/knowledge_base/recreate_vector_store"):
url = api_base_url + api
print("\n重建知识库:")
r = requests.post(url, json={"knowledge_base_name": kb}, stream=True)
for chunk in r.iter_content(None):
data = json.loads(chunk)
assert isinstance(data, dict)
assert data["code"] == 200
print(data["msg"])
url = api_base_url + "/knowledge_base/search_docs"
query = "本项目支持哪些文件格式?"
print("\n尝试检索重建后的检索知识库:")
print(query)
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
data = r.json()
pprint(data)
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
def test_delete_kb_after(api="/knowledge_base/delete_knowledge_base"):
url = api_base_url + api
print("\n删除知识库")
r = requests.post(url, json=kb)
data = r.json()
pprint(data)
# check kb not exists anymore
url = api_base_url + "/knowledge_base/list_knowledge_bases"
print("\n获取知识库列表:")
r = requests.get(url)
data = r.json()
pprint(data)
assert data["code"] == 200
assert isinstance(data["data"], list) and len(data["data"]) > 0
assert kb not in data["data"]

View File

@ -0,0 +1,108 @@
import requests
import json
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent.parent))
from configs.server_config import API_SERVER, api_address
from pprint import pprint
api_base_url = api_address()
def dump_input(d, title):
print("\n")
print("=" * 30 + title + " input " + "="*30)
pprint(d)
def dump_output(r, title):
print("\n")
print("=" * 30 + title + " output" + "="*30)
for line in r.iter_content(None, decode_unicode=True):
print(line, end="", flush=True)
headers = {
'accept': 'application/json',
'Content-Type': 'application/json',
}
data = {
"query": "请用100字左右的文字介绍自己",
"history": [
{
"role": "user",
"content": "你好"
},
{
"role": "assistant",
"content": "你好,我是 ChatGLM"
}
],
"stream": True
}
def test_chat_fastchat(api="/chat/fastchat"):
url = f"{api_base_url}{api}"
data2 = {
"stream": True,
"messages": data["history"] + [{"role": "user", "content": "推荐一部科幻电影"}]
}
dump_input(data2, api)
response = requests.post(url, headers=headers, json=data2, stream=True)
dump_output(response, api)
assert response.status_code == 200
def test_chat_chat(api="/chat/chat"):
url = f"{api_base_url}{api}"
dump_input(data, api)
response = requests.post(url, headers=headers, json=data, stream=True)
dump_output(response, api)
assert response.status_code == 200
def test_knowledge_chat(api="/chat/knowledge_base_chat"):
url = f"{api_base_url}{api}"
data = {
"query": "如何提问以获得高质量答案",
"knowledge_base_name": "samples",
"history": [
{
"role": "user",
"content": "你好"
},
{
"role": "assistant",
"content": "你好,我是 ChatGLM"
}
],
"stream": True
}
dump_input(data, api)
response = requests.post(url, headers=headers, json=data, stream=True)
print("\n")
print("=" * 30 + api + " output" + "="*30)
first = True
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line)
if first:
for doc in data["docs"]:
print(doc)
first = False
print(data["answer"], end="", flush=True)
assert response.status_code == 200
def test_search_engine_chat(api="/chat/search_engine_chat"):
url = f"{api_base_url}{api}"
for se in ["bing", "duckduckgo"]:
dump_input(data, api)
response = requests.post(url, json=data, stream=True)
dump_output(response, api)
assert response.status_code == 200

View File

@ -9,6 +9,7 @@ from webui_pages.utils import *
from streamlit_option_menu import option_menu
from webui_pages import *
import os
from configs import VERSION
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False)
@ -17,6 +18,11 @@ if __name__ == "__main__":
"Langchain-Chatchat WebUI",
os.path.join("img", "chatchat_icon_blue_square_v2.png"),
initial_sidebar_state="expanded",
menu_items={
'Get Help': 'https://github.com/chatchat-space/Langchain-Chatchat',
'Report a bug': "https://github.com/chatchat-space/Langchain-Chatchat/issues",
'About': f"""欢迎使用 Langchain-Chatchat WebUI {VERSION}"""
}
)
if not chat_box.chat_inited:
@ -35,7 +41,7 @@ if __name__ == "__main__":
"func": knowledge_base_page,
},
}
with st.sidebar:
st.image(
os.path.join(
@ -44,6 +50,10 @@ if __name__ == "__main__":
),
use_column_width=True
)
st.caption(
f"""<p align="right">当前版本:{VERSION}</p>""",
unsafe_allow_html=True,
)
options = list(pages)
icons = [x["icon"] for x in pages.values()]

View File

@ -118,7 +118,7 @@ def knowledge_base_page(api: ApiRequest):
vector_store_type=vs_type,
embed_model=embed_model,
)
st.toast(ret["msg"])
st.toast(ret.get("msg", " "))
st.session_state["selected_kb_name"] = kb_name
st.experimental_rerun()
@ -138,12 +138,14 @@ def knowledge_base_page(api: ApiRequest):
# use_container_width=True,
disabled=len(files) == 0,
):
for f in files:
ret = api.upload_kb_doc(f, kb)
if ret["code"] == 200:
st.toast(ret["msg"], icon="")
else:
st.toast(ret["msg"], icon="")
data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files]
data[-1]["not_refresh_vs_cache"]=False
for k in data:
ret = api.upload_kb_doc(**k)
if msg := check_success_msg(ret):
st.toast(msg, icon="")
elif msg := check_error_msg(ret):
st.toast(msg, icon="")
st.session_state.files = []
st.divider()
@ -235,7 +237,7 @@ def knowledge_base_page(api: ApiRequest):
):
for row in selected_rows:
ret = api.delete_kb_doc(kb, row["file_name"], True)
st.toast(ret["msg"])
st.toast(ret.get("msg", " "))
st.experimental_rerun()
st.divider()
@ -249,12 +251,14 @@ def knowledge_base_page(api: ApiRequest):
use_container_width=True,
type="primary",
):
with st.spinner("向量库重构中"):
with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"):
empty = st.empty()
empty.progress(0.0, "")
for d in api.recreate_vector_store(kb):
print(d)
empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
if msg := check_error_msg(d):
st.toast(msg)
else:
empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
st.experimental_rerun()
if cols[2].button(
@ -262,6 +266,6 @@ def knowledge_base_page(api: ApiRequest):
use_container_width=True,
):
ret = api.delete_knowledge_base(kb)
st.toast(ret["msg"])
st.toast(ret.get("msg", " "))
time.sleep(1)
st.experimental_rerun()

View File

@ -229,18 +229,18 @@ class ApiRequest:
elif chunk.strip():
yield chunk
except httpx.ConnectError as e:
msg = f"无法连接API服务器请确认已执行python server\\api.py"
msg = f"无法连接API服务器请确认 api.py 已正常启动。"
logger.error(msg)
logger.error(e)
yield {"code": 500, "errorMsg": msg}
yield {"code": 500, "msg": msg}
except httpx.ReadTimeout as e:
msg = f"API通信超时请确认已启动FastChat与API服务详见RADME '5. 启动 API 服务或 Web UI'"
logger.error(msg)
logger.error(e)
yield {"code": 500, "errorMsg": msg}
yield {"code": 500, "msg": msg}
except Exception as e:
logger.error(e)
yield {"code": 500, "errorMsg": str(e)}
yield {"code": 500, "msg": str(e)}
# 对话相关操作
@ -394,7 +394,7 @@ class ApiRequest:
return response.json()
except Exception as e:
logger.error(e)
return {"code": 500, "errorMsg": errorMsg or str(e)}
return {"code": 500, "msg": errorMsg or str(e)}
def list_knowledge_bases(
self,
@ -496,6 +496,7 @@ class ApiRequest:
knowledge_base_name: str,
filename: str = None,
override: bool = False,
not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
):
'''
@ -529,7 +530,11 @@ class ApiRequest:
else:
response = self.post(
"/knowledge_base/upload_doc",
data={"knowledge_base_name": knowledge_base_name, "override": override},
data={
"knowledge_base_name": knowledge_base_name,
"override": override,
"not_refresh_vs_cache": not_refresh_vs_cache,
},
files={"file": (filename, file)},
)
return self._check_httpx_json_response(response)
@ -539,6 +544,7 @@ class ApiRequest:
knowledge_base_name: str,
doc_name: str,
delete_content: bool = False,
not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
):
'''
@ -551,6 +557,7 @@ class ApiRequest:
"knowledge_base_name": knowledge_base_name,
"doc_name": doc_name,
"delete_content": delete_content,
"not_refresh_vs_cache": not_refresh_vs_cache,
}
if no_remote_api:
@ -568,6 +575,7 @@ class ApiRequest:
self,
knowledge_base_name: str,
file_name: str,
not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
):
'''
@ -583,7 +591,11 @@ class ApiRequest:
else:
response = self.post(
"/knowledge_base/update_doc",
json={"knowledge_base_name": knowledge_base_name, "file_name": file_name},
json={
"knowledge_base_name": knowledge_base_name,
"file_name": file_name,
"not_refresh_vs_cache": not_refresh_vs_cache,
},
)
return self._check_httpx_json_response(response)
@ -617,7 +629,7 @@ class ApiRequest:
"/knowledge_base/recreate_vector_store",
json=data,
stream=True,
timeout=False,
timeout=None,
)
return self._httpx_stream2generator(response, as_json=True)
@ -626,7 +638,22 @@ def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
'''
return error message if error occured when requests API
'''
if isinstance(data, dict) and key in data:
if isinstance(data, dict):
if key in data:
return data[key]
if "code" in data and data["code"] != 200:
return data["msg"]
return ""
def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str:
'''
return error message if error occured when requests API
'''
if (isinstance(data, dict)
and key in data
and "code" in data
and data["code"] == 200):
return data[key]
return ""