diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 0b7afbd..14bd260 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -23,10 +23,11 @@ assignees: '' 描述实际发生的结果 / Describe the actual result. **环境信息 / Environment Information** -- langchain-ChatGLM 版本/commit 号:(例如:v1.0.0 或 commit 123456) / langchain-ChatGLM version/commit number: (e.g., v1.0.0 or commit 123456) +- langchain-ChatGLM 版本/commit 号:(例如:v2.0.1 或 commit 123456) / langchain-ChatGLM version/commit number: (e.g., v2.0.1 or commit 123456) - 是否使用 Docker 部署(是/否):是 / Is Docker deployment used (yes/no): yes -- 使用的模型(ChatGLM-6B / ClueAI/ChatYuan-large-v2 等):ChatGLM-6B / Model used (ChatGLM-6B / ClueAI/ChatYuan-large-v2, etc.): ChatGLM-6B -- 使用的 Embedding 模型(GanymedeNil/text2vec-large-chinese 等):GanymedeNil/text2vec-large-chinese / Embedding model used (GanymedeNil/text2vec-large-chinese, etc.): GanymedeNil/text2vec-large-chinese +- 使用的模型(ChatGLM2-6B / Qwen-7B 等):ChatGLM-6B / Model used (ChatGLM2-6B / Qwen-7B, etc.): ChatGLM2-6B +- 使用的 Embedding 模型(moka-ai/m3e-base 等):moka-ai/m3e-base / Embedding model used (moka-ai/m3e-base, etc.): moka-ai/m3e-base +- 使用的向量库类型 (faiss / milvus / pg_vector 等): faiss / Vector library used (faiss, milvus, pg_vector, etc.): faiss - 操作系统及版本 / Operating system and version: - Python 版本 / Python version: - 其他相关环境信息 / Other relevant environment information: diff --git a/.gitignore b/.gitignore index af50500..b5918ee 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,4 @@ logs .idea/ __pycache__/ knowledge_base/ -configs/model_config.py \ No newline at end of file +configs/*.py diff --git a/README.md b/README.md index 9766035..bf3a5ae 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,7 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch - [BAAI/bge-small-zh](https://huggingface.co/BAAI/bge-small-zh) - [BAAI/bge-base-zh](https://huggingface.co/BAAI/bge-base-zh) - [BAAI/bge-large-zh](https://huggingface.co/BAAI/bge-large-zh) +- [BAAI/bge-large-zh-noinstruct](https://huggingface.co/BAAI/bge-large-zh-noinstruct) - [text2vec-base-chinese-sentence](https://huggingface.co/shibing624/text2vec-base-chinese-sentence) - [text2vec-base-chinese-paraphrase](https://huggingface.co/shibing624/text2vec-base-chinese-paraphrase) - [text2vec-base-multilingual](https://huggingface.co/shibing624/text2vec-base-multilingual) @@ -133,6 +134,7 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch - [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese) - [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh) - [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh) +- [OpenAI/text-embedding-ada-002](https://platform.openai.com/docs/guides/embeddings) --- @@ -181,9 +183,11 @@ $ git clone https://huggingface.co/moka-ai/m3e-base ### 3. 设置配置项 -复制文件 [configs/model_config.py.example](configs/model_config.py.example) 存储至项目路径下 `./configs` 路径下,并重命名为 `model_config.py`。 +复制模型相关参数配置模板文件 [configs/model_config.py.example](configs/model_config.py.example) 存储至项目路径下 `./configs` 路径下,并重命名为 `model_config.py`。 -在开始执行 Web UI 或命令行交互前,请先检查 `configs/model_config.py` 中的各项模型参数设计是否符合需求: +复制服务相关参数配置模板文件 [configs/server_config.py.example](configs/server_config.py.example) 存储至项目路径下 `./configs` 路径下,并重命名为 `server_config.py`。 + +在开始执行 Web UI 或命令行交互前,请先检查 `configs/model_config.py` 和 `configs/server_config.py` 中的各项模型参数设计是否符合需求: - 请确认已下载至本地的 LLM 模型本地存储路径写在 `llm_model_dict` 对应模型的 `local_model_path` 属性中,如: @@ -204,18 +208,18 @@ embedding_model_dict = { "m3e-base": "/Users/xxx/Downloads/m3e-base", } ``` +如果你选择使用OpenAI的Embedding模型,请将模型的```key```写入`embedding_model_dict`中。使用该模型,你需要鞥能够访问OpenAI官的API,或设置代理。 ### 4. 知识库初始化与迁移 当前项目的知识库信息存储在数据库中,在正式运行项目之前请先初始化数据库(我们强烈建议您在执行操作前备份您的知识文件)。 -- 如果您是从 `0.1.x` 版本升级过来的用户,针对已建立的知识库,请确认知识库的向量库类型、Embedding 模型 `configs/model_config.py` 中默认设置一致,如无变化只需以下命令将现有知识库信息添加到数据库即可: +- 如果您是从 `0.1.x` 版本升级过来的用户,针对已建立的知识库,请确认知识库的向量库类型、Embedding 模型与 `configs/model_config.py` 中默认设置一致,如无变化只需以下命令将现有知识库信息添加到数据库即可: ```shell $ python init_database.py ``` - -- 如果您是第一次运行本项目,知识库尚未建立,或者配置文件中的知识库类型、嵌入模型发生变化,需要以下命令初始化或重建知识库: +- 如果您是第一次运行本项目,知识库尚未建立,或者配置文件中的知识库类型、嵌入模型发生变化,或者之前的向量库没有开启`normalize_L2`,需要以下命令初始化或重建知识库: ```shell $ python init_database.py --recreate-vs @@ -228,7 +232,7 @@ embedding_model_dict = { 如需使用开源模型进行本地部署,需首先启动 LLM 服务,启动方式分为三种: - [基于多进程脚本 llm_api.py 启动 LLM 服务](README.md#5.1.1-基于多进程脚本-llm_api.py-启动-LLM-服务) -- [基于命令行脚本 llm_api_launch.py 启动 LLM 服务](README.md#5.1.2-基于命令行脚本-llm_api_launch.py-启动-LLM-服务) +- [基于命令行脚本 llm_api_stale.py 启动 LLM 服务](README.md#5.1.2-基于命令行脚本-llm_api_stale.py-启动-LLM-服务) - [PEFT 加载](README.md#5.1.3-PEFT-加载) 三种方式只需选择一个即可,具体操作方式详见 5.1.1 - 5.1.3。 @@ -244,6 +248,7 @@ $ python server/llm_api.py ``` 项目支持多卡加载,需在 llm_api.py 中修改 create_model_worker_app 函数中,修改如下三个参数: + ```python gpus=None, num_gpus=1, @@ -256,34 +261,36 @@ max_gpu_memory="20GiB" `max_gpu_memory` 控制每个卡使用的显存容量。 -##### 5.1.2 基于命令行脚本 llm_api_launch.py 启动 LLM 服务 +##### 5.1.2 基于命令行脚本 llm_api_stale.py 启动 LLM 服务 -⚠️ **注意:** +⚠️ **注意:** -**1.llm_api_launch.py脚本原生仅适用于linux,mac设备需要安装对应的linux命令,win平台请使用wls;** +**1.llm_api_stale.py脚本原生仅适用于linux,mac设备需要安装对应的linux命令,win平台请使用wls;** **2.加载非默认模型需要用命令行参数--model-path-address指定模型,不会读取model_config.py配置;** -在项目根目录下,执行 [server/llm_api_launch.py](server/llm_api.py) 脚本启动 **LLM 模型**服务: +在项目根目录下,执行 [server/llm_api_stale.py](server/llm_api_stale.py) 脚本启动 **LLM 模型**服务: ```shell -$ python server/llm_api_launch.py +$ python server/llm_api_stale.py ``` 该方式支持启动多个worker,示例启动方式: ```shell -$ python server/llm_api_launch.py --model-path-addresss model1@host1@port1 model2@host2@port2 +$ python server/llm_api_stale.py --model-path-address model1@host1@port1 model2@host2@port2 ``` + 如果出现server端口占用情况,需手动指定server端口,并同步修改model_config.py下对应模型的base_api_url为指定端口: ```shell -$ python server/llm_api_launch.py --server-port 8887 +$ python server/llm_api_stale.py --server-port 8887 ``` + 如果要启动多卡加载,示例命令如下: ```shell -$ python server/llm_api_launch.py --gpus 0,1 --num-gpus 2 --max-gpu-memory 10GiB +$ python server/llm_api_stale.py --gpus 0,1 --num-gpus 2 --max-gpu-memory 10GiB ``` 注:以如上方式启动LLM服务会以nohup命令在后台运行 FastChat 服务,如需停止服务,可以运行如下命令: @@ -294,24 +301,13 @@ $ python server/llm_api_shutdown.py --serve all 亦可单独停止一个 FastChat 服务模块,可选 [`all`, `controller`, `model_worker`, `openai_api_server`] -##### 5.1.3 PEFT 加载 +##### 5.1.3 PEFT 加载(包括lora,p-tuning,prefix tuning, prompt tuning,ia等) 本项目基于 FastChat 加载 LLM 服务,故需以 FastChat 加载 PEFT 路径,即保证路径名称里必须有 peft 这个词,配置文件的名字为 adapter_config.json,peft 路径下包含 model.bin 格式的 PEFT 权重。 +详细步骤参考[加载lora微调后模型失效](https://github.com/chatchat-space/Langchain-Chatchat/issues/1130#issuecomment-1685291822) -示例代码如下: +![image](https://github.com/chatchat-space/Langchain-Chatchat/assets/22924096/4e056c1c-5c4b-4865-a1af-859cd58a625d) -```shell -PEFT_SHARE_BASE_WEIGHTS=true python3 -m fastchat.serve.multi_model_worker \ - --model-path /data/chris/peft-llama-dummy-1 \ - --model-names peft-dummy-1 \ - --model-path /data/chris/peft-llama-dummy-2 \ - --model-names peft-dummy-2 \ - --model-path /data/chris/peft-llama-dummy-3 \ - --model-names peft-dummy-3 \ - --num-gpus 2 -``` - -详见 [FastChat 相关 PR](https://github.com/lm-sys/fastchat/pull/1905#issuecomment-1627801216) #### 5.2 启动 API 服务 @@ -354,7 +350,6 @@ $ streamlit run webui.py --server.port 666 - Web UI 对话界面: ![](img/webui_0813_0.png) - - Web UI 知识库管理页面: ![](img/webui_0813_1.png) @@ -363,86 +358,40 @@ $ streamlit run webui.py --server.port 666 ### 6. 一键启动 -⚠️ **注意:** - -**1. 一键启动脚本仅原生适用于Linux,Mac 设备需要安装对应的linux命令, Winodws 平台请使用 WLS;** - -**2. 加载非默认模型需要用命令行参数 `--model-path-address` 指定模型,不会读取 `model_config.py` 配置。** - -#### 6.1 API 服务一键启动脚本 - -新增 API 一键启动脚本,可一键开启 FastChat 后台服务及本项目提供的 API 服务,调用示例: - -调用默认模型: +更新一键启动脚本 startup.py,一键启动所有 Fastchat 服务、API 服务、WebUI 服务,示例代码: ```shell -$ python server/api_allinone.py +$ python startup.py -a ``` -加载多个非默认模型: +并可使用 `Ctrl + C` 直接关闭所有运行服务。如果一次结束不了,可以多按几次。 + +可选参数包括 `-a (或--all-webui)`, `--all-api`, `--llm-api`, `-c (或--controller)`, `--openai-api`, +`-m (或--model-worker)`, `--api`, `--webui`,其中: + +- `--all-webui` 为一键启动 WebUI 所有依赖服务; + +- `--all-api` 为一键启动 API 所有依赖服务; + +- `--llm-api` 为一键启动 Fastchat 所有依赖的 LLM 服务; + +- `--openai-api` 为仅启动 FastChat 的 controller 和 openai-api-server 服务; + +- 其他为单独服务启动选项。 + +若想指定非默认模型,需要用 `--model-name` 选项,示例: ```shell -$ python server/api_allinone.py --model-path-address model1@host1@port1 model2@host2@port2 +$ python startup.py --all-webui --model-name Qwen-7B-Chat ``` -如果出现server端口占用情况,需手动指定server端口,并同步修改model_config.py下对应模型的base_api_url为指定端口: +更多信息可通过`python startup.py -h`查看。 -```shell -$ python server/api_allinone.py --server-port 8887 -``` +**注意:** -多卡启动: +**1. startup 脚本用多进程方式启动各模块的服务,可能会导致打印顺序问题,请等待全部服务发起后再调用,并根据默认或指定端口调用服务(默认 LLM API 服务端口:`127.0.0.1:8888`,默认 API 服务端口:`127.0.0.1:7861`,默认 WebUI 服务端口:`本机IP:8501`)** -```shell -python server/api_allinone.py --model-path-address model@host@port --num-gpus 2 --gpus 0,1 --max-gpu-memory 10GiB -``` - -其他参数详见各脚本及 FastChat 服务说明。 - -#### 6.2 webui一键启动脚本 - -加载本地模型: - -```shell -$ python webui_allinone.py -``` - -调用远程 API 服务: - -```shell -$ python webui_allinone.py --use-remote-api -``` -如果出现server端口占用情况,需手动指定server端口,并同步修改model_config.py下对应模型的base_api_url为指定端口: - -```shell -$ python webui_allinone.py --server-port 8887 -``` - -后台运行webui服务: - -```shell -$ python webui_allinone.py --nohup -``` - -加载多个非默认模型: - -```shell -$ python webui_allinone.py --model-path-address model1@host1@port1 model2@host2@port2 -``` - -多卡启动: - -```shell -$ python webui_alline.py --model-path-address model@host@port --num-gpus 2 --gpus 0,1 --max-gpu-memory 10GiB -``` - -其他参数详见各脚本及 Fastchat 服务说明。 - -上述两个一键启动脚本会后台运行多个服务,如要停止所有服务,可使用 `shutdown_all.sh` 脚本: - -```shell -bash shutdown_all.sh -``` +**2.服务启动时间示设备不同而不同,约 3-10 分钟,如长时间没有启动请前往 `./logs`目录下监控日志,定位问题。** ## 常见问题 @@ -486,6 +435,6 @@ bash shutdown_all.sh ## 项目交流群 -二维码 +二维码 🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。 diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/configs/__init__.py b/configs/__init__.py index 0bed9b6..dc9dd40 100644 --- a/configs/__init__.py +++ b/configs/__init__.py @@ -1 +1,4 @@ -from .model_config import * \ No newline at end of file +from .model_config import * +from .server_config import * + +VERSION = "v0.2.2-preview" diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 8771cfc..5b2574e 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -1,14 +1,11 @@ import os import logging import torch -import argparse -import json # 日志格式 LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" logger = logging.getLogger() logger.setLevel(logging.INFO) logging.basicConfig(format=LOG_FORMAT) -import json # 在以下字典中修改属性值,以指定本地embedding模型存储位置 @@ -27,7 +24,9 @@ embedding_model_dict = { "m3e-large": "moka-ai/m3e-large", "bge-small-zh": "BAAI/bge-small-zh", "bge-base-zh": "BAAI/bge-base-zh", - "bge-large-zh": "BAAI/bge-large-zh" + "bge-large-zh": "BAAI/bge-large-zh", + "bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct", + "text-embedding-ada-002": os.environ.get("OPENAI_API_KEY") } # 选用的 Embedding 名称 @@ -44,27 +43,15 @@ llm_model_dict = { "api_key": "EMPTY" }, - "chatglm-6b-int4": { - "local_model_path": "THUDM/chatglm-6b-int4", - "api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url" - "api_key": "EMPTY" - }, - "chatglm2-6b": { "local_model_path": "THUDM/chatglm2-6b", - "api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url" + "api_base_url": "http://localhost:8888/v1", # URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致 "api_key": "EMPTY" }, "chatglm2-6b-32k": { "local_model_path": "THUDM/chatglm2-6b-32k", # "THUDM/chatglm2-6b-32k", - "api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url" - "api_key": "EMPTY" - }, - - "vicuna-13b-hf": { - "local_model_path": "", - "api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url" + "api_base_url": "http://localhost:8888/v1", # "URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致 "api_key": "EMPTY" }, @@ -78,10 +65,15 @@ llm_model_dict = { # urllib3.exceptions.NewConnectionError: : # Failed to establish a new connection: [WinError 10060] # 则是因为内地和香港的IP都被OPENAI封了,需要切换为日本、新加坡等地 + + # 如果出现WARNING: Retrying langchain.chat_models.openai.acompletion_with_retry.._completion_with_retry in + # 4.0 seconds as it raised APIConnectionError: Error communicating with OpenAI. + # 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置 "gpt-3.5-turbo": { "local_model_path": "gpt-3.5-turbo", "api_base_url": "https://api.openai.com/v1", - "api_key": os.environ.get("OPENAI_API_KEY") + "api_key": os.environ.get("OPENAI_API_KEY"), + "openai_proxy": os.environ.get("OPENAI_PROXY") }, } @@ -117,7 +109,7 @@ kbs_config = { "secure": False, }, "pg": { - "connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatglm", + "connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat", } } @@ -145,12 +137,12 @@ SEARCH_ENGINE_TOP_K = 5 # nltk 模型存储路径 NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data") -# 基于本地知识问答的提示词模版 -PROMPT_TEMPLATE = """【指令】根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 +# 基于本地知识问答的提示词模版(使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号 +PROMPT_TEMPLATE = """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 -【已知信息】{context} +<已知信息>{{ context }} -【问题】{question}""" +<问题>{{ question }}""" # API 是否开启跨域,默认为False,如果需要开启,请设置为True # is open cross domain diff --git a/configs/server_config.py.example b/configs/server_config.py.example new file mode 100644 index 0000000..5f37779 --- /dev/null +++ b/configs/server_config.py.example @@ -0,0 +1,100 @@ +from .model_config import LLM_MODEL, LLM_DEVICE + +# API 是否开启跨域,默认为False,如果需要开启,请设置为True +# is open cross domain +OPEN_CROSS_DOMAIN = False + +# 各服务器默认绑定host +DEFAULT_BIND_HOST = "127.0.0.1" + +# webui.py server +WEBUI_SERVER = { + "host": DEFAULT_BIND_HOST, + "port": 8501, +} + +# api.py server +API_SERVER = { + "host": DEFAULT_BIND_HOST, + "port": 7861, +} + +# fastchat openai_api server +FSCHAT_OPENAI_API = { + "host": DEFAULT_BIND_HOST, + "port": 8888, # model_config.llm_model_dict中模型配置的api_base_url需要与这里一致。 +} + +# fastchat model_worker server +# 这些模型必须是在model_config.llm_model_dict中正确配置的。 +# 在启动startup.py时,可用通过`--model-worker --model-name xxxx`指定模型,不指定则为LLM_MODEL +FSCHAT_MODEL_WORKERS = { + LLM_MODEL: { + "host": DEFAULT_BIND_HOST, + "port": 20002, + "device": LLM_DEVICE, + # todo: 多卡加载需要配置的参数 + "gpus": None, + "numgpus": 1, + # 以下为非常用参数,可根据需要配置 + # "max_gpu_memory": "20GiB", + # "load_8bit": False, + # "cpu_offloading": None, + # "gptq_ckpt": None, + # "gptq_wbits": 16, + # "gptq_groupsize": -1, + # "gptq_act_order": False, + # "awq_ckpt": None, + # "awq_wbits": 16, + # "awq_groupsize": -1, + # "model_names": [LLM_MODEL], + # "conv_template": None, + # "limit_worker_concurrency": 5, + # "stream_interval": 2, + # "no_register": False, + }, +} + +# fastchat multi model worker server +FSCHAT_MULTI_MODEL_WORKERS = { + # todo +} + +# fastchat controller server +FSCHAT_CONTROLLER = { + "host": DEFAULT_BIND_HOST, + "port": 20001, + "dispatch_method": "shortest_queue", +} + + +# 以下不要更改 +def fschat_controller_address() -> str: + host = FSCHAT_CONTROLLER["host"] + port = FSCHAT_CONTROLLER["port"] + return f"http://{host}:{port}" + + +def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str: + if model := FSCHAT_MODEL_WORKERS.get(model_name): + host = model["host"] + port = model["port"] + return f"http://{host}:{port}" + + +def fschat_openai_api_address() -> str: + host = FSCHAT_OPENAI_API["host"] + port = FSCHAT_OPENAI_API["port"] + return f"http://{host}:{port}" + + +def api_address() -> str: + host = API_SERVER["host"] + port = API_SERVER["port"] + return f"http://{host}:{port}" + + +def webui_address() -> str: + host = WEBUI_SERVER["host"] + port = WEBUI_SERVER["port"] + return f"http://{host}:{port}" diff --git a/docs/FAQ.md b/docs/FAQ.md index 62fe080..490eb25 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -170,3 +170,16 @@ A13: 疑为 chatglm 的 quantization 的问题或 torch 版本差异问题,针 Q14: 修改配置中路径后,加载 text2vec-large-chinese 依然提示 `WARNING: No sentence-transformers model found with name text2vec-large-chinese. Creating a new one with MEAN pooling.` A14: 尝试更换 embedding,如 text2vec-base-chinese,请在 [configs/model_config.py](../configs/model_config.py) 文件中,修改 `text2vec-base`参数为本地路径,绝对路径或者相对路径均可 + + +--- + +Q15: 使用pg向量库建表报错 + +A15: 需要手动安装对应的vector扩展(连接pg执行 CREATE EXTENSION IF NOT EXISTS vector) + +--- + +Q16: pymilvus 连接超时 + +A16.pymilvus版本需要匹配和milvus对应否则会超时参考pymilvus==2.1.3 \ No newline at end of file diff --git a/docs/docker/vector_db/pg/docker-compose.yml b/docs/docker/vector_db/pg/docker-compose.yml index 8e8359c..b14296b 100644 --- a/docs/docker/vector_db/pg/docker-compose.yml +++ b/docs/docker/vector_db/pg/docker-compose.yml @@ -2,9 +2,9 @@ version: "3.8" services: postgresql: image: ankane/pgvector:v0.4.1 - container_name: langchain-chatgml-pg-db + container_name: langchain_chatchat-pg-db environment: - POSTGRES_DB: langchain_chatgml + POSTGRES_DB: langchain_chatchat POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres ports: diff --git a/docs/向量库环境docker.md b/docs/向量库环境docker.md index 162b0f0..dd5e2cb 100644 --- a/docs/向量库环境docker.md +++ b/docs/向量库环境docker.md @@ -5,3 +5,4 @@ cd docs/docker/vector_db/milvus docker-compose up -d ``` + diff --git a/img/qr_code_53.jpg b/img/qr_code_53.jpg new file mode 100644 index 0000000..3174ccc Binary files /dev/null and b/img/qr_code_53.jpg differ diff --git a/img/qr_code_54.jpg b/img/qr_code_54.jpg new file mode 100644 index 0000000..1245a16 Binary files /dev/null and b/img/qr_code_54.jpg differ diff --git a/img/qr_code_55.jpg b/img/qr_code_55.jpg new file mode 100644 index 0000000..8ff046c Binary files /dev/null and b/img/qr_code_55.jpg differ diff --git a/img/qr_code_56.jpg b/img/qr_code_56.jpg new file mode 100644 index 0000000..f17458d Binary files /dev/null and b/img/qr_code_56.jpg differ diff --git a/init_database.py b/init_database.py index 61d00e1..7fc8494 100644 --- a/init_database.py +++ b/init_database.py @@ -2,6 +2,8 @@ from server.knowledge_base.migrate import create_tables, folder2db, recreate_all from configs.model_config import NLTK_DATA_PATH import nltk nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path +from startup import dump_server_info + if __name__ == "__main__": import argparse @@ -21,6 +23,8 @@ if __name__ == "__main__": ) args = parser.parse_args() + dump_server_info() + create_tables() print("database talbes created") diff --git a/requirements.txt b/requirements.txt index 646a5c7..93908dd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -langchain==0.0.257 +langchain==0.0.266 openai sentence_transformers fschat==0.2.24 diff --git a/requirements_api.txt b/requirements_api.txt index 1e13587..f567f9f 100644 --- a/requirements_api.txt +++ b/requirements_api.txt @@ -1,4 +1,4 @@ -langchain==0.0.257 +langchain==0.0.266 openai sentence_transformers fschat==0.2.24 diff --git a/server/api.py b/server/api.py index 800680c..ecadd7c 100644 --- a/server/api.py +++ b/server/api.py @@ -4,7 +4,9 @@ import os sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN +from configs.model_config import NLTK_DATA_PATH +from configs.server_config import OPEN_CROSS_DOMAIN +from configs import VERSION import argparse import uvicorn from fastapi.middleware.cors import CORSMiddleware @@ -14,11 +16,10 @@ from server.chat import (chat, knowledge_base_chat, openai_chat, from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store, - search_docs, DocumentWithScore) + search_docs, DocumentWithScore) from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline from typing import List - nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path @@ -27,7 +28,10 @@ async def document(): def create_app(): - app = FastAPI(title="Langchain-Chatchat API Server") + app = FastAPI( + title="Langchain-Chatchat API Server", + version=VERSION + ) MakeFastAPIOffline(app) # Add CORS middleware to allow all origins # 在config.py中设置OPEN_DOMAIN=True,允许跨域 @@ -75,10 +79,10 @@ def create_app(): )(create_kb) app.post("/knowledge_base/delete_knowledge_base", - tags=["Knowledge Base Management"], - response_model=BaseResponse, - summary="删除知识库" - )(delete_kb) + tags=["Knowledge Base Management"], + response_model=BaseResponse, + summary="删除知识库" + )(delete_kb) app.get("/knowledge_base/list_docs", tags=["Knowledge Base Management"], @@ -87,10 +91,10 @@ def create_app(): )(list_docs) app.post("/knowledge_base/search_docs", - tags=["Knowledge Base Management"], - response_model=List[DocumentWithScore], - summary="搜索知识库" - )(search_docs) + tags=["Knowledge Base Management"], + response_model=List[DocumentWithScore], + summary="搜索知识库" + )(search_docs) app.post("/knowledge_base/upload_doc", tags=["Knowledge Base Management"], @@ -99,10 +103,10 @@ def create_app(): )(upload_doc) app.post("/knowledge_base/delete_doc", - tags=["Knowledge Base Management"], - response_model=BaseResponse, - summary="删除知识库内指定文件" - )(delete_doc) + tags=["Knowledge Base Management"], + response_model=BaseResponse, + summary="删除知识库内指定文件" + )(delete_doc) app.post("/knowledge_base/update_doc", tags=["Knowledge Base Management"], diff --git a/server/api_allinone.py b/server/api_allinone_stale.py similarity index 95% rename from server/api_allinone.py rename to server/api_allinone_stale.py index 3be8581..78a7a6d 100644 --- a/server/api_allinone.py +++ b/server/api_allinone_stale.py @@ -15,7 +15,7 @@ import os sys.path.append(os.path.dirname(__file__)) sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from llm_api_launch import launch_all, parser, controller_args, worker_args, server_args +from llm_api_stale import launch_all, parser, controller_args, worker_args, server_args from api import create_app import uvicorn diff --git a/server/chat/chat.py b/server/chat/chat.py index 2bc21db..2e939f1 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -21,7 +21,7 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成 ), stream: bool = Body(False, description="流式输出"), ): - history = [History(**h) if isinstance(h, dict) else h for h in history] + history = [History.from_data(h) for h in history] async def chat_iterator(query: str, history: List[History] = [], @@ -34,11 +34,13 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成 callbacks=[callback], openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], - model_name=LLM_MODEL + model_name=LLM_MODEL, + openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy") ) + input_msg = History(role="user", content="{{ input }}").to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages( - [i.to_msg_tuple() for i in history] + [("human", "{input}")]) + [i.to_msg_template() for i in history] + [input_msg]) chain = LLMChain(prompt=chat_prompt, llm=model) # Begin a task that runs in the background. diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 84c62f0..2774569 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -38,7 +38,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - history = [History(**h) if isinstance(h, dict) else h for h in history] + history = [History.from_data(h) for h in history] async def knowledge_base_chat_iterator(query: str, kb: KBService, @@ -52,13 +52,15 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp callbacks=[callback], openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], - model_name=LLM_MODEL + model_name=LLM_MODEL, + openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy") ) docs = search_docs(query, knowledge_base_name, top_k, score_threshold) context = "\n".join([doc.page_content for doc in docs]) + input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages( - [i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)]) + [i.to_msg_template() for i in history] + [input_msg]) chain = LLMChain(prompt=chat_prompt, llm=model) diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 15834d0..032d06a 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -73,6 +73,8 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl if search_engine_name not in SEARCH_ENGINES.keys(): return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") + history = [History.from_data(h) for h in history] + async def search_engine_chat_iterator(query: str, search_engine_name: str, top_k: int, @@ -85,14 +87,16 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl callbacks=[callback], openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], - model_name=LLM_MODEL + model_name=LLM_MODEL, + openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy") ) docs = lookup_search_engine(query, search_engine_name, top_k) context = "\n".join([doc.page_content for doc in docs]) + input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages( - [i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)]) + [i.to_msg_template() for i in history] + [input_msg]) chain = LLMChain(prompt=chat_prompt, llm=model) @@ -117,7 +121,7 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl answer = "" async for token in callback.aiter(): answer += token - yield json.dumps({"answer": token, + yield json.dumps({"answer": answer, "docs": source_documents}, ensure_ascii=False) await task diff --git a/server/chat/utils.py b/server/chat/utils.py index f8afb10..2167f10 100644 --- a/server/chat/utils.py +++ b/server/chat/utils.py @@ -1,6 +1,7 @@ import asyncio -from typing import Awaitable +from typing import Awaitable, List, Tuple, Dict, Union from pydantic import BaseModel, Field +from langchain.prompts.chat import ChatMessagePromptTemplate async def wrap_done(fn: Awaitable, event: asyncio.Event): @@ -28,3 +29,29 @@ class History(BaseModel): def to_msg_tuple(self): return "ai" if self.role=="assistant" else "human", self.content + + def to_msg_template(self, is_raw=True) -> ChatMessagePromptTemplate: + role_maps = { + "ai": "assistant", + "human": "user", + } + role = role_maps.get(self.role, self.role) + if is_raw: # 当前默认历史消息都是没有input_variable的文本。 + content = "{% raw %}" + self.content + "{% endraw %}" + else: + content = self.content + + return ChatMessagePromptTemplate.from_template( + content, + "jinja2", + role=role, + ) + + @classmethod + def from_data(cls, h: Union[List, Tuple, Dict]) -> "History": + if isinstance(h, (list,tuple)) and len(h) >= 2: + h = cls(role=h[0], content=h[1]) + elif isinstance(h, dict): + h = cls(**h) + + return h diff --git a/server/knowledge_base/kb_api.py b/server/knowledge_base/kb_api.py index 4753ba4..b9151b8 100644 --- a/server/knowledge_base/kb_api.py +++ b/server/knowledge_base/kb_api.py @@ -15,7 +15,7 @@ async def list_kbs(): async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), vector_store_type: str = Body("faiss"), embed_model: str = Body(EMBEDDING_MODEL), - ): + ) -> BaseResponse: # Create selected knowledge base if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -27,13 +27,18 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}") kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model) - kb.create_kb() + try: + kb.create_kb() + except Exception as e: + print(e) + return BaseResponse(code=500, msg=f"创建知识库出错: {e}") + return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") async def delete_kb( knowledge_base_name: str = Body(..., examples=["samples"]) - ): + ) -> BaseResponse: # Delete selected knowledge base if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -51,5 +56,6 @@ async def delete_kb( return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}") except Exception as e: print(e) + return BaseResponse(code=500, msg=f"删除知识库时出现意外: {e}") return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}") diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 0bf2cb7..ae027c1 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -22,7 +22,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=[" ) -> List[DocumentWithScore]: kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: - return {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}", "docs": []} + return [] docs = kb.search_docs(query, top_k, score_threshold) data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs] @@ -31,7 +31,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=[" async def list_docs( knowledge_base_name: str -): +) -> ListResponse: if not validate_kb_name(knowledge_base_name): return ListResponse(code=403, msg="Don't attack me", data=[]) @@ -41,13 +41,14 @@ async def list_docs( return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[]) else: all_doc_names = kb.list_docs() - return ListResponse(data=all_doc_names) + return ListResponse(data=all_doc_names) async def upload_doc(file: UploadFile = File(..., description="上传文件"), knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]), override: bool = Form(False, description="覆盖已有文件"), - ): + not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"), + ) -> BaseResponse: if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -57,31 +58,38 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"), file_content = await file.read() # 读取上传文件的内容 - kb_file = KnowledgeFile(filename=file.filename, - knowledge_base_name=knowledge_base_name) - - if (os.path.exists(kb_file.filepath) - and not override - and os.path.getsize(kb_file.filepath) == len(file_content) - ): - # TODO: filesize 不同后的处理 - file_status = f"文件 {kb_file.filename} 已存在。" - return BaseResponse(code=404, msg=file_status) - try: + kb_file = KnowledgeFile(filename=file.filename, + knowledge_base_name=knowledge_base_name) + + if (os.path.exists(kb_file.filepath) + and not override + and os.path.getsize(kb_file.filepath) == len(file_content) + ): + # TODO: filesize 不同后的处理 + file_status = f"文件 {kb_file.filename} 已存在。" + return BaseResponse(code=404, msg=file_status) + with open(kb_file.filepath, "wb") as f: f.write(file_content) except Exception as e: + print(e) return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}") - kb.add_doc(kb_file) + try: + kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) + except Exception as e: + print(e) + return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}") + return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}") async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]), doc_name: str = Body(..., examples=["file_name.md"]), delete_content: bool = Body(False), - ): + not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), + ) -> BaseResponse: if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -92,17 +100,23 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]), if not kb.exist_doc(doc_name): return BaseResponse(code=404, msg=f"未找到文件 {doc_name}") - kb_file = KnowledgeFile(filename=doc_name, - knowledge_base_name=knowledge_base_name) - kb.delete_doc(kb_file, delete_content) + + try: + kb_file = KnowledgeFile(filename=doc_name, + knowledge_base_name=knowledge_base_name) + kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=not_refresh_vs_cache) + except Exception as e: + print(e) + return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}") + return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功") - # return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败") async def update_doc( knowledge_base_name: str = Body(..., examples=["samples"]), file_name: str = Body(..., examples=["file_name"]), - ): + not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), + ) -> BaseResponse: ''' 更新知识库文档 ''' @@ -113,14 +127,17 @@ async def update_doc( if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - kb_file = KnowledgeFile(filename=file_name, - knowledge_base_name=knowledge_base_name) + try: + kb_file = KnowledgeFile(filename=file_name, + knowledge_base_name=knowledge_base_name) + if os.path.exists(kb_file.filepath): + kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) + return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}") + except Exception as e: + print(e) + return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败,错误信息是:{e}") - if os.path.exists(kb_file.filepath): - kb.update_doc(kb_file) - return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}") - else: - return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败") + return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败") async def download_doc( @@ -137,18 +154,20 @@ async def download_doc( if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - kb_file = KnowledgeFile(filename=file_name, - knowledge_base_name=knowledge_base_name) - - if os.path.exists(kb_file.filepath): - return FileResponse( - path=kb_file.filepath, - filename=kb_file.filename, - media_type="multipart/form-data") - else: - return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败") + try: + kb_file = KnowledgeFile(filename=file_name, + knowledge_base_name=knowledge_base_name) + if os.path.exists(kb_file.filepath): + return FileResponse( + path=kb_file.filepath, + filename=kb_file.filename, + media_type="multipart/form-data") + except Exception as e: + print(e) + return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败,错误信息是:{e}") + return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败") async def recreate_vector_store( @@ -163,24 +182,35 @@ async def recreate_vector_store( by default, get_service_by_name only return knowledge base in the info.db and having document files in it. set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents. ''' - kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model) - if not kb.exists() and not allow_empty_kb: - return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - async def output(kb): - kb.create_kb() - kb.clear_vs() - docs = list_docs_from_folder(knowledge_base_name) - for i, doc in enumerate(docs): - try: - kb_file = KnowledgeFile(doc, knowledge_base_name) - yield json.dumps({ - "total": len(docs), - "finished": i, - "doc": doc, - }, ensure_ascii=False) - kb.add_doc(kb_file) - except Exception as e: - print(e) + async def output(): + kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model) + if not kb.exists() and not allow_empty_kb: + yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"} + else: + kb.create_kb() + kb.clear_vs() + docs = list_docs_from_folder(knowledge_base_name) + for i, doc in enumerate(docs): + try: + kb_file = KnowledgeFile(doc, knowledge_base_name) + yield json.dumps({ + "code": 200, + "msg": f"({i + 1} / {len(docs)}): {doc}", + "total": len(docs), + "finished": i, + "doc": doc, + }, ensure_ascii=False) + if i == len(docs) - 1: + not_refresh_vs_cache = False + else: + not_refresh_vs_cache = True + kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) + except Exception as e: + print(e) + yield json.dumps({ + "code": 500, + "msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。", + }) - return StreamingResponse(output(kb), media_type="text/event-stream") + return StreamingResponse(output(), media_type="text/event-stream") diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index d506f63..8d1de48 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -71,36 +71,37 @@ class KBService(ABC): status = delete_kb_from_db(self.kb_name) return status - def add_doc(self, kb_file: KnowledgeFile): + def add_doc(self, kb_file: KnowledgeFile, **kwargs): """ 向知识库添加文件 """ docs = kb_file.file2text() if docs: + self.delete_doc(kb_file) embeddings = self._load_embeddings() - self.do_add_doc(docs, embeddings) + self.do_add_doc(docs, embeddings, **kwargs) status = add_doc_to_db(kb_file) else: status = False return status - def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False): + def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs): """ 从知识库删除文件 """ - self.do_delete_doc(kb_file) + self.do_delete_doc(kb_file, **kwargs) status = delete_file_from_db(kb_file) if delete_content and os.path.exists(kb_file.filepath): os.remove(kb_file.filepath) return status - def update_doc(self, kb_file: KnowledgeFile): + def update_doc(self, kb_file: KnowledgeFile, **kwargs): """ 使用content中的文件更新向量库 """ if os.path.exists(kb_file.filepath): - self.delete_doc(kb_file) - return self.add_doc(kb_file) + self.delete_doc(kb_file, **kwargs) + return self.add_doc(kb_file, **kwargs) def exist_doc(self, file_name: str): return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name, @@ -156,6 +157,7 @@ class KBService(ABC): def do_search(self, query: str, top_k: int, + score_threshold: float, embeddings: Embeddings, ) -> List[Document]: """ diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 5c8376f..b3f5439 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -13,7 +13,8 @@ from functools import lru_cache from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile from langchain.vectorstores import FAISS from langchain.embeddings.base import Embeddings -from langchain.embeddings.huggingface import HuggingFaceEmbeddings +from langchain.embeddings.huggingface import HuggingFaceEmbeddings,HuggingFaceBgeEmbeddings +from langchain.embeddings.openai import OpenAIEmbeddings from typing import List from langchain.docstore.document import Document from server.utils import torch_gc @@ -21,10 +22,19 @@ from server.utils import torch_gc # make HuggingFaceEmbeddings hashable def _embeddings_hash(self): - return hash(self.model_name) - + if isinstance(self, HuggingFaceEmbeddings): + return hash(self.model_name) + elif isinstance(self, HuggingFaceBgeEmbeddings): + return hash(self.model_name) + elif isinstance(self, OpenAIEmbeddings): + return hash(self.model) HuggingFaceEmbeddings.__hash__ = _embeddings_hash +OpenAIEmbeddings.__hash__ = _embeddings_hash +HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash + +_VECTOR_STORE_TICKS = {} + _VECTOR_STORE_TICKS = {} @@ -41,7 +51,23 @@ def load_vector_store( vs_path = get_vs_path(knowledge_base_name) if embeddings is None: embeddings = load_embeddings(embed_model, embed_device) - search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True) + + if not os.path.exists(vs_path): + os.makedirs(vs_path) + + if "index.faiss" in os.listdir(vs_path): + search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True) + else: + # create an empty vector store + doc = Document(page_content="init", metadata={}) + search_index = FAISS.from_documents([doc], embeddings, normalize_L2=True) + ids = [k for k, v in search_index.docstore._dict.items()] + search_index.delete(ids) + search_index.save_local(vs_path) + + if tick == 0: # vector store is loaded first time + _VECTOR_STORE_TICKS[knowledge_base_name] = 0 + return search_index @@ -50,6 +76,7 @@ def refresh_vs_cache(kb_name: str): make vector store cache refreshed when next loading """ _VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1 + print(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}") class FaissKBService(KBService): @@ -74,8 +101,10 @@ class FaissKBService(KBService): def do_create_kb(self): if not os.path.exists(self.vs_path): os.makedirs(self.vs_path) + load_vector_store(self.kb_name) def do_drop_kb(self): + self.clear_vs() shutil.rmtree(self.kb_path) def do_search(self, @@ -93,38 +122,40 @@ class FaissKBService(KBService): def do_add_doc(self, docs: List[Document], embeddings: Embeddings, + **kwargs, ): - if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path): - vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True) - vector_store.add_documents(docs) - torch_gc() - else: - if not os.path.exists(self.vs_path): - os.makedirs(self.vs_path) - vector_store = FAISS.from_documents( - docs, embeddings, normalize_L2=True) # docs 为Document列表 - torch_gc() - vector_store.save_local(self.vs_path) - refresh_vs_cache(self.kb_name) - - def do_delete_doc(self, - kb_file: KnowledgeFile): - embeddings = self._load_embeddings() - if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path): - vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True) - ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] - if len(ids) == 0: - return None - vector_store.delete(ids) + vector_store = load_vector_store(self.kb_name, + embeddings=embeddings, + tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0)) + vector_store.add_documents(docs) + torch_gc() + if not kwargs.get("not_refresh_vs_cache"): vector_store.save_local(self.vs_path) refresh_vs_cache(self.kb_name) - return True - else: + + def do_delete_doc(self, + kb_file: KnowledgeFile, + **kwargs): + embeddings = self._load_embeddings() + vector_store = load_vector_store(self.kb_name, + embeddings=embeddings, + tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0)) + + ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] + if len(ids) == 0: return None + vector_store.delete(ids) + if not kwargs.get("not_refresh_vs_cache"): + vector_store.save_local(self.vs_path) + refresh_vs_cache(self.kb_name) + + return True + def do_clear_vs(self): shutil.rmtree(self.vs_path) os.makedirs(self.vs_path) + refresh_vs_cache(self.kb_name) def exist_doc(self, file_name: str): if super().exist_doc(file_name): diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index f9c40c0..78c22f4 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -45,12 +45,12 @@ class MilvusKBService(KBService): def do_drop_kb(self): self.milvus.col.drop() - def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]: + def do_search(self, query: str, top_k: int,score_threshold: float, embeddings: Embeddings): # todo: support score threshold self._load_milvus(embeddings=embeddings) - return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD) + return self.milvus.similarity_search_with_score(query, top_k) - def add_doc(self, kb_file: KnowledgeFile): + def add_doc(self, kb_file: KnowledgeFile, **kwargs): """ 向知识库添加文件 """ @@ -60,22 +60,24 @@ class MilvusKBService(KBService): status = add_doc_to_db(kb_file) return status - def do_add_doc(self, docs: List[Document], embeddings: Embeddings): + def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs): pass - def do_delete_doc(self, kb_file: KnowledgeFile): + def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): filepath = kb_file.filepath.replace('\\', '\\\\') delete_list = [item.get("pk") for item in self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])] self.milvus.col.delete(expr=f'pk in {delete_list}') def do_clear_vs(self): - self.milvus.col.drop() + if not self.milvus.col: + self.milvus.col.drop() if __name__ == '__main__': # 测试建表使用 from server.db.base import Base, engine + Base.metadata.create_all(bind=engine) milvusService = MilvusKBService("test") milvusService.add_doc(KnowledgeFile("README.md", "test")) diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index a3126ec..6876bd8 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -43,12 +43,12 @@ class PGKBService(KBService): ''')) connect.commit() - def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]: + def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings): # todo: support score threshold self._load_pg_vector(embeddings=embeddings) - return self.pg_vector.similarity_search(query, top_k) + return self.pg_vector.similarity_search_with_score(query, top_k) - def add_doc(self, kb_file: KnowledgeFile): + def add_doc(self, kb_file: KnowledgeFile, **kwargs): """ 向知识库添加文件 """ @@ -58,10 +58,10 @@ class PGKBService(KBService): status = add_doc_to_db(kb_file) return status - def do_add_doc(self, docs: List[Document], embeddings: Embeddings): + def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs): pass - def do_delete_doc(self, kb_file: KnowledgeFile): + def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): with self.pg_vector.connect() as connect: filepath = kb_file.filepath.replace('\\', '\\\\') connect.execute( @@ -76,6 +76,7 @@ class PGKBService(KBService): if __name__ == '__main__': from server.db.base import Base, engine + Base.metadata.create_all(bind=engine) pGKBService = PGKBService("test") pGKBService.create_kb() diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index 1c023fa..c96d386 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -43,7 +43,11 @@ def folder2db( kb_file = KnowledgeFile(doc, kb_name) if callable(callback_before): callback_before(kb_file, i, docs) - kb.add_doc(kb_file) + if i == len(docs) - 1: + not_refresh_vs_cache = False + else: + not_refresh_vs_cache = True + kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) if callable(callback_after): callback_after(kb_file, i, docs) except Exception as e: @@ -67,7 +71,11 @@ def folder2db( kb_file = KnowledgeFile(doc, kb_name) if callable(callback_before): callback_before(kb_file, i, docs) - kb.update_doc(kb_file) + if i == len(docs) - 1: + not_refresh_vs_cache = False + else: + not_refresh_vs_cache = True + kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) if callable(callback_after): callback_after(kb_file, i, docs) except Exception as e: @@ -81,7 +89,11 @@ def folder2db( kb_file = KnowledgeFile(doc, kb_name) if callable(callback_before): callback_before(kb_file, i, docs) - kb.add_doc(kb_file) + if i == len(docs) - 1: + not_refresh_vs_cache = False + else: + not_refresh_vs_cache = True + kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) if callable(callback_after): callback_after(kb_file, i, docs) except Exception as e: diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 3ab6560..da53049 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -1,5 +1,7 @@ import os from langchain.embeddings.huggingface import HuggingFaceEmbeddings +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.embeddings import HuggingFaceBgeEmbeddings from configs.model_config import ( embedding_model_dict, KB_ROOT_PATH, @@ -41,11 +43,20 @@ def list_docs_from_folder(kb_name: str): @lru_cache(1) def load_embeddings(model: str, device: str): - embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], - model_kwargs={'device': device}) + if model == "text-embedding-ada-002": # openai text-embedding-ada-002 + embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE) + elif 'bge-' in model: + embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model], + model_kwargs={'device': device}, + query_instruction="为这个句子生成表示以用于检索相关文章:") + if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding + embeddings.query_instruction = "" + else: + embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device}) return embeddings + LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst', '.rtf', '.txt', '.xml', '.doc', '.docx', '.epub', '.odt', '.pdf', @@ -69,7 +80,7 @@ class KnowledgeFile: ): self.kb_name = knowledge_base_name self.filename = filename - self.ext = os.path.splitext(filename)[-1] + self.ext = os.path.splitext(filename)[-1].lower() if self.ext not in SUPPORTED_EXTS: raise ValueError(f"暂未支持的文件格式 {self.ext}") self.filepath = get_file_path(knowledge_base_name, filename) diff --git a/server/llm_api_launch.py b/server/llm_api_stale.py similarity index 98% rename from server/llm_api_launch.py rename to server/llm_api_stale.py index 0f7710a..cb02e0d 100644 --- a/server/llm_api_launch.py +++ b/server/llm_api_stale.py @@ -1,5 +1,5 @@ """ -调用示例: python llm_api_launch.py --model-path-address THUDM/chatglm2-6b@localhost@7650 THUDM/chatglm2-6b-32k@localhost@7651 +调用示例: python llm_api_stale.py --model-path-address THUDM/chatglm2-6b@localhost@7650 THUDM/chatglm2-6b-32k@localhost@7651 其他fastchat.server.controller/worker/openai_api_server参数可按照fastchat文档调用 但少数非关键参数如--worker-address,--allowed-origins,--allowed-methods,--allowed-headers不支持 diff --git a/server/utils.py b/server/utils.py index c0f11a5..4a88722 100644 --- a/server/utils.py +++ b/server/utils.py @@ -9,8 +9,8 @@ from typing import Any, Optional class BaseResponse(BaseModel): - code: int = pydantic.Field(200, description="HTTP status code") - msg: str = pydantic.Field("success", description="HTTP status message") + code: int = pydantic.Field(200, description="API status code") + msg: str = pydantic.Field("success", description="API status message") class Config: schema_extra = { diff --git a/webui_allinone.py b/server/webui_allinone_stale.py similarity index 93% rename from webui_allinone.py rename to server/webui_allinone_stale.py index 2992ae5..627f956 100644 --- a/webui_allinone.py +++ b/server/webui_allinone_stale.py @@ -20,9 +20,9 @@ from webui_pages.utils import * from streamlit_option_menu import option_menu from webui_pages import * import os -from server.llm_api_launch import string_args,launch_all,controller_args,worker_args,server_args,LOG_PATH +from server.llm_api_stale import string_args,launch_all,controller_args,worker_args,server_args,LOG_PATH -from server.api_allinone import parser, api_args +from server.api_allinone_stale import parser, api_args import subprocess parser.add_argument("--use-remote-api",action="store_true") diff --git a/startup.py b/startup.py new file mode 100644 index 0000000..df00851 --- /dev/null +++ b/startup.py @@ -0,0 +1,472 @@ +from multiprocessing import Process, Queue +import multiprocessing as mp +import subprocess +import sys +import os +from pprint import pprint + +# 设置numexpr最大线程数,默认为CPU核心数 +try: + import numexpr + n_cores = numexpr.utils.detect_number_of_cores() + os.environ["NUMEXPR_MAX_THREADS"] = str(n_cores) +except: + pass + +sys.path.append(os.path.dirname(os.path.dirname(__file__))) +from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, \ + logger +from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS, + FSCHAT_OPENAI_API, fschat_controller_address, fschat_model_worker_address, + fschat_openai_api_address, ) +from server.utils import MakeFastAPIOffline, FastAPI +import argparse +from typing import Tuple, List +from configs import VERSION + + +def set_httpx_timeout(timeout=60.0): + import httpx + httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout + httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout + httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout + + +def create_controller_app( + dispatch_method: str, +) -> FastAPI: + import fastchat.constants + fastchat.constants.LOGDIR = LOG_PATH + from fastchat.serve.controller import app, Controller + + controller = Controller(dispatch_method) + sys.modules["fastchat.serve.controller"].controller = controller + + MakeFastAPIOffline(app) + app.title = "FastChat Controller" + return app + + +def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]: + import fastchat.constants + fastchat.constants.LOGDIR = LOG_PATH + from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id + import argparse + import threading + import fastchat.serve.model_worker + + # workaround to make program exit with Ctrl+c + # it should be deleted after pr is merged by fastchat + def _new_init_heart_beat(self): + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True, + ) + self.heart_beat_thread.start() + + ModelWorker.init_heart_beat = _new_init_heart_beat + + parser = argparse.ArgumentParser() + args = parser.parse_args([]) + # default args. should be deleted after pr is merged by fastchat + args.gpus = None + args.max_gpu_memory = "20GiB" + args.load_8bit = False + args.cpu_offloading = None + args.gptq_ckpt = None + args.gptq_wbits = 16 + args.gptq_groupsize = -1 + args.gptq_act_order = False + args.awq_ckpt = None + args.awq_wbits = 16 + args.awq_groupsize = -1 + args.num_gpus = 1 + args.model_names = [] + args.conv_template = None + args.limit_worker_concurrency = 5 + args.stream_interval = 2 + args.no_register = False + + for k, v in kwargs.items(): + setattr(args, k, v) + + if args.gpus: + if args.num_gpus is None: + args.num_gpus = len(args.gpus.split(',')) + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + + gptq_config = GptqConfig( + ckpt=args.gptq_ckpt or args.model_path, + wbits=args.gptq_wbits, + groupsize=args.gptq_groupsize, + act_order=args.gptq_act_order, + ) + awq_config = AWQConfig( + ckpt=args.awq_ckpt or args.model_path, + wbits=args.awq_wbits, + groupsize=args.awq_groupsize, + ) + + worker = ModelWorker( + controller_addr=args.controller_address, + worker_addr=args.worker_address, + worker_id=worker_id, + model_path=args.model_path, + model_names=args.model_names, + limit_worker_concurrency=args.limit_worker_concurrency, + no_register=args.no_register, + device=args.device, + num_gpus=args.num_gpus, + max_gpu_memory=args.max_gpu_memory, + load_8bit=args.load_8bit, + cpu_offloading=args.cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + stream_interval=args.stream_interval, + conv_template=args.conv_template, + ) + + sys.modules["fastchat.serve.model_worker"].worker = worker + sys.modules["fastchat.serve.model_worker"].args = args + sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config + + MakeFastAPIOffline(app) + app.title = f"FastChat LLM Server ({LLM_MODEL})" + return app + + +def create_openai_api_app( + controller_address: str, + api_keys: List = [], +) -> FastAPI: + import fastchat.constants + fastchat.constants.LOGDIR = LOG_PATH + from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings + + app.add_middleware( + CORSMiddleware, + allow_credentials=True, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], + ) + + app_settings.controller_address = controller_address + app_settings.api_keys = api_keys + + MakeFastAPIOffline(app) + app.title = "FastChat OpeanAI API Server" + return app + + +def _set_app_seq(app: FastAPI, q: Queue, run_seq: int): + if run_seq == 1: + @app.on_event("startup") + async def on_startup(): + set_httpx_timeout() + q.put(run_seq) + elif run_seq > 1: + @app.on_event("startup") + async def on_startup(): + set_httpx_timeout() + while True: + no = q.get() + if no != run_seq - 1: + q.put(no) + else: + break + q.put(run_seq) + + +def run_controller(q: Queue, run_seq: int = 1): + import uvicorn + + app = create_controller_app(FSCHAT_CONTROLLER.get("dispatch_method")) + _set_app_seq(app, q, run_seq) + + host = FSCHAT_CONTROLLER["host"] + port = FSCHAT_CONTROLLER["port"] + uvicorn.run(app, host=host, port=port) + + +def run_model_worker( + model_name: str = LLM_MODEL, + controller_address: str = "", + q: Queue = None, + run_seq: int = 2, +): + import uvicorn + + kwargs = FSCHAT_MODEL_WORKERS[LLM_MODEL].copy() + host = kwargs.pop("host") + port = kwargs.pop("port") + model_path = llm_model_dict[model_name].get("local_model_path", "") + kwargs["model_path"] = model_path + kwargs["model_names"] = [model_name] + kwargs["controller_address"] = controller_address or fschat_controller_address() + kwargs["worker_address"] = fschat_model_worker_address() + + app = create_model_worker_app(**kwargs) + _set_app_seq(app, q, run_seq) + + uvicorn.run(app, host=host, port=port) + + +def run_openai_api(q: Queue, run_seq: int = 3): + import uvicorn + + controller_addr = fschat_controller_address() + app = create_openai_api_app(controller_addr) # todo: not support keys yet. + _set_app_seq(app, q, run_seq) + + host = FSCHAT_OPENAI_API["host"] + port = FSCHAT_OPENAI_API["port"] + uvicorn.run(app, host=host, port=port) + + +def run_api_server(q: Queue, run_seq: int = 4): + from server.api import create_app + import uvicorn + + app = create_app() + _set_app_seq(app, q, run_seq) + + host = API_SERVER["host"] + port = API_SERVER["port"] + + uvicorn.run(app, host=host, port=port) + + +def run_webui(q: Queue, run_seq: int = 5): + host = WEBUI_SERVER["host"] + port = WEBUI_SERVER["port"] + while True: + no = q.get() + if no != run_seq - 1: + q.put(no) + else: + break + q.put(run_seq) + p = subprocess.Popen(["streamlit", "run", "webui.py", + "--server.address", host, + "--server.port", str(port)]) + p.wait() + + +def parse_args() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "-a", + "--all-webui", + action="store_true", + help="run fastchat's controller/openai_api/model_worker servers, run api.py and webui.py", + dest="all_webui", + ) + parser.add_argument( + "--all-api", + action="store_true", + help="run fastchat's controller/openai_api/model_worker servers, run api.py", + dest="all_api", + ) + parser.add_argument( + "--llm-api", + action="store_true", + help="run fastchat's controller/openai_api/model_worker servers", + dest="llm_api", + ) + parser.add_argument( + "-o", + "--openai-api", + action="store_true", + help="run fastchat's controller/openai_api servers", + dest="openai_api", + ) + parser.add_argument( + "-m", + "--model-worker", + action="store_true", + help="run fastchat's model_worker server with specified model name. specify --model-name if not using default LLM_MODEL", + dest="model_worker", + ) + parser.add_argument( + "-n", + "--model-name", + type=str, + default=LLM_MODEL, + help="specify model name for model worker.", + dest="model_name", + ) + parser.add_argument( + "-c", + "--controller", + type=str, + help="specify controller address the worker is registered to. default is server_config.FSCHAT_CONTROLLER", + dest="controller_address", + ) + parser.add_argument( + "--api", + action="store_true", + help="run api.py server", + dest="api", + ) + parser.add_argument( + "-w", + "--webui", + action="store_true", + help="run webui.py server", + dest="webui", + ) + args = parser.parse_args() + return args + + +def dump_server_info(after_start=False): + import platform + import langchain + import fastchat + from configs.server_config import api_address, webui_address + + print("\n\n") + print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30) + print(f"操作系统:{platform.platform()}.") + print(f"python版本:{sys.version}") + print(f"项目版本:{VERSION}") + print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}") + print("\n") + print(f"当前LLM模型:{LLM_MODEL} @ {LLM_DEVICE}") + pprint(llm_model_dict[LLM_MODEL]) + print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}") + if after_start: + print("\n") + print(f"服务端运行信息:") + if args.openai_api: + print(f" OpenAI API Server: {fschat_openai_api_address()}/v1") + print(" (请确认llm_model_dict中配置的api_base_url与上面地址一致。)") + if args.api: + print(f" Chatchat API Server: {api_address()}") + if args.webui: + print(f" Chatchat WEBUI Server: {webui_address()}") + print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30) + print("\n\n") + + +if __name__ == "__main__": + import time + + mp.set_start_method("spawn") + queue = Queue() + args = parse_args() + if args.all_webui: + args.openai_api = True + args.model_worker = True + args.api = True + args.webui = True + + elif args.all_api: + args.openai_api = True + args.model_worker = True + args.api = True + args.webui = False + + elif args.llm_api: + args.openai_api = True + args.model_worker = True + args.api = False + args.webui = False + + dump_server_info() + logger.info(f"正在启动服务:") + logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}") + + processes = {} + + if args.openai_api: + process = Process( + target=run_controller, + name=f"controller({os.getpid()})", + args=(queue, len(processes) + 1), + daemon=True, + ) + process.start() + processes["controller"] = process + + process = Process( + target=run_openai_api, + name=f"openai_api({os.getpid()})", + args=(queue, len(processes) + 1), + daemon=True, + ) + process.start() + processes["openai_api"] = process + + if args.model_worker: + process = Process( + target=run_model_worker, + name=f"model_worker({os.getpid()})", + args=(args.model_name, args.controller_address, queue, len(processes) + 1), + daemon=True, + ) + process.start() + processes["model_worker"] = process + + if args.api: + process = Process( + target=run_api_server, + name=f"API Server{os.getpid()})", + args=(queue, len(processes) + 1), + daemon=True, + ) + process.start() + processes["api"] = process + + if args.webui: + process = Process( + target=run_webui, + name=f"WEBUI Server{os.getpid()})", + args=(queue, len(processes) + 1), + daemon=True, + ) + process.start() + processes["webui"] = process + + try: + # log infors + while True: + no = queue.get() + if no == len(processes): + time.sleep(0.5) + dump_server_info(True) + break + else: + queue.put(no) + + if model_worker_process := processes.get("model_worker"): + model_worker_process.join() + for name, process in processes.items(): + if name != "model_worker": + process.join() + except: + if model_worker_process := processes.get("model_worker"): + model_worker_process.terminate() + for name, process in processes.items(): + if name != "model_worker": + process.terminate() + +# 服务启动后接口调用示例: +# import openai +# openai.api_key = "EMPTY" # Not support yet +# openai.api_base = "http://localhost:8888/v1" + +# model = "chatglm2-6b" + +# # create a chat completion +# completion = openai.ChatCompletion.create( +# model=model, +# messages=[{"role": "user", "content": "Hello! What is your name?"}] +# ) +# # print the completion +# print(completion.choices[0].message.content) diff --git a/tests/api/stream_api_test.py b/tests/api/stream_api_test.py index 06a9654..2902c8a 100644 --- a/tests/api/stream_api_test.py +++ b/tests/api/stream_api_test.py @@ -28,4 +28,14 @@ if __name__ == "__main__": for line in response.iter_content(decode_unicode=True): print(line, flush=True) else: - print("Error:", response.status_code) \ No newline at end of file + print("Error:", response.status_code) + + + r = requests.post( + openai_url + "/chat/completions", + json={"model": LLM_MODEL, "messages": "你好", "max_tokens": 1000}) + data = r.json() + print(f"/chat/completions\n") + print(data) + assert "choices" in data + diff --git a/tests/api/test_kb_api.py b/tests/api/test_kb_api.py new file mode 100644 index 0000000..5a8b97d --- /dev/null +++ b/tests/api/test_kb_api.py @@ -0,0 +1,204 @@ +from doctest import testfile +import requests +import json +import sys +from pathlib import Path + +root_path = Path(__file__).parent.parent.parent +sys.path.append(str(root_path)) +from configs.server_config import api_address +from configs.model_config import VECTOR_SEARCH_TOP_K +from server.knowledge_base.utils import get_kb_path + +from pprint import pprint + + +api_base_url = api_address() + +kb = "kb_for_api_test" +test_files = { + "README.MD": str(root_path / "README.MD"), + "FAQ.MD": str(root_path / "docs" / "FAQ.MD") +} + + +def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"): + if not Path(get_kb_path(kb)).exists(): + return + + url = api_base_url + api + print("\n测试知识库存在,需要删除") + r = requests.post(url, json=kb) + data = r.json() + pprint(data) + + # check kb not exists anymore + url = api_base_url + "/knowledge_base/list_knowledge_bases" + print("\n获取知识库列表:") + r = requests.get(url) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert isinstance(data["data"], list) and len(data["data"]) > 0 + assert kb not in data["data"] + + +def test_create_kb(api="/knowledge_base/create_knowledge_base"): + url = api_base_url + api + + print(f"\n尝试用空名称创建知识库:") + r = requests.post(url, json={"knowledge_base_name": " "}) + data = r.json() + pprint(data) + assert data["code"] == 404 + assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称" + + print(f"\n创建新知识库: {kb}") + r = requests.post(url, json={"knowledge_base_name": kb}) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert data["msg"] == f"已新增知识库 {kb}" + + print(f"\n尝试创建同名知识库: {kb}") + r = requests.post(url, json={"knowledge_base_name": kb}) + data = r.json() + pprint(data) + assert data["code"] == 404 + assert data["msg"] == f"已存在同名知识库 {kb}" + + +def test_list_kbs(api="/knowledge_base/list_knowledge_bases"): + url = api_base_url + api + print("\n获取知识库列表:") + r = requests.get(url) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert isinstance(data["data"], list) and len(data["data"]) > 0 + assert kb in data["data"] + + +def test_upload_doc(api="/knowledge_base/upload_doc"): + url = api_base_url + api + for name, path in test_files.items(): + print(f"\n上传知识文件: {name}") + data = {"knowledge_base_name": kb, "override": True} + files = {"file": (name, open(path, "rb"))} + r = requests.post(url, data=data, files=files) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert data["msg"] == f"成功上传文件 {name}" + + for name, path in test_files.items(): + print(f"\n尝试重新上传知识文件: {name}, 不覆盖") + data = {"knowledge_base_name": kb, "override": False} + files = {"file": (name, open(path, "rb"))} + r = requests.post(url, data=data, files=files) + data = r.json() + pprint(data) + assert data["code"] == 404 + assert data["msg"] == f"文件 {name} 已存在。" + + for name, path in test_files.items(): + print(f"\n尝试重新上传知识文件: {name}, 覆盖") + data = {"knowledge_base_name": kb, "override": True} + files = {"file": (name, open(path, "rb"))} + r = requests.post(url, data=data, files=files) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert data["msg"] == f"成功上传文件 {name}" + + +def test_list_docs(api="/knowledge_base/list_docs"): + url = api_base_url + api + print("\n获取知识库中文件列表:") + r = requests.get(url, params={"knowledge_base_name": kb}) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert isinstance(data["data"], list) + for name in test_files: + assert name in data["data"] + + +def test_search_docs(api="/knowledge_base/search_docs"): + url = api_base_url + api + query = "介绍一下langchain-chatchat项目" + print("\n检索知识库:") + print(query) + r = requests.post(url, json={"knowledge_base_name": kb, "query": query}) + data = r.json() + pprint(data) + assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K + + +def test_update_doc(api="/knowledge_base/update_doc"): + url = api_base_url + api + for name, path in test_files.items(): + print(f"\n更新知识文件: {name}") + r = requests.post(url, json={"knowledge_base_name": kb, "file_name": name}) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert data["msg"] == f"成功更新文件 {name}" + + +def test_delete_doc(api="/knowledge_base/delete_doc"): + url = api_base_url + api + for name, path in test_files.items(): + print(f"\n删除知识文件: {name}") + r = requests.post(url, json={"knowledge_base_name": kb, "doc_name": name}) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert data["msg"] == f"{name} 文件删除成功" + + url = api_base_url + "/knowledge_base/search_docs" + query = "介绍一下langchain-chatchat项目" + print("\n尝试检索删除后的检索知识库:") + print(query) + r = requests.post(url, json={"knowledge_base_name": kb, "query": query}) + data = r.json() + pprint(data) + assert isinstance(data, list) and len(data) == 0 + + +def test_recreate_vs(api="/knowledge_base/recreate_vector_store"): + url = api_base_url + api + print("\n重建知识库:") + r = requests.post(url, json={"knowledge_base_name": kb}, stream=True) + for chunk in r.iter_content(None): + data = json.loads(chunk) + assert isinstance(data, dict) + assert data["code"] == 200 + print(data["msg"]) + + url = api_base_url + "/knowledge_base/search_docs" + query = "本项目支持哪些文件格式?" + print("\n尝试检索重建后的检索知识库:") + print(query) + r = requests.post(url, json={"knowledge_base_name": kb, "query": query}) + data = r.json() + pprint(data) + assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K + + +def test_delete_kb_after(api="/knowledge_base/delete_knowledge_base"): + url = api_base_url + api + print("\n删除知识库") + r = requests.post(url, json=kb) + data = r.json() + pprint(data) + + # check kb not exists anymore + url = api_base_url + "/knowledge_base/list_knowledge_bases" + print("\n获取知识库列表:") + r = requests.get(url) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert isinstance(data["data"], list) and len(data["data"]) > 0 + assert kb not in data["data"] diff --git a/tests/api/test_stream_chat_api.py b/tests/api/test_stream_chat_api.py new file mode 100644 index 0000000..56d3237 --- /dev/null +++ b/tests/api/test_stream_chat_api.py @@ -0,0 +1,108 @@ +import requests +import json +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from configs.server_config import API_SERVER, api_address + +from pprint import pprint + + +api_base_url = api_address() + + +def dump_input(d, title): + print("\n") + print("=" * 30 + title + " input " + "="*30) + pprint(d) + + +def dump_output(r, title): + print("\n") + print("=" * 30 + title + " output" + "="*30) + for line in r.iter_content(None, decode_unicode=True): + print(line, end="", flush=True) + + +headers = { + 'accept': 'application/json', + 'Content-Type': 'application/json', +} + +data = { + "query": "请用100字左右的文字介绍自己", + "history": [ + { + "role": "user", + "content": "你好" + }, + { + "role": "assistant", + "content": "你好,我是 ChatGLM" + } + ], + "stream": True +} + + + +def test_chat_fastchat(api="/chat/fastchat"): + url = f"{api_base_url}{api}" + data2 = { + "stream": True, + "messages": data["history"] + [{"role": "user", "content": "推荐一部科幻电影"}] + } + dump_input(data2, api) + response = requests.post(url, headers=headers, json=data2, stream=True) + dump_output(response, api) + assert response.status_code == 200 + + +def test_chat_chat(api="/chat/chat"): + url = f"{api_base_url}{api}" + dump_input(data, api) + response = requests.post(url, headers=headers, json=data, stream=True) + dump_output(response, api) + assert response.status_code == 200 + + +def test_knowledge_chat(api="/chat/knowledge_base_chat"): + url = f"{api_base_url}{api}" + data = { + "query": "如何提问以获得高质量答案", + "knowledge_base_name": "samples", + "history": [ + { + "role": "user", + "content": "你好" + }, + { + "role": "assistant", + "content": "你好,我是 ChatGLM" + } + ], + "stream": True + } + dump_input(data, api) + response = requests.post(url, headers=headers, json=data, stream=True) + print("\n") + print("=" * 30 + api + " output" + "="*30) + first = True + for line in response.iter_content(None, decode_unicode=True): + data = json.loads(line) + if first: + for doc in data["docs"]: + print(doc) + first = False + print(data["answer"], end="", flush=True) + assert response.status_code == 200 + + +def test_search_engine_chat(api="/chat/search_engine_chat"): + url = f"{api_base_url}{api}" + for se in ["bing", "duckduckgo"]: + dump_input(data, api) + response = requests.post(url, json=data, stream=True) + dump_output(response, api) + assert response.status_code == 200 diff --git a/webui.py b/webui.py index 99db3f6..58fc0e3 100644 --- a/webui.py +++ b/webui.py @@ -9,6 +9,7 @@ from webui_pages.utils import * from streamlit_option_menu import option_menu from webui_pages import * import os +from configs import VERSION api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False) @@ -17,6 +18,11 @@ if __name__ == "__main__": "Langchain-Chatchat WebUI", os.path.join("img", "chatchat_icon_blue_square_v2.png"), initial_sidebar_state="expanded", + menu_items={ + 'Get Help': 'https://github.com/chatchat-space/Langchain-Chatchat', + 'Report a bug': "https://github.com/chatchat-space/Langchain-Chatchat/issues", + 'About': f"""欢迎使用 Langchain-Chatchat WebUI {VERSION}!""" + } ) if not chat_box.chat_inited: @@ -35,7 +41,7 @@ if __name__ == "__main__": "func": knowledge_base_page, }, } - + with st.sidebar: st.image( os.path.join( @@ -44,6 +50,10 @@ if __name__ == "__main__": ), use_column_width=True ) + st.caption( + f"""

当前版本:{VERSION}

""", + unsafe_allow_html=True, + ) options = list(pages) icons = [x["icon"] for x in pages.values()] diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index f059475..4351e95 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -118,7 +118,7 @@ def knowledge_base_page(api: ApiRequest): vector_store_type=vs_type, embed_model=embed_model, ) - st.toast(ret["msg"]) + st.toast(ret.get("msg", " ")) st.session_state["selected_kb_name"] = kb_name st.experimental_rerun() @@ -138,12 +138,14 @@ def knowledge_base_page(api: ApiRequest): # use_container_width=True, disabled=len(files) == 0, ): - for f in files: - ret = api.upload_kb_doc(f, kb) - if ret["code"] == 200: - st.toast(ret["msg"], icon="✔") - else: - st.toast(ret["msg"], icon="✖") + data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files] + data[-1]["not_refresh_vs_cache"]=False + for k in data: + ret = api.upload_kb_doc(**k) + if msg := check_success_msg(ret): + st.toast(msg, icon="✔") + elif msg := check_error_msg(ret): + st.toast(msg, icon="✖") st.session_state.files = [] st.divider() @@ -235,7 +237,7 @@ def knowledge_base_page(api: ApiRequest): ): for row in selected_rows: ret = api.delete_kb_doc(kb, row["file_name"], True) - st.toast(ret["msg"]) + st.toast(ret.get("msg", " ")) st.experimental_rerun() st.divider() @@ -249,12 +251,14 @@ def knowledge_base_page(api: ApiRequest): use_container_width=True, type="primary", ): - with st.spinner("向量库重构中"): + with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"): empty = st.empty() empty.progress(0.0, "") for d in api.recreate_vector_store(kb): - print(d) - empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}") + if msg := check_error_msg(d): + st.toast(msg) + else: + empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}") st.experimental_rerun() if cols[2].button( @@ -262,6 +266,6 @@ def knowledge_base_page(api: ApiRequest): use_container_width=True, ): ret = api.delete_knowledge_base(kb) - st.toast(ret["msg"]) + st.toast(ret.get("msg", " ")) time.sleep(1) st.experimental_rerun() diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 3e67ed7..c666d45 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -229,18 +229,18 @@ class ApiRequest: elif chunk.strip(): yield chunk except httpx.ConnectError as e: - msg = f"无法连接API服务器,请确认已执行python server\\api.py" + msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。" logger.error(msg) logger.error(e) - yield {"code": 500, "errorMsg": msg} + yield {"code": 500, "msg": msg} except httpx.ReadTimeout as e: msg = f"API通信超时,请确认已启动FastChat与API服务(详见RADME '5. 启动 API 服务或 Web UI')" logger.error(msg) logger.error(e) - yield {"code": 500, "errorMsg": msg} + yield {"code": 500, "msg": msg} except Exception as e: logger.error(e) - yield {"code": 500, "errorMsg": str(e)} + yield {"code": 500, "msg": str(e)} # 对话相关操作 @@ -394,7 +394,7 @@ class ApiRequest: return response.json() except Exception as e: logger.error(e) - return {"code": 500, "errorMsg": errorMsg or str(e)} + return {"code": 500, "msg": errorMsg or str(e)} def list_knowledge_bases( self, @@ -496,6 +496,7 @@ class ApiRequest: knowledge_base_name: str, filename: str = None, override: bool = False, + not_refresh_vs_cache: bool = False, no_remote_api: bool = None, ): ''' @@ -529,7 +530,11 @@ class ApiRequest: else: response = self.post( "/knowledge_base/upload_doc", - data={"knowledge_base_name": knowledge_base_name, "override": override}, + data={ + "knowledge_base_name": knowledge_base_name, + "override": override, + "not_refresh_vs_cache": not_refresh_vs_cache, + }, files={"file": (filename, file)}, ) return self._check_httpx_json_response(response) @@ -539,6 +544,7 @@ class ApiRequest: knowledge_base_name: str, doc_name: str, delete_content: bool = False, + not_refresh_vs_cache: bool = False, no_remote_api: bool = None, ): ''' @@ -551,6 +557,7 @@ class ApiRequest: "knowledge_base_name": knowledge_base_name, "doc_name": doc_name, "delete_content": delete_content, + "not_refresh_vs_cache": not_refresh_vs_cache, } if no_remote_api: @@ -568,6 +575,7 @@ class ApiRequest: self, knowledge_base_name: str, file_name: str, + not_refresh_vs_cache: bool = False, no_remote_api: bool = None, ): ''' @@ -583,7 +591,11 @@ class ApiRequest: else: response = self.post( "/knowledge_base/update_doc", - json={"knowledge_base_name": knowledge_base_name, "file_name": file_name}, + json={ + "knowledge_base_name": knowledge_base_name, + "file_name": file_name, + "not_refresh_vs_cache": not_refresh_vs_cache, + }, ) return self._check_httpx_json_response(response) @@ -617,7 +629,7 @@ class ApiRequest: "/knowledge_base/recreate_vector_store", json=data, stream=True, - timeout=False, + timeout=None, ) return self._httpx_stream2generator(response, as_json=True) @@ -626,7 +638,22 @@ def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str: ''' return error message if error occured when requests API ''' - if isinstance(data, dict) and key in data: + if isinstance(data, dict): + if key in data: + return data[key] + if "code" in data and data["code"] != 200: + return data["msg"] + return "" + + +def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str: + ''' + return error message if error occured when requests API + ''' + if (isinstance(data, dict) + and key in data + and "code" in data + and data["code"] == 200): return data[key] return ""