Merge branch 'dev' of github.com:chatchat-space/Langchain-Chatchat into dev
This commit is contained in:
commit
999870c3a7
|
|
@ -23,10 +23,11 @@ assignees: ''
|
|||
描述实际发生的结果 / Describe the actual result.
|
||||
|
||||
**环境信息 / Environment Information**
|
||||
- langchain-ChatGLM 版本/commit 号:(例如:v1.0.0 或 commit 123456) / langchain-ChatGLM version/commit number: (e.g., v1.0.0 or commit 123456)
|
||||
- langchain-ChatGLM 版本/commit 号:(例如:v2.0.1 或 commit 123456) / langchain-ChatGLM version/commit number: (e.g., v2.0.1 or commit 123456)
|
||||
- 是否使用 Docker 部署(是/否):是 / Is Docker deployment used (yes/no): yes
|
||||
- 使用的模型(ChatGLM-6B / ClueAI/ChatYuan-large-v2 等):ChatGLM-6B / Model used (ChatGLM-6B / ClueAI/ChatYuan-large-v2, etc.): ChatGLM-6B
|
||||
- 使用的 Embedding 模型(GanymedeNil/text2vec-large-chinese 等):GanymedeNil/text2vec-large-chinese / Embedding model used (GanymedeNil/text2vec-large-chinese, etc.): GanymedeNil/text2vec-large-chinese
|
||||
- 使用的模型(ChatGLM2-6B / Qwen-7B 等):ChatGLM-6B / Model used (ChatGLM2-6B / Qwen-7B, etc.): ChatGLM2-6B
|
||||
- 使用的 Embedding 模型(moka-ai/m3e-base 等):moka-ai/m3e-base / Embedding model used (moka-ai/m3e-base, etc.): moka-ai/m3e-base
|
||||
- 使用的向量库类型 (faiss / milvus / pg_vector 等): faiss / Vector library used (faiss, milvus, pg_vector, etc.): faiss
|
||||
- 操作系统及版本 / Operating system and version:
|
||||
- Python 版本 / Python version:
|
||||
- 其他相关环境信息 / Other relevant environment information:
|
||||
|
|
|
|||
|
|
@ -4,4 +4,4 @@ logs
|
|||
.idea/
|
||||
__pycache__/
|
||||
knowledge_base/
|
||||
configs/model_config.py
|
||||
configs/*.py
|
||||
|
|
|
|||
147
README.md
147
README.md
|
|
@ -126,6 +126,7 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
|
|||
- [BAAI/bge-small-zh](https://huggingface.co/BAAI/bge-small-zh)
|
||||
- [BAAI/bge-base-zh](https://huggingface.co/BAAI/bge-base-zh)
|
||||
- [BAAI/bge-large-zh](https://huggingface.co/BAAI/bge-large-zh)
|
||||
- [BAAI/bge-large-zh-noinstruct](https://huggingface.co/BAAI/bge-large-zh-noinstruct)
|
||||
- [text2vec-base-chinese-sentence](https://huggingface.co/shibing624/text2vec-base-chinese-sentence)
|
||||
- [text2vec-base-chinese-paraphrase](https://huggingface.co/shibing624/text2vec-base-chinese-paraphrase)
|
||||
- [text2vec-base-multilingual](https://huggingface.co/shibing624/text2vec-base-multilingual)
|
||||
|
|
@ -133,6 +134,7 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
|
|||
- [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese)
|
||||
- [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh)
|
||||
- [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh)
|
||||
- [OpenAI/text-embedding-ada-002](https://platform.openai.com/docs/guides/embeddings)
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -181,9 +183,11 @@ $ git clone https://huggingface.co/moka-ai/m3e-base
|
|||
|
||||
### 3. 设置配置项
|
||||
|
||||
复制文件 [configs/model_config.py.example](configs/model_config.py.example) 存储至项目路径下 `./configs` 路径下,并重命名为 `model_config.py`。
|
||||
复制模型相关参数配置模板文件 [configs/model_config.py.example](configs/model_config.py.example) 存储至项目路径下 `./configs` 路径下,并重命名为 `model_config.py`。
|
||||
|
||||
在开始执行 Web UI 或命令行交互前,请先检查 `configs/model_config.py` 中的各项模型参数设计是否符合需求:
|
||||
复制服务相关参数配置模板文件 [configs/server_config.py.example](configs/server_config.py.example) 存储至项目路径下 `./configs` 路径下,并重命名为 `server_config.py`。
|
||||
|
||||
在开始执行 Web UI 或命令行交互前,请先检查 `configs/model_config.py` 和 `configs/server_config.py` 中的各项模型参数设计是否符合需求:
|
||||
|
||||
- 请确认已下载至本地的 LLM 模型本地存储路径写在 `llm_model_dict` 对应模型的 `local_model_path` 属性中,如:
|
||||
|
||||
|
|
@ -204,18 +208,18 @@ embedding_model_dict = {
|
|||
"m3e-base": "/Users/xxx/Downloads/m3e-base",
|
||||
}
|
||||
```
|
||||
如果你选择使用OpenAI的Embedding模型,请将模型的```key```写入`embedding_model_dict`中。使用该模型,你需要鞥能够访问OpenAI官的API,或设置代理。
|
||||
|
||||
### 4. 知识库初始化与迁移
|
||||
|
||||
当前项目的知识库信息存储在数据库中,在正式运行项目之前请先初始化数据库(我们强烈建议您在执行操作前备份您的知识文件)。
|
||||
|
||||
- 如果您是从 `0.1.x` 版本升级过来的用户,针对已建立的知识库,请确认知识库的向量库类型、Embedding 模型 `configs/model_config.py` 中默认设置一致,如无变化只需以下命令将现有知识库信息添加到数据库即可:
|
||||
- 如果您是从 `0.1.x` 版本升级过来的用户,针对已建立的知识库,请确认知识库的向量库类型、Embedding 模型与 `configs/model_config.py` 中默认设置一致,如无变化只需以下命令将现有知识库信息添加到数据库即可:
|
||||
|
||||
```shell
|
||||
$ python init_database.py
|
||||
```
|
||||
|
||||
- 如果您是第一次运行本项目,知识库尚未建立,或者配置文件中的知识库类型、嵌入模型发生变化,需要以下命令初始化或重建知识库:
|
||||
- 如果您是第一次运行本项目,知识库尚未建立,或者配置文件中的知识库类型、嵌入模型发生变化,或者之前的向量库没有开启`normalize_L2`,需要以下命令初始化或重建知识库:
|
||||
|
||||
```shell
|
||||
$ python init_database.py --recreate-vs
|
||||
|
|
@ -228,7 +232,7 @@ embedding_model_dict = {
|
|||
如需使用开源模型进行本地部署,需首先启动 LLM 服务,启动方式分为三种:
|
||||
|
||||
- [基于多进程脚本 llm_api.py 启动 LLM 服务](README.md#5.1.1-基于多进程脚本-llm_api.py-启动-LLM-服务)
|
||||
- [基于命令行脚本 llm_api_launch.py 启动 LLM 服务](README.md#5.1.2-基于命令行脚本-llm_api_launch.py-启动-LLM-服务)
|
||||
- [基于命令行脚本 llm_api_stale.py 启动 LLM 服务](README.md#5.1.2-基于命令行脚本-llm_api_stale.py-启动-LLM-服务)
|
||||
- [PEFT 加载](README.md#5.1.3-PEFT-加载)
|
||||
|
||||
三种方式只需选择一个即可,具体操作方式详见 5.1.1 - 5.1.3。
|
||||
|
|
@ -244,6 +248,7 @@ $ python server/llm_api.py
|
|||
```
|
||||
|
||||
项目支持多卡加载,需在 llm_api.py 中修改 create_model_worker_app 函数中,修改如下三个参数:
|
||||
|
||||
```python
|
||||
gpus=None,
|
||||
num_gpus=1,
|
||||
|
|
@ -256,34 +261,36 @@ max_gpu_memory="20GiB"
|
|||
|
||||
`max_gpu_memory` 控制每个卡使用的显存容量。
|
||||
|
||||
##### 5.1.2 基于命令行脚本 llm_api_launch.py 启动 LLM 服务
|
||||
##### 5.1.2 基于命令行脚本 llm_api_stale.py 启动 LLM 服务
|
||||
|
||||
⚠️ **注意:**
|
||||
⚠️ **注意:**
|
||||
|
||||
**1.llm_api_launch.py脚本原生仅适用于linux,mac设备需要安装对应的linux命令,win平台请使用wls;**
|
||||
**1.llm_api_stale.py脚本原生仅适用于linux,mac设备需要安装对应的linux命令,win平台请使用wls;**
|
||||
|
||||
**2.加载非默认模型需要用命令行参数--model-path-address指定模型,不会读取model_config.py配置;**
|
||||
|
||||
在项目根目录下,执行 [server/llm_api_launch.py](server/llm_api.py) 脚本启动 **LLM 模型**服务:
|
||||
在项目根目录下,执行 [server/llm_api_stale.py](server/llm_api_stale.py) 脚本启动 **LLM 模型**服务:
|
||||
|
||||
```shell
|
||||
$ python server/llm_api_launch.py
|
||||
$ python server/llm_api_stale.py
|
||||
```
|
||||
|
||||
该方式支持启动多个worker,示例启动方式:
|
||||
|
||||
```shell
|
||||
$ python server/llm_api_launch.py --model-path-addresss model1@host1@port1 model2@host2@port2
|
||||
$ python server/llm_api_stale.py --model-path-address model1@host1@port1 model2@host2@port2
|
||||
```
|
||||
|
||||
如果出现server端口占用情况,需手动指定server端口,并同步修改model_config.py下对应模型的base_api_url为指定端口:
|
||||
|
||||
```shell
|
||||
$ python server/llm_api_launch.py --server-port 8887
|
||||
$ python server/llm_api_stale.py --server-port 8887
|
||||
```
|
||||
|
||||
如果要启动多卡加载,示例命令如下:
|
||||
|
||||
```shell
|
||||
$ python server/llm_api_launch.py --gpus 0,1 --num-gpus 2 --max-gpu-memory 10GiB
|
||||
$ python server/llm_api_stale.py --gpus 0,1 --num-gpus 2 --max-gpu-memory 10GiB
|
||||
```
|
||||
|
||||
注:以如上方式启动LLM服务会以nohup命令在后台运行 FastChat 服务,如需停止服务,可以运行如下命令:
|
||||
|
|
@ -294,24 +301,13 @@ $ python server/llm_api_shutdown.py --serve all
|
|||
|
||||
亦可单独停止一个 FastChat 服务模块,可选 [`all`, `controller`, `model_worker`, `openai_api_server`]
|
||||
|
||||
##### 5.1.3 PEFT 加载
|
||||
##### 5.1.3 PEFT 加载(包括lora,p-tuning,prefix tuning, prompt tuning,ia等)
|
||||
|
||||
本项目基于 FastChat 加载 LLM 服务,故需以 FastChat 加载 PEFT 路径,即保证路径名称里必须有 peft 这个词,配置文件的名字为 adapter_config.json,peft 路径下包含 model.bin 格式的 PEFT 权重。
|
||||
详细步骤参考[加载lora微调后模型失效](https://github.com/chatchat-space/Langchain-Chatchat/issues/1130#issuecomment-1685291822)
|
||||
|
||||
示例代码如下:
|
||||

|
||||
|
||||
```shell
|
||||
PEFT_SHARE_BASE_WEIGHTS=true python3 -m fastchat.serve.multi_model_worker \
|
||||
--model-path /data/chris/peft-llama-dummy-1 \
|
||||
--model-names peft-dummy-1 \
|
||||
--model-path /data/chris/peft-llama-dummy-2 \
|
||||
--model-names peft-dummy-2 \
|
||||
--model-path /data/chris/peft-llama-dummy-3 \
|
||||
--model-names peft-dummy-3 \
|
||||
--num-gpus 2
|
||||
```
|
||||
|
||||
详见 [FastChat 相关 PR](https://github.com/lm-sys/fastchat/pull/1905#issuecomment-1627801216)
|
||||
|
||||
#### 5.2 启动 API 服务
|
||||
|
||||
|
|
@ -354,7 +350,6 @@ $ streamlit run webui.py --server.port 666
|
|||
- Web UI 对话界面:
|
||||
|
||||

|
||||
|
||||
- Web UI 知识库管理页面:
|
||||
|
||||

|
||||
|
|
@ -363,86 +358,40 @@ $ streamlit run webui.py --server.port 666
|
|||
|
||||
### 6. 一键启动
|
||||
|
||||
⚠️ **注意:**
|
||||
|
||||
**1. 一键启动脚本仅原生适用于Linux,Mac 设备需要安装对应的linux命令, Winodws 平台请使用 WLS;**
|
||||
|
||||
**2. 加载非默认模型需要用命令行参数 `--model-path-address` 指定模型,不会读取 `model_config.py` 配置。**
|
||||
|
||||
#### 6.1 API 服务一键启动脚本
|
||||
|
||||
新增 API 一键启动脚本,可一键开启 FastChat 后台服务及本项目提供的 API 服务,调用示例:
|
||||
|
||||
调用默认模型:
|
||||
更新一键启动脚本 startup.py,一键启动所有 Fastchat 服务、API 服务、WebUI 服务,示例代码:
|
||||
|
||||
```shell
|
||||
$ python server/api_allinone.py
|
||||
$ python startup.py -a
|
||||
```
|
||||
|
||||
加载多个非默认模型:
|
||||
并可使用 `Ctrl + C` 直接关闭所有运行服务。如果一次结束不了,可以多按几次。
|
||||
|
||||
可选参数包括 `-a (或--all-webui)`, `--all-api`, `--llm-api`, `-c (或--controller)`, `--openai-api`,
|
||||
`-m (或--model-worker)`, `--api`, `--webui`,其中:
|
||||
|
||||
- `--all-webui` 为一键启动 WebUI 所有依赖服务;
|
||||
|
||||
- `--all-api` 为一键启动 API 所有依赖服务;
|
||||
|
||||
- `--llm-api` 为一键启动 Fastchat 所有依赖的 LLM 服务;
|
||||
|
||||
- `--openai-api` 为仅启动 FastChat 的 controller 和 openai-api-server 服务;
|
||||
|
||||
- 其他为单独服务启动选项。
|
||||
|
||||
若想指定非默认模型,需要用 `--model-name` 选项,示例:
|
||||
|
||||
```shell
|
||||
$ python server/api_allinone.py --model-path-address model1@host1@port1 model2@host2@port2
|
||||
$ python startup.py --all-webui --model-name Qwen-7B-Chat
|
||||
```
|
||||
|
||||
如果出现server端口占用情况,需手动指定server端口,并同步修改model_config.py下对应模型的base_api_url为指定端口:
|
||||
更多信息可通过`python startup.py -h`查看。
|
||||
|
||||
```shell
|
||||
$ python server/api_allinone.py --server-port 8887
|
||||
```
|
||||
**注意:**
|
||||
|
||||
多卡启动:
|
||||
**1. startup 脚本用多进程方式启动各模块的服务,可能会导致打印顺序问题,请等待全部服务发起后再调用,并根据默认或指定端口调用服务(默认 LLM API 服务端口:`127.0.0.1:8888`,默认 API 服务端口:`127.0.0.1:7861`,默认 WebUI 服务端口:`本机IP:8501`)**
|
||||
|
||||
```shell
|
||||
python server/api_allinone.py --model-path-address model@host@port --num-gpus 2 --gpus 0,1 --max-gpu-memory 10GiB
|
||||
```
|
||||
|
||||
其他参数详见各脚本及 FastChat 服务说明。
|
||||
|
||||
#### 6.2 webui一键启动脚本
|
||||
|
||||
加载本地模型:
|
||||
|
||||
```shell
|
||||
$ python webui_allinone.py
|
||||
```
|
||||
|
||||
调用远程 API 服务:
|
||||
|
||||
```shell
|
||||
$ python webui_allinone.py --use-remote-api
|
||||
```
|
||||
如果出现server端口占用情况,需手动指定server端口,并同步修改model_config.py下对应模型的base_api_url为指定端口:
|
||||
|
||||
```shell
|
||||
$ python webui_allinone.py --server-port 8887
|
||||
```
|
||||
|
||||
后台运行webui服务:
|
||||
|
||||
```shell
|
||||
$ python webui_allinone.py --nohup
|
||||
```
|
||||
|
||||
加载多个非默认模型:
|
||||
|
||||
```shell
|
||||
$ python webui_allinone.py --model-path-address model1@host1@port1 model2@host2@port2
|
||||
```
|
||||
|
||||
多卡启动:
|
||||
|
||||
```shell
|
||||
$ python webui_alline.py --model-path-address model@host@port --num-gpus 2 --gpus 0,1 --max-gpu-memory 10GiB
|
||||
```
|
||||
|
||||
其他参数详见各脚本及 Fastchat 服务说明。
|
||||
|
||||
上述两个一键启动脚本会后台运行多个服务,如要停止所有服务,可使用 `shutdown_all.sh` 脚本:
|
||||
|
||||
```shell
|
||||
bash shutdown_all.sh
|
||||
```
|
||||
**2.服务启动时间示设备不同而不同,约 3-10 分钟,如长时间没有启动请前往 `./logs`目录下监控日志,定位问题。**
|
||||
|
||||
## 常见问题
|
||||
|
||||
|
|
@ -486,6 +435,6 @@ bash shutdown_all.sh
|
|||
|
||||
## 项目交流群
|
||||
|
||||
<img src="img/qr_code_52.jpg" alt="二维码" width="300" height="300" />
|
||||
<img src="img/qr_code_56.jpg" alt="二维码" width="300" height="300" />
|
||||
|
||||
🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
|
||||
|
|
|
|||
|
|
@ -1 +1,4 @@
|
|||
from .model_config import *
|
||||
from .model_config import *
|
||||
from .server_config import *
|
||||
|
||||
VERSION = "v0.2.2-preview"
|
||||
|
|
|
|||
|
|
@ -1,14 +1,11 @@
|
|||
import os
|
||||
import logging
|
||||
import torch
|
||||
import argparse
|
||||
import json
|
||||
# 日志格式
|
||||
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
logging.basicConfig(format=LOG_FORMAT)
|
||||
import json
|
||||
|
||||
|
||||
# 在以下字典中修改属性值,以指定本地embedding模型存储位置
|
||||
|
|
@ -27,7 +24,9 @@ embedding_model_dict = {
|
|||
"m3e-large": "moka-ai/m3e-large",
|
||||
"bge-small-zh": "BAAI/bge-small-zh",
|
||||
"bge-base-zh": "BAAI/bge-base-zh",
|
||||
"bge-large-zh": "BAAI/bge-large-zh"
|
||||
"bge-large-zh": "BAAI/bge-large-zh",
|
||||
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
|
||||
"text-embedding-ada-002": os.environ.get("OPENAI_API_KEY")
|
||||
}
|
||||
|
||||
# 选用的 Embedding 名称
|
||||
|
|
@ -44,27 +43,15 @@ llm_model_dict = {
|
|||
"api_key": "EMPTY"
|
||||
},
|
||||
|
||||
"chatglm-6b-int4": {
|
||||
"local_model_path": "THUDM/chatglm-6b-int4",
|
||||
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
|
||||
"api_key": "EMPTY"
|
||||
},
|
||||
|
||||
"chatglm2-6b": {
|
||||
"local_model_path": "THUDM/chatglm2-6b",
|
||||
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
|
||||
"api_base_url": "http://localhost:8888/v1", # URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
|
||||
"api_key": "EMPTY"
|
||||
},
|
||||
|
||||
"chatglm2-6b-32k": {
|
||||
"local_model_path": "THUDM/chatglm2-6b-32k", # "THUDM/chatglm2-6b-32k",
|
||||
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
|
||||
"api_key": "EMPTY"
|
||||
},
|
||||
|
||||
"vicuna-13b-hf": {
|
||||
"local_model_path": "",
|
||||
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
|
||||
"api_base_url": "http://localhost:8888/v1", # "URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
|
||||
"api_key": "EMPTY"
|
||||
},
|
||||
|
||||
|
|
@ -78,10 +65,15 @@ llm_model_dict = {
|
|||
# urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPSConnection object at 0x000001FE4BDB85E0>:
|
||||
# Failed to establish a new connection: [WinError 10060]
|
||||
# 则是因为内地和香港的IP都被OPENAI封了,需要切换为日本、新加坡等地
|
||||
|
||||
# 如果出现WARNING: Retrying langchain.chat_models.openai.acompletion_with_retry.<locals>._completion_with_retry in
|
||||
# 4.0 seconds as it raised APIConnectionError: Error communicating with OpenAI.
|
||||
# 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置
|
||||
"gpt-3.5-turbo": {
|
||||
"local_model_path": "gpt-3.5-turbo",
|
||||
"api_base_url": "https://api.openai.com/v1",
|
||||
"api_key": os.environ.get("OPENAI_API_KEY")
|
||||
"api_key": os.environ.get("OPENAI_API_KEY"),
|
||||
"openai_proxy": os.environ.get("OPENAI_PROXY")
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -117,7 +109,7 @@ kbs_config = {
|
|||
"secure": False,
|
||||
},
|
||||
"pg": {
|
||||
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatglm",
|
||||
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -145,12 +137,12 @@ SEARCH_ENGINE_TOP_K = 5
|
|||
# nltk 模型存储路径
|
||||
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
|
||||
|
||||
# 基于本地知识问答的提示词模版
|
||||
PROMPT_TEMPLATE = """【指令】根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。
|
||||
# 基于本地知识问答的提示词模版(使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号
|
||||
PROMPT_TEMPLATE = """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
|
||||
|
||||
【已知信息】{context}
|
||||
<已知信息>{{ context }}</已知信息>
|
||||
|
||||
【问题】{question}"""
|
||||
<问题>{{ question }}</问题>"""
|
||||
|
||||
# API 是否开启跨域,默认为False,如果需要开启,请设置为True
|
||||
# is open cross domain
|
||||
|
|
|
|||
|
|
@ -0,0 +1,100 @@
|
|||
from .model_config import LLM_MODEL, LLM_DEVICE
|
||||
|
||||
# API 是否开启跨域,默认为False,如果需要开启,请设置为True
|
||||
# is open cross domain
|
||||
OPEN_CROSS_DOMAIN = False
|
||||
|
||||
# 各服务器默认绑定host
|
||||
DEFAULT_BIND_HOST = "127.0.0.1"
|
||||
|
||||
# webui.py server
|
||||
WEBUI_SERVER = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 8501,
|
||||
}
|
||||
|
||||
# api.py server
|
||||
API_SERVER = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 7861,
|
||||
}
|
||||
|
||||
# fastchat openai_api server
|
||||
FSCHAT_OPENAI_API = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 8888, # model_config.llm_model_dict中模型配置的api_base_url需要与这里一致。
|
||||
}
|
||||
|
||||
# fastchat model_worker server
|
||||
# 这些模型必须是在model_config.llm_model_dict中正确配置的。
|
||||
# 在启动startup.py时,可用通过`--model-worker --model-name xxxx`指定模型,不指定则为LLM_MODEL
|
||||
FSCHAT_MODEL_WORKERS = {
|
||||
LLM_MODEL: {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 20002,
|
||||
"device": LLM_DEVICE,
|
||||
# todo: 多卡加载需要配置的参数
|
||||
"gpus": None,
|
||||
"numgpus": 1,
|
||||
# 以下为非常用参数,可根据需要配置
|
||||
# "max_gpu_memory": "20GiB",
|
||||
# "load_8bit": False,
|
||||
# "cpu_offloading": None,
|
||||
# "gptq_ckpt": None,
|
||||
# "gptq_wbits": 16,
|
||||
# "gptq_groupsize": -1,
|
||||
# "gptq_act_order": False,
|
||||
# "awq_ckpt": None,
|
||||
# "awq_wbits": 16,
|
||||
# "awq_groupsize": -1,
|
||||
# "model_names": [LLM_MODEL],
|
||||
# "conv_template": None,
|
||||
# "limit_worker_concurrency": 5,
|
||||
# "stream_interval": 2,
|
||||
# "no_register": False,
|
||||
},
|
||||
}
|
||||
|
||||
# fastchat multi model worker server
|
||||
FSCHAT_MULTI_MODEL_WORKERS = {
|
||||
# todo
|
||||
}
|
||||
|
||||
# fastchat controller server
|
||||
FSCHAT_CONTROLLER = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 20001,
|
||||
"dispatch_method": "shortest_queue",
|
||||
}
|
||||
|
||||
|
||||
# 以下不要更改
|
||||
def fschat_controller_address() -> str:
|
||||
host = FSCHAT_CONTROLLER["host"]
|
||||
port = FSCHAT_CONTROLLER["port"]
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str:
|
||||
if model := FSCHAT_MODEL_WORKERS.get(model_name):
|
||||
host = model["host"]
|
||||
port = model["port"]
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def fschat_openai_api_address() -> str:
|
||||
host = FSCHAT_OPENAI_API["host"]
|
||||
port = FSCHAT_OPENAI_API["port"]
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def api_address() -> str:
|
||||
host = API_SERVER["host"]
|
||||
port = API_SERVER["port"]
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def webui_address() -> str:
|
||||
host = WEBUI_SERVER["host"]
|
||||
port = WEBUI_SERVER["port"]
|
||||
return f"http://{host}:{port}"
|
||||
13
docs/FAQ.md
13
docs/FAQ.md
|
|
@ -170,3 +170,16 @@ A13: 疑为 chatglm 的 quantization 的问题或 torch 版本差异问题,针
|
|||
Q14: 修改配置中路径后,加载 text2vec-large-chinese 依然提示 `WARNING: No sentence-transformers model found with name text2vec-large-chinese. Creating a new one with MEAN pooling.`
|
||||
|
||||
A14: 尝试更换 embedding,如 text2vec-base-chinese,请在 [configs/model_config.py](../configs/model_config.py) 文件中,修改 `text2vec-base`参数为本地路径,绝对路径或者相对路径均可
|
||||
|
||||
|
||||
---
|
||||
|
||||
Q15: 使用pg向量库建表报错
|
||||
|
||||
A15: 需要手动安装对应的vector扩展(连接pg执行 CREATE EXTENSION IF NOT EXISTS vector)
|
||||
|
||||
---
|
||||
|
||||
Q16: pymilvus 连接超时
|
||||
|
||||
A16.pymilvus版本需要匹配和milvus对应否则会超时参考pymilvus==2.1.3
|
||||
|
|
@ -2,9 +2,9 @@ version: "3.8"
|
|||
services:
|
||||
postgresql:
|
||||
image: ankane/pgvector:v0.4.1
|
||||
container_name: langchain-chatgml-pg-db
|
||||
container_name: langchain_chatchat-pg-db
|
||||
environment:
|
||||
POSTGRES_DB: langchain_chatgml
|
||||
POSTGRES_DB: langchain_chatchat
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
ports:
|
||||
|
|
|
|||
|
|
@ -5,3 +5,4 @@
|
|||
cd docs/docker/vector_db/milvus
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
|
|
|
|||
Binary file not shown.
|
After Width: | Height: | Size: 292 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 269 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 291 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 200 KiB |
|
|
@ -2,6 +2,8 @@ from server.knowledge_base.migrate import create_tables, folder2db, recreate_all
|
|||
from configs.model_config import NLTK_DATA_PATH
|
||||
import nltk
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
from startup import dump_server_info
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
|
@ -21,6 +23,8 @@ if __name__ == "__main__":
|
|||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
dump_server_info()
|
||||
|
||||
create_tables()
|
||||
print("database talbes created")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
langchain==0.0.257
|
||||
langchain==0.0.266
|
||||
openai
|
||||
sentence_transformers
|
||||
fschat==0.2.24
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
langchain==0.0.257
|
||||
langchain==0.0.266
|
||||
openai
|
||||
sentence_transformers
|
||||
fschat==0.2.24
|
||||
|
|
|
|||
|
|
@ -4,7 +4,9 @@ import os
|
|||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN
|
||||
from configs.model_config import NLTK_DATA_PATH
|
||||
from configs.server_config import OPEN_CROSS_DOMAIN
|
||||
from configs import VERSION
|
||||
import argparse
|
||||
import uvicorn
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
|
@ -14,11 +16,10 @@ from server.chat import (chat, knowledge_base_chat, openai_chat,
|
|||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
|
||||
update_doc, download_doc, recreate_vector_store,
|
||||
search_docs, DocumentWithScore)
|
||||
search_docs, DocumentWithScore)
|
||||
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
|
||||
from typing import List
|
||||
|
||||
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
|
||||
|
||||
|
|
@ -27,7 +28,10 @@ async def document():
|
|||
|
||||
|
||||
def create_app():
|
||||
app = FastAPI(title="Langchain-Chatchat API Server")
|
||||
app = FastAPI(
|
||||
title="Langchain-Chatchat API Server",
|
||||
version=VERSION
|
||||
)
|
||||
MakeFastAPIOffline(app)
|
||||
# Add CORS middleware to allow all origins
|
||||
# 在config.py中设置OPEN_DOMAIN=True,允许跨域
|
||||
|
|
@ -75,10 +79,10 @@ def create_app():
|
|||
)(create_kb)
|
||||
|
||||
app.post("/knowledge_base/delete_knowledge_base",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="删除知识库"
|
||||
)(delete_kb)
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="删除知识库"
|
||||
)(delete_kb)
|
||||
|
||||
app.get("/knowledge_base/list_docs",
|
||||
tags=["Knowledge Base Management"],
|
||||
|
|
@ -87,10 +91,10 @@ def create_app():
|
|||
)(list_docs)
|
||||
|
||||
app.post("/knowledge_base/search_docs",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=List[DocumentWithScore],
|
||||
summary="搜索知识库"
|
||||
)(search_docs)
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=List[DocumentWithScore],
|
||||
summary="搜索知识库"
|
||||
)(search_docs)
|
||||
|
||||
app.post("/knowledge_base/upload_doc",
|
||||
tags=["Knowledge Base Management"],
|
||||
|
|
@ -99,10 +103,10 @@ def create_app():
|
|||
)(upload_doc)
|
||||
|
||||
app.post("/knowledge_base/delete_doc",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="删除知识库内指定文件"
|
||||
)(delete_doc)
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="删除知识库内指定文件"
|
||||
)(delete_doc)
|
||||
|
||||
app.post("/knowledge_base/update_doc",
|
||||
tags=["Knowledge Base Management"],
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import os
|
|||
sys.path.append(os.path.dirname(__file__))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from llm_api_launch import launch_all, parser, controller_args, worker_args, server_args
|
||||
from llm_api_stale import launch_all, parser, controller_args, worker_args, server_args
|
||||
from api import create_app
|
||||
import uvicorn
|
||||
|
||||
|
|
@ -21,7 +21,7 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
|
|||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
):
|
||||
history = [History(**h) if isinstance(h, dict) else h for h in history]
|
||||
history = [History.from_data(h) for h in history]
|
||||
|
||||
async def chat_iterator(query: str,
|
||||
history: List[History] = [],
|
||||
|
|
@ -34,11 +34,13 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
|
|||
callbacks=[callback],
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL
|
||||
model_name=LLM_MODEL,
|
||||
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
|
||||
)
|
||||
|
||||
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", "{input}")])
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
|
||||
# Begin a task that runs in the background.
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
|||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
history = [History(**h) if isinstance(h, dict) else h for h in history]
|
||||
history = [History.from_data(h) for h in history]
|
||||
|
||||
async def knowledge_base_chat_iterator(query: str,
|
||||
kb: KBService,
|
||||
|
|
@ -52,13 +52,15 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
|||
callbacks=[callback],
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL
|
||||
model_name=LLM_MODEL,
|
||||
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
|
||||
)
|
||||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
|
||||
input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
|
||||
|
|
|
|||
|
|
@ -73,6 +73,8 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
|||
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||
|
||||
history = [History.from_data(h) for h in history]
|
||||
|
||||
async def search_engine_chat_iterator(query: str,
|
||||
search_engine_name: str,
|
||||
top_k: int,
|
||||
|
|
@ -85,14 +87,16 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
|||
callbacks=[callback],
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL
|
||||
model_name=LLM_MODEL,
|
||||
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
|
||||
)
|
||||
|
||||
docs = lookup_search_engine(query, search_engine_name, top_k)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
|
||||
input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
|
||||
|
|
@ -117,7 +121,7 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
|||
answer = ""
|
||||
async for token in callback.aiter():
|
||||
answer += token
|
||||
yield json.dumps({"answer": token,
|
||||
yield json.dumps({"answer": answer,
|
||||
"docs": source_documents},
|
||||
ensure_ascii=False)
|
||||
await task
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
from typing import Awaitable
|
||||
from typing import Awaitable, List, Tuple, Dict, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.prompts.chat import ChatMessagePromptTemplate
|
||||
|
||||
|
||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||
|
|
@ -28,3 +29,29 @@ class History(BaseModel):
|
|||
|
||||
def to_msg_tuple(self):
|
||||
return "ai" if self.role=="assistant" else "human", self.content
|
||||
|
||||
def to_msg_template(self, is_raw=True) -> ChatMessagePromptTemplate:
|
||||
role_maps = {
|
||||
"ai": "assistant",
|
||||
"human": "user",
|
||||
}
|
||||
role = role_maps.get(self.role, self.role)
|
||||
if is_raw: # 当前默认历史消息都是没有input_variable的文本。
|
||||
content = "{% raw %}" + self.content + "{% endraw %}"
|
||||
else:
|
||||
content = self.content
|
||||
|
||||
return ChatMessagePromptTemplate.from_template(
|
||||
content,
|
||||
"jinja2",
|
||||
role=role,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_data(cls, h: Union[List, Tuple, Dict]) -> "History":
|
||||
if isinstance(h, (list,tuple)) and len(h) >= 2:
|
||||
h = cls(role=h[0], content=h[1])
|
||||
elif isinstance(h, dict):
|
||||
h = cls(**h)
|
||||
|
||||
return h
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ async def list_kbs():
|
|||
async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
vector_store_type: str = Body("faiss"),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
):
|
||||
) -> BaseResponse:
|
||||
# Create selected knowledge base
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
|
@ -27,13 +27,18 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
|||
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
|
||||
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
|
||||
kb.create_kb()
|
||||
try:
|
||||
kb.create_kb()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"创建知识库出错: {e}")
|
||||
|
||||
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
|
||||
|
||||
|
||||
async def delete_kb(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"])
|
||||
):
|
||||
) -> BaseResponse:
|
||||
# Delete selected knowledge base
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
|
@ -51,5 +56,6 @@ async def delete_kb(
|
|||
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"删除知识库时出现意外: {e}")
|
||||
|
||||
return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}")
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
|
|||
) -> List[DocumentWithScore]:
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}", "docs": []}
|
||||
return []
|
||||
docs = kb.search_docs(query, top_k, score_threshold)
|
||||
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
|
|||
|
||||
async def list_docs(
|
||||
knowledge_base_name: str
|
||||
):
|
||||
) -> ListResponse:
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return ListResponse(code=403, msg="Don't attack me", data=[])
|
||||
|
||||
|
|
@ -41,13 +41,14 @@ async def list_docs(
|
|||
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
|
||||
else:
|
||||
all_doc_names = kb.list_docs()
|
||||
return ListResponse(data=all_doc_names)
|
||||
return ListResponse(data=all_doc_names)
|
||||
|
||||
|
||||
async def upload_doc(file: UploadFile = File(..., description="上传文件"),
|
||||
knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]),
|
||||
override: bool = Form(False, description="覆盖已有文件"),
|
||||
):
|
||||
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
|
|
@ -57,31 +58,38 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"),
|
|||
|
||||
file_content = await file.read() # 读取上传文件的内容
|
||||
|
||||
kb_file = KnowledgeFile(filename=file.filename,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
|
||||
if (os.path.exists(kb_file.filepath)
|
||||
and not override
|
||||
and os.path.getsize(kb_file.filepath) == len(file_content)
|
||||
):
|
||||
# TODO: filesize 不同后的处理
|
||||
file_status = f"文件 {kb_file.filename} 已存在。"
|
||||
return BaseResponse(code=404, msg=file_status)
|
||||
|
||||
try:
|
||||
kb_file = KnowledgeFile(filename=file.filename,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
|
||||
if (os.path.exists(kb_file.filepath)
|
||||
and not override
|
||||
and os.path.getsize(kb_file.filepath) == len(file_content)
|
||||
):
|
||||
# TODO: filesize 不同后的处理
|
||||
file_status = f"文件 {kb_file.filename} 已存在。"
|
||||
return BaseResponse(code=404, msg=file_status)
|
||||
|
||||
with open(kb_file.filepath, "wb") as f:
|
||||
f.write(file_content)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}")
|
||||
|
||||
kb.add_doc(kb_file)
|
||||
try:
|
||||
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}")
|
||||
|
||||
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
|
||||
|
||||
|
||||
async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
doc_name: str = Body(..., examples=["file_name.md"]),
|
||||
delete_content: bool = Body(False),
|
||||
):
|
||||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
|
|
@ -92,17 +100,23 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
|
|||
|
||||
if not kb.exist_doc(doc_name):
|
||||
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
|
||||
kb_file = KnowledgeFile(filename=doc_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
kb.delete_doc(kb_file, delete_content)
|
||||
|
||||
try:
|
||||
kb_file = KnowledgeFile(filename=doc_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}")
|
||||
|
||||
return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功")
|
||||
# return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败")
|
||||
|
||||
|
||||
async def update_doc(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
file_name: str = Body(..., examples=["file_name"]),
|
||||
):
|
||||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
更新知识库文档
|
||||
'''
|
||||
|
|
@ -113,14 +127,17 @@ async def update_doc(
|
|||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
kb_file = KnowledgeFile(filename=file_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
try:
|
||||
kb_file = KnowledgeFile(filename=file_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
if os.path.exists(kb_file.filepath):
|
||||
kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败,错误信息是:{e}")
|
||||
|
||||
if os.path.exists(kb_file.filepath):
|
||||
kb.update_doc(kb_file)
|
||||
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
|
||||
else:
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
|
||||
|
||||
|
||||
async def download_doc(
|
||||
|
|
@ -137,18 +154,20 @@ async def download_doc(
|
|||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
kb_file = KnowledgeFile(filename=file_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
|
||||
if os.path.exists(kb_file.filepath):
|
||||
return FileResponse(
|
||||
path=kb_file.filepath,
|
||||
filename=kb_file.filename,
|
||||
media_type="multipart/form-data")
|
||||
else:
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
|
||||
try:
|
||||
kb_file = KnowledgeFile(filename=file_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
|
||||
if os.path.exists(kb_file.filepath):
|
||||
return FileResponse(
|
||||
path=kb_file.filepath,
|
||||
filename=kb_file.filename,
|
||||
media_type="multipart/form-data")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败,错误信息是:{e}")
|
||||
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
|
||||
|
||||
|
||||
async def recreate_vector_store(
|
||||
|
|
@ -163,24 +182,35 @@ async def recreate_vector_store(
|
|||
by default, get_service_by_name only return knowledge base in the info.db and having document files in it.
|
||||
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
|
||||
'''
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
|
||||
if not kb.exists() and not allow_empty_kb:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
async def output(kb):
|
||||
kb.create_kb()
|
||||
kb.clear_vs()
|
||||
docs = list_docs_from_folder(knowledge_base_name)
|
||||
for i, doc in enumerate(docs):
|
||||
try:
|
||||
kb_file = KnowledgeFile(doc, knowledge_base_name)
|
||||
yield json.dumps({
|
||||
"total": len(docs),
|
||||
"finished": i,
|
||||
"doc": doc,
|
||||
}, ensure_ascii=False)
|
||||
kb.add_doc(kb_file)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
async def output():
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
|
||||
if not kb.exists() and not allow_empty_kb:
|
||||
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||||
else:
|
||||
kb.create_kb()
|
||||
kb.clear_vs()
|
||||
docs = list_docs_from_folder(knowledge_base_name)
|
||||
for i, doc in enumerate(docs):
|
||||
try:
|
||||
kb_file = KnowledgeFile(doc, knowledge_base_name)
|
||||
yield json.dumps({
|
||||
"code": 200,
|
||||
"msg": f"({i + 1} / {len(docs)}): {doc}",
|
||||
"total": len(docs),
|
||||
"finished": i,
|
||||
"doc": doc,
|
||||
}, ensure_ascii=False)
|
||||
if i == len(docs) - 1:
|
||||
not_refresh_vs_cache = False
|
||||
else:
|
||||
not_refresh_vs_cache = True
|
||||
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
yield json.dumps({
|
||||
"code": 500,
|
||||
"msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。",
|
||||
})
|
||||
|
||||
return StreamingResponse(output(kb), media_type="text/event-stream")
|
||||
return StreamingResponse(output(), media_type="text/event-stream")
|
||||
|
|
|
|||
|
|
@ -71,36 +71,37 @@ class KBService(ABC):
|
|||
status = delete_kb_from_db(self.kb_name)
|
||||
return status
|
||||
|
||||
def add_doc(self, kb_file: KnowledgeFile):
|
||||
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
"""
|
||||
向知识库添加文件
|
||||
"""
|
||||
docs = kb_file.file2text()
|
||||
if docs:
|
||||
self.delete_doc(kb_file)
|
||||
embeddings = self._load_embeddings()
|
||||
self.do_add_doc(docs, embeddings)
|
||||
self.do_add_doc(docs, embeddings, **kwargs)
|
||||
status = add_doc_to_db(kb_file)
|
||||
else:
|
||||
status = False
|
||||
return status
|
||||
|
||||
def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False):
|
||||
def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs):
|
||||
"""
|
||||
从知识库删除文件
|
||||
"""
|
||||
self.do_delete_doc(kb_file)
|
||||
self.do_delete_doc(kb_file, **kwargs)
|
||||
status = delete_file_from_db(kb_file)
|
||||
if delete_content and os.path.exists(kb_file.filepath):
|
||||
os.remove(kb_file.filepath)
|
||||
return status
|
||||
|
||||
def update_doc(self, kb_file: KnowledgeFile):
|
||||
def update_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
"""
|
||||
使用content中的文件更新向量库
|
||||
"""
|
||||
if os.path.exists(kb_file.filepath):
|
||||
self.delete_doc(kb_file)
|
||||
return self.add_doc(kb_file)
|
||||
self.delete_doc(kb_file, **kwargs)
|
||||
return self.add_doc(kb_file, **kwargs)
|
||||
|
||||
def exist_doc(self, file_name: str):
|
||||
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
|
||||
|
|
@ -156,6 +157,7 @@ class KBService(ABC):
|
|||
def do_search(self,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
embeddings: Embeddings,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -13,7 +13,8 @@ from functools import lru_cache
|
|||
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings,HuggingFaceBgeEmbeddings
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from typing import List
|
||||
from langchain.docstore.document import Document
|
||||
from server.utils import torch_gc
|
||||
|
|
@ -21,10 +22,19 @@ from server.utils import torch_gc
|
|||
|
||||
# make HuggingFaceEmbeddings hashable
|
||||
def _embeddings_hash(self):
|
||||
return hash(self.model_name)
|
||||
|
||||
if isinstance(self, HuggingFaceEmbeddings):
|
||||
return hash(self.model_name)
|
||||
elif isinstance(self, HuggingFaceBgeEmbeddings):
|
||||
return hash(self.model_name)
|
||||
elif isinstance(self, OpenAIEmbeddings):
|
||||
return hash(self.model)
|
||||
|
||||
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
|
||||
OpenAIEmbeddings.__hash__ = _embeddings_hash
|
||||
HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash
|
||||
|
||||
_VECTOR_STORE_TICKS = {}
|
||||
|
||||
|
||||
_VECTOR_STORE_TICKS = {}
|
||||
|
||||
|
|
@ -41,7 +51,23 @@ def load_vector_store(
|
|||
vs_path = get_vs_path(knowledge_base_name)
|
||||
if embeddings is None:
|
||||
embeddings = load_embeddings(embed_model, embed_device)
|
||||
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
|
||||
|
||||
if not os.path.exists(vs_path):
|
||||
os.makedirs(vs_path)
|
||||
|
||||
if "index.faiss" in os.listdir(vs_path):
|
||||
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
|
||||
else:
|
||||
# create an empty vector store
|
||||
doc = Document(page_content="init", metadata={})
|
||||
search_index = FAISS.from_documents([doc], embeddings, normalize_L2=True)
|
||||
ids = [k for k, v in search_index.docstore._dict.items()]
|
||||
search_index.delete(ids)
|
||||
search_index.save_local(vs_path)
|
||||
|
||||
if tick == 0: # vector store is loaded first time
|
||||
_VECTOR_STORE_TICKS[knowledge_base_name] = 0
|
||||
|
||||
return search_index
|
||||
|
||||
|
||||
|
|
@ -50,6 +76,7 @@ def refresh_vs_cache(kb_name: str):
|
|||
make vector store cache refreshed when next loading
|
||||
"""
|
||||
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
|
||||
print(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
|
||||
|
||||
|
||||
class FaissKBService(KBService):
|
||||
|
|
@ -74,8 +101,10 @@ class FaissKBService(KBService):
|
|||
def do_create_kb(self):
|
||||
if not os.path.exists(self.vs_path):
|
||||
os.makedirs(self.vs_path)
|
||||
load_vector_store(self.kb_name)
|
||||
|
||||
def do_drop_kb(self):
|
||||
self.clear_vs()
|
||||
shutil.rmtree(self.kb_path)
|
||||
|
||||
def do_search(self,
|
||||
|
|
@ -93,38 +122,40 @@ class FaissKBService(KBService):
|
|||
def do_add_doc(self,
|
||||
docs: List[Document],
|
||||
embeddings: Embeddings,
|
||||
**kwargs,
|
||||
):
|
||||
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
||||
vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True)
|
||||
vector_store.add_documents(docs)
|
||||
torch_gc()
|
||||
else:
|
||||
if not os.path.exists(self.vs_path):
|
||||
os.makedirs(self.vs_path)
|
||||
vector_store = FAISS.from_documents(
|
||||
docs, embeddings, normalize_L2=True) # docs 为Document列表
|
||||
torch_gc()
|
||||
vector_store.save_local(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
|
||||
def do_delete_doc(self,
|
||||
kb_file: KnowledgeFile):
|
||||
embeddings = self._load_embeddings()
|
||||
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
||||
vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True)
|
||||
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
||||
if len(ids) == 0:
|
||||
return None
|
||||
vector_store.delete(ids)
|
||||
vector_store = load_vector_store(self.kb_name,
|
||||
embeddings=embeddings,
|
||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
|
||||
vector_store.add_documents(docs)
|
||||
torch_gc()
|
||||
if not kwargs.get("not_refresh_vs_cache"):
|
||||
vector_store.save_local(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
return True
|
||||
else:
|
||||
|
||||
def do_delete_doc(self,
|
||||
kb_file: KnowledgeFile,
|
||||
**kwargs):
|
||||
embeddings = self._load_embeddings()
|
||||
vector_store = load_vector_store(self.kb_name,
|
||||
embeddings=embeddings,
|
||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
|
||||
|
||||
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
||||
if len(ids) == 0:
|
||||
return None
|
||||
|
||||
vector_store.delete(ids)
|
||||
if not kwargs.get("not_refresh_vs_cache"):
|
||||
vector_store.save_local(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
|
||||
return True
|
||||
|
||||
def do_clear_vs(self):
|
||||
shutil.rmtree(self.vs_path)
|
||||
os.makedirs(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
|
||||
def exist_doc(self, file_name: str):
|
||||
if super().exist_doc(file_name):
|
||||
|
|
|
|||
|
|
@ -45,12 +45,12 @@ class MilvusKBService(KBService):
|
|||
def do_drop_kb(self):
|
||||
self.milvus.col.drop()
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
|
||||
def do_search(self, query: str, top_k: int,score_threshold: float, embeddings: Embeddings):
|
||||
# todo: support score threshold
|
||||
self._load_milvus(embeddings=embeddings)
|
||||
return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD)
|
||||
return self.milvus.similarity_search_with_score(query, top_k)
|
||||
|
||||
def add_doc(self, kb_file: KnowledgeFile):
|
||||
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
"""
|
||||
向知识库添加文件
|
||||
"""
|
||||
|
|
@ -60,22 +60,24 @@ class MilvusKBService(KBService):
|
|||
status = add_doc_to_db(kb_file)
|
||||
return status
|
||||
|
||||
def do_add_doc(self, docs: List[Document], embeddings: Embeddings):
|
||||
def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs):
|
||||
pass
|
||||
|
||||
def do_delete_doc(self, kb_file: KnowledgeFile):
|
||||
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
filepath = kb_file.filepath.replace('\\', '\\\\')
|
||||
delete_list = [item.get("pk") for item in
|
||||
self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
|
||||
self.milvus.col.delete(expr=f'pk in {delete_list}')
|
||||
|
||||
def do_clear_vs(self):
|
||||
self.milvus.col.drop()
|
||||
if not self.milvus.col:
|
||||
self.milvus.col.drop()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 测试建表使用
|
||||
from server.db.base import Base, engine
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
milvusService = MilvusKBService("test")
|
||||
milvusService.add_doc(KnowledgeFile("README.md", "test"))
|
||||
|
|
|
|||
|
|
@ -43,12 +43,12 @@ class PGKBService(KBService):
|
|||
'''))
|
||||
connect.commit()
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
|
||||
# todo: support score threshold
|
||||
self._load_pg_vector(embeddings=embeddings)
|
||||
return self.pg_vector.similarity_search(query, top_k)
|
||||
return self.pg_vector.similarity_search_with_score(query, top_k)
|
||||
|
||||
def add_doc(self, kb_file: KnowledgeFile):
|
||||
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
"""
|
||||
向知识库添加文件
|
||||
"""
|
||||
|
|
@ -58,10 +58,10 @@ class PGKBService(KBService):
|
|||
status = add_doc_to_db(kb_file)
|
||||
return status
|
||||
|
||||
def do_add_doc(self, docs: List[Document], embeddings: Embeddings):
|
||||
def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs):
|
||||
pass
|
||||
|
||||
def do_delete_doc(self, kb_file: KnowledgeFile):
|
||||
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
with self.pg_vector.connect() as connect:
|
||||
filepath = kb_file.filepath.replace('\\', '\\\\')
|
||||
connect.execute(
|
||||
|
|
@ -76,6 +76,7 @@ class PGKBService(KBService):
|
|||
|
||||
if __name__ == '__main__':
|
||||
from server.db.base import Base, engine
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
pGKBService = PGKBService("test")
|
||||
pGKBService.create_kb()
|
||||
|
|
|
|||
|
|
@ -43,7 +43,11 @@ def folder2db(
|
|||
kb_file = KnowledgeFile(doc, kb_name)
|
||||
if callable(callback_before):
|
||||
callback_before(kb_file, i, docs)
|
||||
kb.add_doc(kb_file)
|
||||
if i == len(docs) - 1:
|
||||
not_refresh_vs_cache = False
|
||||
else:
|
||||
not_refresh_vs_cache = True
|
||||
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
if callable(callback_after):
|
||||
callback_after(kb_file, i, docs)
|
||||
except Exception as e:
|
||||
|
|
@ -67,7 +71,11 @@ def folder2db(
|
|||
kb_file = KnowledgeFile(doc, kb_name)
|
||||
if callable(callback_before):
|
||||
callback_before(kb_file, i, docs)
|
||||
kb.update_doc(kb_file)
|
||||
if i == len(docs) - 1:
|
||||
not_refresh_vs_cache = False
|
||||
else:
|
||||
not_refresh_vs_cache = True
|
||||
kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
if callable(callback_after):
|
||||
callback_after(kb_file, i, docs)
|
||||
except Exception as e:
|
||||
|
|
@ -81,7 +89,11 @@ def folder2db(
|
|||
kb_file = KnowledgeFile(doc, kb_name)
|
||||
if callable(callback_before):
|
||||
callback_before(kb_file, i, docs)
|
||||
kb.add_doc(kb_file)
|
||||
if i == len(docs) - 1:
|
||||
not_refresh_vs_cache = False
|
||||
else:
|
||||
not_refresh_vs_cache = True
|
||||
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
if callable(callback_after):
|
||||
callback_after(kb_file, i, docs)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import os
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||
from configs.model_config import (
|
||||
embedding_model_dict,
|
||||
KB_ROOT_PATH,
|
||||
|
|
@ -41,11 +43,20 @@ def list_docs_from_folder(kb_name: str):
|
|||
|
||||
@lru_cache(1)
|
||||
def load_embeddings(model: str, device: str):
|
||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
|
||||
model_kwargs={'device': device})
|
||||
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
|
||||
embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE)
|
||||
elif 'bge-' in model:
|
||||
embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model],
|
||||
model_kwargs={'device': device},
|
||||
query_instruction="为这个句子生成表示以用于检索相关文章:")
|
||||
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
|
||||
embeddings.query_instruction = ""
|
||||
else:
|
||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device})
|
||||
return embeddings
|
||||
|
||||
|
||||
|
||||
LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst',
|
||||
'.rtf', '.txt', '.xml',
|
||||
'.doc', '.docx', '.epub', '.odt', '.pdf',
|
||||
|
|
@ -69,7 +80,7 @@ class KnowledgeFile:
|
|||
):
|
||||
self.kb_name = knowledge_base_name
|
||||
self.filename = filename
|
||||
self.ext = os.path.splitext(filename)[-1]
|
||||
self.ext = os.path.splitext(filename)[-1].lower()
|
||||
if self.ext not in SUPPORTED_EXTS:
|
||||
raise ValueError(f"暂未支持的文件格式 {self.ext}")
|
||||
self.filepath = get_file_path(knowledge_base_name, filename)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
调用示例: python llm_api_launch.py --model-path-address THUDM/chatglm2-6b@localhost@7650 THUDM/chatglm2-6b-32k@localhost@7651
|
||||
调用示例: python llm_api_stale.py --model-path-address THUDM/chatglm2-6b@localhost@7650 THUDM/chatglm2-6b-32k@localhost@7651
|
||||
其他fastchat.server.controller/worker/openai_api_server参数可按照fastchat文档调用
|
||||
但少数非关键参数如--worker-address,--allowed-origins,--allowed-methods,--allowed-headers不支持
|
||||
|
||||
|
|
@ -9,8 +9,8 @@ from typing import Any, Optional
|
|||
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
code: int = pydantic.Field(200, description="HTTP status code")
|
||||
msg: str = pydantic.Field("success", description="HTTP status message")
|
||||
code: int = pydantic.Field(200, description="API status code")
|
||||
msg: str = pydantic.Field("success", description="API status message")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
|
|
|
|||
|
|
@ -20,9 +20,9 @@ from webui_pages.utils import *
|
|||
from streamlit_option_menu import option_menu
|
||||
from webui_pages import *
|
||||
import os
|
||||
from server.llm_api_launch import string_args,launch_all,controller_args,worker_args,server_args,LOG_PATH
|
||||
from server.llm_api_stale import string_args,launch_all,controller_args,worker_args,server_args,LOG_PATH
|
||||
|
||||
from server.api_allinone import parser, api_args
|
||||
from server.api_allinone_stale import parser, api_args
|
||||
import subprocess
|
||||
|
||||
parser.add_argument("--use-remote-api",action="store_true")
|
||||
|
|
@ -0,0 +1,472 @@
|
|||
from multiprocessing import Process, Queue
|
||||
import multiprocessing as mp
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
from pprint import pprint
|
||||
|
||||
# 设置numexpr最大线程数,默认为CPU核心数
|
||||
try:
|
||||
import numexpr
|
||||
n_cores = numexpr.utils.detect_number_of_cores()
|
||||
os.environ["NUMEXPR_MAX_THREADS"] = str(n_cores)
|
||||
except:
|
||||
pass
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, \
|
||||
logger
|
||||
from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS,
|
||||
FSCHAT_OPENAI_API, fschat_controller_address, fschat_model_worker_address,
|
||||
fschat_openai_api_address, )
|
||||
from server.utils import MakeFastAPIOffline, FastAPI
|
||||
import argparse
|
||||
from typing import Tuple, List
|
||||
from configs import VERSION
|
||||
|
||||
|
||||
def set_httpx_timeout(timeout=60.0):
|
||||
import httpx
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
|
||||
|
||||
|
||||
def create_controller_app(
|
||||
dispatch_method: str,
|
||||
) -> FastAPI:
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.controller import app, Controller
|
||||
|
||||
controller = Controller(dispatch_method)
|
||||
sys.modules["fastchat.serve.controller"].controller = controller
|
||||
|
||||
MakeFastAPIOffline(app)
|
||||
app.title = "FastChat Controller"
|
||||
return app
|
||||
|
||||
|
||||
def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]:
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
|
||||
import argparse
|
||||
import threading
|
||||
import fastchat.serve.model_worker
|
||||
|
||||
# workaround to make program exit with Ctrl+c
|
||||
# it should be deleted after pr is merged by fastchat
|
||||
def _new_init_heart_beat(self):
|
||||
self.register_to_controller()
|
||||
self.heart_beat_thread = threading.Thread(
|
||||
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
|
||||
)
|
||||
self.heart_beat_thread.start()
|
||||
|
||||
ModelWorker.init_heart_beat = _new_init_heart_beat
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args([])
|
||||
# default args. should be deleted after pr is merged by fastchat
|
||||
args.gpus = None
|
||||
args.max_gpu_memory = "20GiB"
|
||||
args.load_8bit = False
|
||||
args.cpu_offloading = None
|
||||
args.gptq_ckpt = None
|
||||
args.gptq_wbits = 16
|
||||
args.gptq_groupsize = -1
|
||||
args.gptq_act_order = False
|
||||
args.awq_ckpt = None
|
||||
args.awq_wbits = 16
|
||||
args.awq_groupsize = -1
|
||||
args.num_gpus = 1
|
||||
args.model_names = []
|
||||
args.conv_template = None
|
||||
args.limit_worker_concurrency = 5
|
||||
args.stream_interval = 2
|
||||
args.no_register = False
|
||||
|
||||
for k, v in kwargs.items():
|
||||
setattr(args, k, v)
|
||||
|
||||
if args.gpus:
|
||||
if args.num_gpus is None:
|
||||
args.num_gpus = len(args.gpus.split(','))
|
||||
if len(args.gpus.split(",")) < args.num_gpus:
|
||||
raise ValueError(
|
||||
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
|
||||
)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
||||
|
||||
gptq_config = GptqConfig(
|
||||
ckpt=args.gptq_ckpt or args.model_path,
|
||||
wbits=args.gptq_wbits,
|
||||
groupsize=args.gptq_groupsize,
|
||||
act_order=args.gptq_act_order,
|
||||
)
|
||||
awq_config = AWQConfig(
|
||||
ckpt=args.awq_ckpt or args.model_path,
|
||||
wbits=args.awq_wbits,
|
||||
groupsize=args.awq_groupsize,
|
||||
)
|
||||
|
||||
worker = ModelWorker(
|
||||
controller_addr=args.controller_address,
|
||||
worker_addr=args.worker_address,
|
||||
worker_id=worker_id,
|
||||
model_path=args.model_path,
|
||||
model_names=args.model_names,
|
||||
limit_worker_concurrency=args.limit_worker_concurrency,
|
||||
no_register=args.no_register,
|
||||
device=args.device,
|
||||
num_gpus=args.num_gpus,
|
||||
max_gpu_memory=args.max_gpu_memory,
|
||||
load_8bit=args.load_8bit,
|
||||
cpu_offloading=args.cpu_offloading,
|
||||
gptq_config=gptq_config,
|
||||
awq_config=awq_config,
|
||||
stream_interval=args.stream_interval,
|
||||
conv_template=args.conv_template,
|
||||
)
|
||||
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
sys.modules["fastchat.serve.model_worker"].args = args
|
||||
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
||||
|
||||
MakeFastAPIOffline(app)
|
||||
app.title = f"FastChat LLM Server ({LLM_MODEL})"
|
||||
return app
|
||||
|
||||
|
||||
def create_openai_api_app(
|
||||
controller_address: str,
|
||||
api_keys: List = [],
|
||||
) -> FastAPI:
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_credentials=True,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app_settings.controller_address = controller_address
|
||||
app_settings.api_keys = api_keys
|
||||
|
||||
MakeFastAPIOffline(app)
|
||||
app.title = "FastChat OpeanAI API Server"
|
||||
return app
|
||||
|
||||
|
||||
def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
|
||||
if run_seq == 1:
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
set_httpx_timeout()
|
||||
q.put(run_seq)
|
||||
elif run_seq > 1:
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
set_httpx_timeout()
|
||||
while True:
|
||||
no = q.get()
|
||||
if no != run_seq - 1:
|
||||
q.put(no)
|
||||
else:
|
||||
break
|
||||
q.put(run_seq)
|
||||
|
||||
|
||||
def run_controller(q: Queue, run_seq: int = 1):
|
||||
import uvicorn
|
||||
|
||||
app = create_controller_app(FSCHAT_CONTROLLER.get("dispatch_method"))
|
||||
_set_app_seq(app, q, run_seq)
|
||||
|
||||
host = FSCHAT_CONTROLLER["host"]
|
||||
port = FSCHAT_CONTROLLER["port"]
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
def run_model_worker(
|
||||
model_name: str = LLM_MODEL,
|
||||
controller_address: str = "",
|
||||
q: Queue = None,
|
||||
run_seq: int = 2,
|
||||
):
|
||||
import uvicorn
|
||||
|
||||
kwargs = FSCHAT_MODEL_WORKERS[LLM_MODEL].copy()
|
||||
host = kwargs.pop("host")
|
||||
port = kwargs.pop("port")
|
||||
model_path = llm_model_dict[model_name].get("local_model_path", "")
|
||||
kwargs["model_path"] = model_path
|
||||
kwargs["model_names"] = [model_name]
|
||||
kwargs["controller_address"] = controller_address or fschat_controller_address()
|
||||
kwargs["worker_address"] = fschat_model_worker_address()
|
||||
|
||||
app = create_model_worker_app(**kwargs)
|
||||
_set_app_seq(app, q, run_seq)
|
||||
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
def run_openai_api(q: Queue, run_seq: int = 3):
|
||||
import uvicorn
|
||||
|
||||
controller_addr = fschat_controller_address()
|
||||
app = create_openai_api_app(controller_addr) # todo: not support keys yet.
|
||||
_set_app_seq(app, q, run_seq)
|
||||
|
||||
host = FSCHAT_OPENAI_API["host"]
|
||||
port = FSCHAT_OPENAI_API["port"]
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
def run_api_server(q: Queue, run_seq: int = 4):
|
||||
from server.api import create_app
|
||||
import uvicorn
|
||||
|
||||
app = create_app()
|
||||
_set_app_seq(app, q, run_seq)
|
||||
|
||||
host = API_SERVER["host"]
|
||||
port = API_SERVER["port"]
|
||||
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
def run_webui(q: Queue, run_seq: int = 5):
|
||||
host = WEBUI_SERVER["host"]
|
||||
port = WEBUI_SERVER["port"]
|
||||
while True:
|
||||
no = q.get()
|
||||
if no != run_seq - 1:
|
||||
q.put(no)
|
||||
else:
|
||||
break
|
||||
q.put(run_seq)
|
||||
p = subprocess.Popen(["streamlit", "run", "webui.py",
|
||||
"--server.address", host,
|
||||
"--server.port", str(port)])
|
||||
p.wait()
|
||||
|
||||
|
||||
def parse_args() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-a",
|
||||
"--all-webui",
|
||||
action="store_true",
|
||||
help="run fastchat's controller/openai_api/model_worker servers, run api.py and webui.py",
|
||||
dest="all_webui",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--all-api",
|
||||
action="store_true",
|
||||
help="run fastchat's controller/openai_api/model_worker servers, run api.py",
|
||||
dest="all_api",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm-api",
|
||||
action="store_true",
|
||||
help="run fastchat's controller/openai_api/model_worker servers",
|
||||
dest="llm_api",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--openai-api",
|
||||
action="store_true",
|
||||
help="run fastchat's controller/openai_api servers",
|
||||
dest="openai_api",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model-worker",
|
||||
action="store_true",
|
||||
help="run fastchat's model_worker server with specified model name. specify --model-name if not using default LLM_MODEL",
|
||||
dest="model_worker",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
"--model-name",
|
||||
type=str,
|
||||
default=LLM_MODEL,
|
||||
help="specify model name for model worker.",
|
||||
dest="model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--controller",
|
||||
type=str,
|
||||
help="specify controller address the worker is registered to. default is server_config.FSCHAT_CONTROLLER",
|
||||
dest="controller_address",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api",
|
||||
action="store_true",
|
||||
help="run api.py server",
|
||||
dest="api",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-w",
|
||||
"--webui",
|
||||
action="store_true",
|
||||
help="run webui.py server",
|
||||
dest="webui",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def dump_server_info(after_start=False):
|
||||
import platform
|
||||
import langchain
|
||||
import fastchat
|
||||
from configs.server_config import api_address, webui_address
|
||||
|
||||
print("\n\n")
|
||||
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
|
||||
print(f"操作系统:{platform.platform()}.")
|
||||
print(f"python版本:{sys.version}")
|
||||
print(f"项目版本:{VERSION}")
|
||||
print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
|
||||
print("\n")
|
||||
print(f"当前LLM模型:{LLM_MODEL} @ {LLM_DEVICE}")
|
||||
pprint(llm_model_dict[LLM_MODEL])
|
||||
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}")
|
||||
if after_start:
|
||||
print("\n")
|
||||
print(f"服务端运行信息:")
|
||||
if args.openai_api:
|
||||
print(f" OpenAI API Server: {fschat_openai_api_address()}/v1")
|
||||
print(" (请确认llm_model_dict中配置的api_base_url与上面地址一致。)")
|
||||
if args.api:
|
||||
print(f" Chatchat API Server: {api_address()}")
|
||||
if args.webui:
|
||||
print(f" Chatchat WEBUI Server: {webui_address()}")
|
||||
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
|
||||
print("\n\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
queue = Queue()
|
||||
args = parse_args()
|
||||
if args.all_webui:
|
||||
args.openai_api = True
|
||||
args.model_worker = True
|
||||
args.api = True
|
||||
args.webui = True
|
||||
|
||||
elif args.all_api:
|
||||
args.openai_api = True
|
||||
args.model_worker = True
|
||||
args.api = True
|
||||
args.webui = False
|
||||
|
||||
elif args.llm_api:
|
||||
args.openai_api = True
|
||||
args.model_worker = True
|
||||
args.api = False
|
||||
args.webui = False
|
||||
|
||||
dump_server_info()
|
||||
logger.info(f"正在启动服务:")
|
||||
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
|
||||
|
||||
processes = {}
|
||||
|
||||
if args.openai_api:
|
||||
process = Process(
|
||||
target=run_controller,
|
||||
name=f"controller({os.getpid()})",
|
||||
args=(queue, len(processes) + 1),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
processes["controller"] = process
|
||||
|
||||
process = Process(
|
||||
target=run_openai_api,
|
||||
name=f"openai_api({os.getpid()})",
|
||||
args=(queue, len(processes) + 1),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
processes["openai_api"] = process
|
||||
|
||||
if args.model_worker:
|
||||
process = Process(
|
||||
target=run_model_worker,
|
||||
name=f"model_worker({os.getpid()})",
|
||||
args=(args.model_name, args.controller_address, queue, len(processes) + 1),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
processes["model_worker"] = process
|
||||
|
||||
if args.api:
|
||||
process = Process(
|
||||
target=run_api_server,
|
||||
name=f"API Server{os.getpid()})",
|
||||
args=(queue, len(processes) + 1),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
processes["api"] = process
|
||||
|
||||
if args.webui:
|
||||
process = Process(
|
||||
target=run_webui,
|
||||
name=f"WEBUI Server{os.getpid()})",
|
||||
args=(queue, len(processes) + 1),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
processes["webui"] = process
|
||||
|
||||
try:
|
||||
# log infors
|
||||
while True:
|
||||
no = queue.get()
|
||||
if no == len(processes):
|
||||
time.sleep(0.5)
|
||||
dump_server_info(True)
|
||||
break
|
||||
else:
|
||||
queue.put(no)
|
||||
|
||||
if model_worker_process := processes.get("model_worker"):
|
||||
model_worker_process.join()
|
||||
for name, process in processes.items():
|
||||
if name != "model_worker":
|
||||
process.join()
|
||||
except:
|
||||
if model_worker_process := processes.get("model_worker"):
|
||||
model_worker_process.terminate()
|
||||
for name, process in processes.items():
|
||||
if name != "model_worker":
|
||||
process.terminate()
|
||||
|
||||
# 服务启动后接口调用示例:
|
||||
# import openai
|
||||
# openai.api_key = "EMPTY" # Not support yet
|
||||
# openai.api_base = "http://localhost:8888/v1"
|
||||
|
||||
# model = "chatglm2-6b"
|
||||
|
||||
# # create a chat completion
|
||||
# completion = openai.ChatCompletion.create(
|
||||
# model=model,
|
||||
# messages=[{"role": "user", "content": "Hello! What is your name?"}]
|
||||
# )
|
||||
# # print the completion
|
||||
# print(completion.choices[0].message.content)
|
||||
|
|
@ -28,4 +28,14 @@ if __name__ == "__main__":
|
|||
for line in response.iter_content(decode_unicode=True):
|
||||
print(line, flush=True)
|
||||
else:
|
||||
print("Error:", response.status_code)
|
||||
print("Error:", response.status_code)
|
||||
|
||||
|
||||
r = requests.post(
|
||||
openai_url + "/chat/completions",
|
||||
json={"model": LLM_MODEL, "messages": "你好", "max_tokens": 1000})
|
||||
data = r.json()
|
||||
print(f"/chat/completions\n")
|
||||
print(data)
|
||||
assert "choices" in data
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,204 @@
|
|||
from doctest import testfile
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from configs.server_config import api_address
|
||||
from configs.model_config import VECTOR_SEARCH_TOP_K
|
||||
from server.knowledge_base.utils import get_kb_path
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
api_base_url = api_address()
|
||||
|
||||
kb = "kb_for_api_test"
|
||||
test_files = {
|
||||
"README.MD": str(root_path / "README.MD"),
|
||||
"FAQ.MD": str(root_path / "docs" / "FAQ.MD")
|
||||
}
|
||||
|
||||
|
||||
def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"):
|
||||
if not Path(get_kb_path(kb)).exists():
|
||||
return
|
||||
|
||||
url = api_base_url + api
|
||||
print("\n测试知识库存在,需要删除")
|
||||
r = requests.post(url, json=kb)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
|
||||
# check kb not exists anymore
|
||||
url = api_base_url + "/knowledge_base/list_knowledge_bases"
|
||||
print("\n获取知识库列表:")
|
||||
r = requests.get(url)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert isinstance(data["data"], list) and len(data["data"]) > 0
|
||||
assert kb not in data["data"]
|
||||
|
||||
|
||||
def test_create_kb(api="/knowledge_base/create_knowledge_base"):
|
||||
url = api_base_url + api
|
||||
|
||||
print(f"\n尝试用空名称创建知识库:")
|
||||
r = requests.post(url, json={"knowledge_base_name": " "})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 404
|
||||
assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称"
|
||||
|
||||
print(f"\n创建新知识库: {kb}")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert data["msg"] == f"已新增知识库 {kb}"
|
||||
|
||||
print(f"\n尝试创建同名知识库: {kb}")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 404
|
||||
assert data["msg"] == f"已存在同名知识库 {kb}"
|
||||
|
||||
|
||||
def test_list_kbs(api="/knowledge_base/list_knowledge_bases"):
|
||||
url = api_base_url + api
|
||||
print("\n获取知识库列表:")
|
||||
r = requests.get(url)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert isinstance(data["data"], list) and len(data["data"]) > 0
|
||||
assert kb in data["data"]
|
||||
|
||||
|
||||
def test_upload_doc(api="/knowledge_base/upload_doc"):
|
||||
url = api_base_url + api
|
||||
for name, path in test_files.items():
|
||||
print(f"\n上传知识文件: {name}")
|
||||
data = {"knowledge_base_name": kb, "override": True}
|
||||
files = {"file": (name, open(path, "rb"))}
|
||||
r = requests.post(url, data=data, files=files)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert data["msg"] == f"成功上传文件 {name}"
|
||||
|
||||
for name, path in test_files.items():
|
||||
print(f"\n尝试重新上传知识文件: {name}, 不覆盖")
|
||||
data = {"knowledge_base_name": kb, "override": False}
|
||||
files = {"file": (name, open(path, "rb"))}
|
||||
r = requests.post(url, data=data, files=files)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 404
|
||||
assert data["msg"] == f"文件 {name} 已存在。"
|
||||
|
||||
for name, path in test_files.items():
|
||||
print(f"\n尝试重新上传知识文件: {name}, 覆盖")
|
||||
data = {"knowledge_base_name": kb, "override": True}
|
||||
files = {"file": (name, open(path, "rb"))}
|
||||
r = requests.post(url, data=data, files=files)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert data["msg"] == f"成功上传文件 {name}"
|
||||
|
||||
|
||||
def test_list_docs(api="/knowledge_base/list_docs"):
|
||||
url = api_base_url + api
|
||||
print("\n获取知识库中文件列表:")
|
||||
r = requests.get(url, params={"knowledge_base_name": kb})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert isinstance(data["data"], list)
|
||||
for name in test_files:
|
||||
assert name in data["data"]
|
||||
|
||||
|
||||
def test_search_docs(api="/knowledge_base/search_docs"):
|
||||
url = api_base_url + api
|
||||
query = "介绍一下langchain-chatchat项目"
|
||||
print("\n检索知识库:")
|
||||
print(query)
|
||||
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
||||
|
||||
|
||||
def test_update_doc(api="/knowledge_base/update_doc"):
|
||||
url = api_base_url + api
|
||||
for name, path in test_files.items():
|
||||
print(f"\n更新知识文件: {name}")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb, "file_name": name})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert data["msg"] == f"成功更新文件 {name}"
|
||||
|
||||
|
||||
def test_delete_doc(api="/knowledge_base/delete_doc"):
|
||||
url = api_base_url + api
|
||||
for name, path in test_files.items():
|
||||
print(f"\n删除知识文件: {name}")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb, "doc_name": name})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert data["msg"] == f"{name} 文件删除成功"
|
||||
|
||||
url = api_base_url + "/knowledge_base/search_docs"
|
||||
query = "介绍一下langchain-chatchat项目"
|
||||
print("\n尝试检索删除后的检索知识库:")
|
||||
print(query)
|
||||
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) == 0
|
||||
|
||||
|
||||
def test_recreate_vs(api="/knowledge_base/recreate_vector_store"):
|
||||
url = api_base_url + api
|
||||
print("\n重建知识库:")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb}, stream=True)
|
||||
for chunk in r.iter_content(None):
|
||||
data = json.loads(chunk)
|
||||
assert isinstance(data, dict)
|
||||
assert data["code"] == 200
|
||||
print(data["msg"])
|
||||
|
||||
url = api_base_url + "/knowledge_base/search_docs"
|
||||
query = "本项目支持哪些文件格式?"
|
||||
print("\n尝试检索重建后的检索知识库:")
|
||||
print(query)
|
||||
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
||||
|
||||
|
||||
def test_delete_kb_after(api="/knowledge_base/delete_knowledge_base"):
|
||||
url = api_base_url + api
|
||||
print("\n删除知识库")
|
||||
r = requests.post(url, json=kb)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
|
||||
# check kb not exists anymore
|
||||
url = api_base_url + "/knowledge_base/list_knowledge_bases"
|
||||
print("\n获取知识库列表:")
|
||||
r = requests.get(url)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert isinstance(data["data"], list) and len(data["data"]) > 0
|
||||
assert kb not in data["data"]
|
||||
|
|
@ -0,0 +1,108 @@
|
|||
import requests
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
from configs.server_config import API_SERVER, api_address
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
api_base_url = api_address()
|
||||
|
||||
|
||||
def dump_input(d, title):
|
||||
print("\n")
|
||||
print("=" * 30 + title + " input " + "="*30)
|
||||
pprint(d)
|
||||
|
||||
|
||||
def dump_output(r, title):
|
||||
print("\n")
|
||||
print("=" * 30 + title + " output" + "="*30)
|
||||
for line in r.iter_content(None, decode_unicode=True):
|
||||
print(line, end="", flush=True)
|
||||
|
||||
|
||||
headers = {
|
||||
'accept': 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
data = {
|
||||
"query": "请用100字左右的文字介绍自己",
|
||||
"history": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "你好,我是 ChatGLM"
|
||||
}
|
||||
],
|
||||
"stream": True
|
||||
}
|
||||
|
||||
|
||||
|
||||
def test_chat_fastchat(api="/chat/fastchat"):
|
||||
url = f"{api_base_url}{api}"
|
||||
data2 = {
|
||||
"stream": True,
|
||||
"messages": data["history"] + [{"role": "user", "content": "推荐一部科幻电影"}]
|
||||
}
|
||||
dump_input(data2, api)
|
||||
response = requests.post(url, headers=headers, json=data2, stream=True)
|
||||
dump_output(response, api)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_chat_chat(api="/chat/chat"):
|
||||
url = f"{api_base_url}{api}"
|
||||
dump_input(data, api)
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
dump_output(response, api)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_knowledge_chat(api="/chat/knowledge_base_chat"):
|
||||
url = f"{api_base_url}{api}"
|
||||
data = {
|
||||
"query": "如何提问以获得高质量答案",
|
||||
"knowledge_base_name": "samples",
|
||||
"history": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "你好,我是 ChatGLM"
|
||||
}
|
||||
],
|
||||
"stream": True
|
||||
}
|
||||
dump_input(data, api)
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
print("\n")
|
||||
print("=" * 30 + api + " output" + "="*30)
|
||||
first = True
|
||||
for line in response.iter_content(None, decode_unicode=True):
|
||||
data = json.loads(line)
|
||||
if first:
|
||||
for doc in data["docs"]:
|
||||
print(doc)
|
||||
first = False
|
||||
print(data["answer"], end="", flush=True)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_search_engine_chat(api="/chat/search_engine_chat"):
|
||||
url = f"{api_base_url}{api}"
|
||||
for se in ["bing", "duckduckgo"]:
|
||||
dump_input(data, api)
|
||||
response = requests.post(url, json=data, stream=True)
|
||||
dump_output(response, api)
|
||||
assert response.status_code == 200
|
||||
12
webui.py
12
webui.py
|
|
@ -9,6 +9,7 @@ from webui_pages.utils import *
|
|||
from streamlit_option_menu import option_menu
|
||||
from webui_pages import *
|
||||
import os
|
||||
from configs import VERSION
|
||||
|
||||
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False)
|
||||
|
||||
|
|
@ -17,6 +18,11 @@ if __name__ == "__main__":
|
|||
"Langchain-Chatchat WebUI",
|
||||
os.path.join("img", "chatchat_icon_blue_square_v2.png"),
|
||||
initial_sidebar_state="expanded",
|
||||
menu_items={
|
||||
'Get Help': 'https://github.com/chatchat-space/Langchain-Chatchat',
|
||||
'Report a bug': "https://github.com/chatchat-space/Langchain-Chatchat/issues",
|
||||
'About': f"""欢迎使用 Langchain-Chatchat WebUI {VERSION}!"""
|
||||
}
|
||||
)
|
||||
|
||||
if not chat_box.chat_inited:
|
||||
|
|
@ -35,7 +41,7 @@ if __name__ == "__main__":
|
|||
"func": knowledge_base_page,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
with st.sidebar:
|
||||
st.image(
|
||||
os.path.join(
|
||||
|
|
@ -44,6 +50,10 @@ if __name__ == "__main__":
|
|||
),
|
||||
use_column_width=True
|
||||
)
|
||||
st.caption(
|
||||
f"""<p align="right">当前版本:{VERSION}</p>""",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
options = list(pages)
|
||||
icons = [x["icon"] for x in pages.values()]
|
||||
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ def knowledge_base_page(api: ApiRequest):
|
|||
vector_store_type=vs_type,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
st.toast(ret["msg"])
|
||||
st.toast(ret.get("msg", " "))
|
||||
st.session_state["selected_kb_name"] = kb_name
|
||||
st.experimental_rerun()
|
||||
|
||||
|
|
@ -138,12 +138,14 @@ def knowledge_base_page(api: ApiRequest):
|
|||
# use_container_width=True,
|
||||
disabled=len(files) == 0,
|
||||
):
|
||||
for f in files:
|
||||
ret = api.upload_kb_doc(f, kb)
|
||||
if ret["code"] == 200:
|
||||
st.toast(ret["msg"], icon="✔")
|
||||
else:
|
||||
st.toast(ret["msg"], icon="✖")
|
||||
data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files]
|
||||
data[-1]["not_refresh_vs_cache"]=False
|
||||
for k in data:
|
||||
ret = api.upload_kb_doc(**k)
|
||||
if msg := check_success_msg(ret):
|
||||
st.toast(msg, icon="✔")
|
||||
elif msg := check_error_msg(ret):
|
||||
st.toast(msg, icon="✖")
|
||||
st.session_state.files = []
|
||||
|
||||
st.divider()
|
||||
|
|
@ -235,7 +237,7 @@ def knowledge_base_page(api: ApiRequest):
|
|||
):
|
||||
for row in selected_rows:
|
||||
ret = api.delete_kb_doc(kb, row["file_name"], True)
|
||||
st.toast(ret["msg"])
|
||||
st.toast(ret.get("msg", " "))
|
||||
st.experimental_rerun()
|
||||
|
||||
st.divider()
|
||||
|
|
@ -249,12 +251,14 @@ def knowledge_base_page(api: ApiRequest):
|
|||
use_container_width=True,
|
||||
type="primary",
|
||||
):
|
||||
with st.spinner("向量库重构中"):
|
||||
with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"):
|
||||
empty = st.empty()
|
||||
empty.progress(0.0, "")
|
||||
for d in api.recreate_vector_store(kb):
|
||||
print(d)
|
||||
empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
|
||||
if msg := check_error_msg(d):
|
||||
st.toast(msg)
|
||||
else:
|
||||
empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
|
||||
st.experimental_rerun()
|
||||
|
||||
if cols[2].button(
|
||||
|
|
@ -262,6 +266,6 @@ def knowledge_base_page(api: ApiRequest):
|
|||
use_container_width=True,
|
||||
):
|
||||
ret = api.delete_knowledge_base(kb)
|
||||
st.toast(ret["msg"])
|
||||
st.toast(ret.get("msg", " "))
|
||||
time.sleep(1)
|
||||
st.experimental_rerun()
|
||||
|
|
|
|||
|
|
@ -229,18 +229,18 @@ class ApiRequest:
|
|||
elif chunk.strip():
|
||||
yield chunk
|
||||
except httpx.ConnectError as e:
|
||||
msg = f"无法连接API服务器,请确认已执行python server\\api.py"
|
||||
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。"
|
||||
logger.error(msg)
|
||||
logger.error(e)
|
||||
yield {"code": 500, "errorMsg": msg}
|
||||
yield {"code": 500, "msg": msg}
|
||||
except httpx.ReadTimeout as e:
|
||||
msg = f"API通信超时,请确认已启动FastChat与API服务(详见RADME '5. 启动 API 服务或 Web UI')"
|
||||
logger.error(msg)
|
||||
logger.error(e)
|
||||
yield {"code": 500, "errorMsg": msg}
|
||||
yield {"code": 500, "msg": msg}
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
yield {"code": 500, "errorMsg": str(e)}
|
||||
yield {"code": 500, "msg": str(e)}
|
||||
|
||||
# 对话相关操作
|
||||
|
||||
|
|
@ -394,7 +394,7 @@ class ApiRequest:
|
|||
return response.json()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return {"code": 500, "errorMsg": errorMsg or str(e)}
|
||||
return {"code": 500, "msg": errorMsg or str(e)}
|
||||
|
||||
def list_knowledge_bases(
|
||||
self,
|
||||
|
|
@ -496,6 +496,7 @@ class ApiRequest:
|
|||
knowledge_base_name: str,
|
||||
filename: str = None,
|
||||
override: bool = False,
|
||||
not_refresh_vs_cache: bool = False,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
|
|
@ -529,7 +530,11 @@ class ApiRequest:
|
|||
else:
|
||||
response = self.post(
|
||||
"/knowledge_base/upload_doc",
|
||||
data={"knowledge_base_name": knowledge_base_name, "override": override},
|
||||
data={
|
||||
"knowledge_base_name": knowledge_base_name,
|
||||
"override": override,
|
||||
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||
},
|
||||
files={"file": (filename, file)},
|
||||
)
|
||||
return self._check_httpx_json_response(response)
|
||||
|
|
@ -539,6 +544,7 @@ class ApiRequest:
|
|||
knowledge_base_name: str,
|
||||
doc_name: str,
|
||||
delete_content: bool = False,
|
||||
not_refresh_vs_cache: bool = False,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
|
|
@ -551,6 +557,7 @@ class ApiRequest:
|
|||
"knowledge_base_name": knowledge_base_name,
|
||||
"doc_name": doc_name,
|
||||
"delete_content": delete_content,
|
||||
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||
}
|
||||
|
||||
if no_remote_api:
|
||||
|
|
@ -568,6 +575,7 @@ class ApiRequest:
|
|||
self,
|
||||
knowledge_base_name: str,
|
||||
file_name: str,
|
||||
not_refresh_vs_cache: bool = False,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
|
|
@ -583,7 +591,11 @@ class ApiRequest:
|
|||
else:
|
||||
response = self.post(
|
||||
"/knowledge_base/update_doc",
|
||||
json={"knowledge_base_name": knowledge_base_name, "file_name": file_name},
|
||||
json={
|
||||
"knowledge_base_name": knowledge_base_name,
|
||||
"file_name": file_name,
|
||||
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||
},
|
||||
)
|
||||
return self._check_httpx_json_response(response)
|
||||
|
||||
|
|
@ -617,7 +629,7 @@ class ApiRequest:
|
|||
"/knowledge_base/recreate_vector_store",
|
||||
json=data,
|
||||
stream=True,
|
||||
timeout=False,
|
||||
timeout=None,
|
||||
)
|
||||
return self._httpx_stream2generator(response, as_json=True)
|
||||
|
||||
|
|
@ -626,7 +638,22 @@ def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
|
|||
'''
|
||||
return error message if error occured when requests API
|
||||
'''
|
||||
if isinstance(data, dict) and key in data:
|
||||
if isinstance(data, dict):
|
||||
if key in data:
|
||||
return data[key]
|
||||
if "code" in data and data["code"] != 200:
|
||||
return data["msg"]
|
||||
return ""
|
||||
|
||||
|
||||
def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str:
|
||||
'''
|
||||
return error message if error occured when requests API
|
||||
'''
|
||||
if (isinstance(data, dict)
|
||||
and key in data
|
||||
and "code" in data
|
||||
and data["code"] == 200):
|
||||
return data[key]
|
||||
return ""
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue