diff --git a/README.md b/README.md
index b690ad0..991487b 100644
--- a/README.md
+++ b/README.md
@@ -126,6 +126,7 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
- [BAAI/bge-small-zh](https://huggingface.co/BAAI/bge-small-zh)
- [BAAI/bge-base-zh](https://huggingface.co/BAAI/bge-base-zh)
- [BAAI/bge-large-zh](https://huggingface.co/BAAI/bge-large-zh)
+- [BAAI/bge-large-zh-noinstruct](https://huggingface.co/BAAI/bge-large-zh-noinstruct)
- [text2vec-base-chinese-sentence](https://huggingface.co/shibing624/text2vec-base-chinese-sentence)
- [text2vec-base-chinese-paraphrase](https://huggingface.co/shibing624/text2vec-base-chinese-paraphrase)
- [text2vec-base-multilingual](https://huggingface.co/shibing624/text2vec-base-multilingual)
@@ -133,6 +134,7 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
- [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese)
- [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh)
- [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh)
+- [OpenAI/text-embedding-ada-002](https://platform.openai.com/docs/guides/embeddings)
---
@@ -206,6 +208,7 @@ embedding_model_dict = {
"m3e-base": "/Users/xxx/Downloads/m3e-base",
}
```
+如果你选择使用OpenAI的Embedding模型,请将模型的```key```写入`embedding_model_dict`中。使用该模型,你需要鞥能够访问OpenAI官的API,或设置代理。
### 4. 知识库初始化与迁移
@@ -298,24 +301,13 @@ $ python server/llm_api_shutdown.py --serve all
亦可单独停止一个 FastChat 服务模块,可选 [`all`, `controller`, `model_worker`, `openai_api_server`]
-##### 5.1.3 PEFT 加载
+##### 5.1.3 PEFT 加载(包括lora,p-tuning,prefix tuning, prompt tuning,ia等)
本项目基于 FastChat 加载 LLM 服务,故需以 FastChat 加载 PEFT 路径,即保证路径名称里必须有 peft 这个词,配置文件的名字为 adapter_config.json,peft 路径下包含 model.bin 格式的 PEFT 权重。
+详细步骤参考[加载lora微调后模型失效](https://github.com/chatchat-space/Langchain-Chatchat/issues/1130#issuecomment-1685291822)
-示例代码如下:
+
-```shell
-PEFT_SHARE_BASE_WEIGHTS=true python3 -m fastchat.serve.multi_model_worker \
- --model-path /data/chris/peft-llama-dummy-1 \
- --model-names peft-dummy-1 \
- --model-path /data/chris/peft-llama-dummy-2 \
- --model-names peft-dummy-2 \
- --model-path /data/chris/peft-llama-dummy-3 \
- --model-names peft-dummy-3 \
- --num-gpus 2
-```
-
-详见 [FastChat 相关 PR](https://github.com/lm-sys/fastchat/pull/1905#issuecomment-1627801216)
#### 5.2 启动 API 服务
@@ -441,6 +433,6 @@ $ python startup.py --all-webui --model-name Qwen-7B-Chat
## 项目交流群
-
+
🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
diff --git a/common/__init__.py b/common/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/configs/model_config.py.example b/configs/model_config.py.example
index ccfd1b6..5b2574e 100644
--- a/configs/model_config.py.example
+++ b/configs/model_config.py.example
@@ -24,7 +24,9 @@ embedding_model_dict = {
"m3e-large": "moka-ai/m3e-large",
"bge-small-zh": "BAAI/bge-small-zh",
"bge-base-zh": "BAAI/bge-base-zh",
- "bge-large-zh": "BAAI/bge-large-zh"
+ "bge-large-zh": "BAAI/bge-large-zh",
+ "bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
+ "text-embedding-ada-002": os.environ.get("OPENAI_API_KEY")
}
# 选用的 Embedding 名称
@@ -41,12 +43,6 @@ llm_model_dict = {
"api_key": "EMPTY"
},
- "chatglm-6b-int4": {
- "local_model_path": "THUDM/chatglm-6b-int4",
- "api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
- "api_key": "EMPTY"
- },
-
"chatglm2-6b": {
"local_model_path": "THUDM/chatglm2-6b",
"api_base_url": "http://localhost:8888/v1", # URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
@@ -59,12 +55,6 @@ llm_model_dict = {
"api_key": "EMPTY"
},
- "vicuna-13b-hf": {
- "local_model_path": "",
- "api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
- "api_key": "EMPTY"
- },
-
# 调用chatgpt时如果报出: urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
# Max retries exceeded with url: /v1/chat/completions
# 则需要将urllib3版本修改为1.25.11
@@ -75,10 +65,15 @@ llm_model_dict = {
# urllib3.exceptions.NewConnectionError: :
# Failed to establish a new connection: [WinError 10060]
# 则是因为内地和香港的IP都被OPENAI封了,需要切换为日本、新加坡等地
+
+ # 如果出现WARNING: Retrying langchain.chat_models.openai.acompletion_with_retry.._completion_with_retry in
+ # 4.0 seconds as it raised APIConnectionError: Error communicating with OpenAI.
+ # 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置
"gpt-3.5-turbo": {
"local_model_path": "gpt-3.5-turbo",
"api_base_url": "https://api.openai.com/v1",
- "api_key": os.environ.get("OPENAI_API_KEY")
+ "api_key": os.environ.get("OPENAI_API_KEY"),
+ "openai_proxy": os.environ.get("OPENAI_PROXY")
},
}
@@ -114,7 +109,7 @@ kbs_config = {
"secure": False,
},
"pg": {
- "connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatglm",
+ "connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
}
}
@@ -142,12 +137,12 @@ SEARCH_ENGINE_TOP_K = 5
# nltk 模型存储路径
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
-# 基于本地知识问答的提示词模版
-PROMPT_TEMPLATE = """【指令】根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。
+# 基于本地知识问答的提示词模版(使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号
+PROMPT_TEMPLATE = """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 指令>
-【已知信息】{context}
+<已知信息>{{ context }}已知信息>
-【问题】{question}"""
+<问题>{{ question }}问题>"""
# API 是否开启跨域,默认为False,如果需要开启,请设置为True
# is open cross domain
diff --git a/docs/FAQ.md b/docs/FAQ.md
index 62fe080..490eb25 100644
--- a/docs/FAQ.md
+++ b/docs/FAQ.md
@@ -170,3 +170,16 @@ A13: 疑为 chatglm 的 quantization 的问题或 torch 版本差异问题,针
Q14: 修改配置中路径后,加载 text2vec-large-chinese 依然提示 `WARNING: No sentence-transformers model found with name text2vec-large-chinese. Creating a new one with MEAN pooling.`
A14: 尝试更换 embedding,如 text2vec-base-chinese,请在 [configs/model_config.py](../configs/model_config.py) 文件中,修改 `text2vec-base`参数为本地路径,绝对路径或者相对路径均可
+
+
+---
+
+Q15: 使用pg向量库建表报错
+
+A15: 需要手动安装对应的vector扩展(连接pg执行 CREATE EXTENSION IF NOT EXISTS vector)
+
+---
+
+Q16: pymilvus 连接超时
+
+A16.pymilvus版本需要匹配和milvus对应否则会超时参考pymilvus==2.1.3
\ No newline at end of file
diff --git a/docs/docker/vector_db/pg/docker-compose.yml b/docs/docker/vector_db/pg/docker-compose.yml
index 8e8359c..b14296b 100644
--- a/docs/docker/vector_db/pg/docker-compose.yml
+++ b/docs/docker/vector_db/pg/docker-compose.yml
@@ -2,9 +2,9 @@ version: "3.8"
services:
postgresql:
image: ankane/pgvector:v0.4.1
- container_name: langchain-chatgml-pg-db
+ container_name: langchain_chatchat-pg-db
environment:
- POSTGRES_DB: langchain_chatgml
+ POSTGRES_DB: langchain_chatchat
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
ports:
diff --git a/docs/向量库环境docker.md b/docs/向量库环境docker.md
index 162b0f0..dd5e2cb 100644
--- a/docs/向量库环境docker.md
+++ b/docs/向量库环境docker.md
@@ -5,3 +5,4 @@
cd docs/docker/vector_db/milvus
docker-compose up -d
```
+
diff --git a/img/qr_code_55.jpg b/img/qr_code_55.jpg
new file mode 100644
index 0000000..8ff046c
Binary files /dev/null and b/img/qr_code_55.jpg differ
diff --git a/img/qr_code_56.jpg b/img/qr_code_56.jpg
new file mode 100644
index 0000000..f17458d
Binary files /dev/null and b/img/qr_code_56.jpg differ
diff --git a/init_database.py b/init_database.py
index 61d00e1..7fc8494 100644
--- a/init_database.py
+++ b/init_database.py
@@ -2,6 +2,8 @@ from server.knowledge_base.migrate import create_tables, folder2db, recreate_all
from configs.model_config import NLTK_DATA_PATH
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
+from startup import dump_server_info
+
if __name__ == "__main__":
import argparse
@@ -21,6 +23,8 @@ if __name__ == "__main__":
)
args = parser.parse_args()
+ dump_server_info()
+
create_tables()
print("database talbes created")
diff --git a/server/chat/chat.py b/server/chat/chat.py
index 2bc21db..2e939f1 100644
--- a/server/chat/chat.py
+++ b/server/chat/chat.py
@@ -21,7 +21,7 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
),
stream: bool = Body(False, description="流式输出"),
):
- history = [History(**h) if isinstance(h, dict) else h for h in history]
+ history = [History.from_data(h) for h in history]
async def chat_iterator(query: str,
history: List[History] = [],
@@ -34,11 +34,13 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
- model_name=LLM_MODEL
+ model_name=LLM_MODEL,
+ openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
)
+ input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
- [i.to_msg_tuple() for i in history] + [("human", "{input}")])
+ [i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model)
# Begin a task that runs in the background.
diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py
index 84c62f0..2774569 100644
--- a/server/chat/knowledge_base_chat.py
+++ b/server/chat/knowledge_base_chat.py
@@ -38,7 +38,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
- history = [History(**h) if isinstance(h, dict) else h for h in history]
+ history = [History.from_data(h) for h in history]
async def knowledge_base_chat_iterator(query: str,
kb: KBService,
@@ -52,13 +52,15 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
- model_name=LLM_MODEL
+ model_name=LLM_MODEL,
+ openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
)
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs])
+ input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
- [i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
+ [i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model)
diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py
index 15834d0..032d06a 100644
--- a/server/chat/search_engine_chat.py
+++ b/server/chat/search_engine_chat.py
@@ -73,6 +73,8 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
+ history = [History.from_data(h) for h in history]
+
async def search_engine_chat_iterator(query: str,
search_engine_name: str,
top_k: int,
@@ -85,14 +87,16 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
- model_name=LLM_MODEL
+ model_name=LLM_MODEL,
+ openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
)
docs = lookup_search_engine(query, search_engine_name, top_k)
context = "\n".join([doc.page_content for doc in docs])
+ input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
- [i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
+ [i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model)
@@ -117,7 +121,7 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
answer = ""
async for token in callback.aiter():
answer += token
- yield json.dumps({"answer": token,
+ yield json.dumps({"answer": answer,
"docs": source_documents},
ensure_ascii=False)
await task
diff --git a/server/chat/utils.py b/server/chat/utils.py
index f8afb10..2167f10 100644
--- a/server/chat/utils.py
+++ b/server/chat/utils.py
@@ -1,6 +1,7 @@
import asyncio
-from typing import Awaitable
+from typing import Awaitable, List, Tuple, Dict, Union
from pydantic import BaseModel, Field
+from langchain.prompts.chat import ChatMessagePromptTemplate
async def wrap_done(fn: Awaitable, event: asyncio.Event):
@@ -28,3 +29,29 @@ class History(BaseModel):
def to_msg_tuple(self):
return "ai" if self.role=="assistant" else "human", self.content
+
+ def to_msg_template(self, is_raw=True) -> ChatMessagePromptTemplate:
+ role_maps = {
+ "ai": "assistant",
+ "human": "user",
+ }
+ role = role_maps.get(self.role, self.role)
+ if is_raw: # 当前默认历史消息都是没有input_variable的文本。
+ content = "{% raw %}" + self.content + "{% endraw %}"
+ else:
+ content = self.content
+
+ return ChatMessagePromptTemplate.from_template(
+ content,
+ "jinja2",
+ role=role,
+ )
+
+ @classmethod
+ def from_data(cls, h: Union[List, Tuple, Dict]) -> "History":
+ if isinstance(h, (list,tuple)) and len(h) >= 2:
+ h = cls(role=h[0], content=h[1])
+ elif isinstance(h, dict):
+ h = cls(**h)
+
+ return h
diff --git a/server/knowledge_base/kb_api.py b/server/knowledge_base/kb_api.py
index 4753ba4..b9151b8 100644
--- a/server/knowledge_base/kb_api.py
+++ b/server/knowledge_base/kb_api.py
@@ -15,7 +15,7 @@ async def list_kbs():
async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"),
embed_model: str = Body(EMBEDDING_MODEL),
- ):
+ ) -> BaseResponse:
# Create selected knowledge base
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@@ -27,13 +27,18 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
- kb.create_kb()
+ try:
+ kb.create_kb()
+ except Exception as e:
+ print(e)
+ return BaseResponse(code=500, msg=f"创建知识库出错: {e}")
+
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
async def delete_kb(
knowledge_base_name: str = Body(..., examples=["samples"])
- ):
+ ) -> BaseResponse:
# Delete selected knowledge base
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@@ -51,5 +56,6 @@ async def delete_kb(
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
except Exception as e:
print(e)
+ return BaseResponse(code=500, msg=f"删除知识库时出现意外: {e}")
return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}")
diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py
index 0bf2cb7..ae027c1 100644
--- a/server/knowledge_base/kb_doc_api.py
+++ b/server/knowledge_base/kb_doc_api.py
@@ -22,7 +22,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
) -> List[DocumentWithScore]:
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
- return {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}", "docs": []}
+ return []
docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
@@ -31,7 +31,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
async def list_docs(
knowledge_base_name: str
-):
+) -> ListResponse:
if not validate_kb_name(knowledge_base_name):
return ListResponse(code=403, msg="Don't attack me", data=[])
@@ -41,13 +41,14 @@ async def list_docs(
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
else:
all_doc_names = kb.list_docs()
- return ListResponse(data=all_doc_names)
+ return ListResponse(data=all_doc_names)
async def upload_doc(file: UploadFile = File(..., description="上传文件"),
knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]),
override: bool = Form(False, description="覆盖已有文件"),
- ):
+ not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
+ ) -> BaseResponse:
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@@ -57,31 +58,38 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"),
file_content = await file.read() # 读取上传文件的内容
- kb_file = KnowledgeFile(filename=file.filename,
- knowledge_base_name=knowledge_base_name)
-
- if (os.path.exists(kb_file.filepath)
- and not override
- and os.path.getsize(kb_file.filepath) == len(file_content)
- ):
- # TODO: filesize 不同后的处理
- file_status = f"文件 {kb_file.filename} 已存在。"
- return BaseResponse(code=404, msg=file_status)
-
try:
+ kb_file = KnowledgeFile(filename=file.filename,
+ knowledge_base_name=knowledge_base_name)
+
+ if (os.path.exists(kb_file.filepath)
+ and not override
+ and os.path.getsize(kb_file.filepath) == len(file_content)
+ ):
+ # TODO: filesize 不同后的处理
+ file_status = f"文件 {kb_file.filename} 已存在。"
+ return BaseResponse(code=404, msg=file_status)
+
with open(kb_file.filepath, "wb") as f:
f.write(file_content)
except Exception as e:
+ print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}")
- kb.add_doc(kb_file)
+ try:
+ kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
+ except Exception as e:
+ print(e)
+ return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}")
+
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
doc_name: str = Body(..., examples=["file_name.md"]),
delete_content: bool = Body(False),
- ):
+ not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
+ ) -> BaseResponse:
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@@ -92,17 +100,23 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
if not kb.exist_doc(doc_name):
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
- kb_file = KnowledgeFile(filename=doc_name,
- knowledge_base_name=knowledge_base_name)
- kb.delete_doc(kb_file, delete_content)
+
+ try:
+ kb_file = KnowledgeFile(filename=doc_name,
+ knowledge_base_name=knowledge_base_name)
+ kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=not_refresh_vs_cache)
+ except Exception as e:
+ print(e)
+ return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}")
+
return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功")
- # return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败")
async def update_doc(
knowledge_base_name: str = Body(..., examples=["samples"]),
file_name: str = Body(..., examples=["file_name"]),
- ):
+ not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
+ ) -> BaseResponse:
'''
更新知识库文档
'''
@@ -113,14 +127,17 @@ async def update_doc(
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
- kb_file = KnowledgeFile(filename=file_name,
- knowledge_base_name=knowledge_base_name)
+ try:
+ kb_file = KnowledgeFile(filename=file_name,
+ knowledge_base_name=knowledge_base_name)
+ if os.path.exists(kb_file.filepath):
+ kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
+ return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
+ except Exception as e:
+ print(e)
+ return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败,错误信息是:{e}")
- if os.path.exists(kb_file.filepath):
- kb.update_doc(kb_file)
- return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
- else:
- return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
+ return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
async def download_doc(
@@ -137,18 +154,20 @@ async def download_doc(
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
- kb_file = KnowledgeFile(filename=file_name,
- knowledge_base_name=knowledge_base_name)
-
- if os.path.exists(kb_file.filepath):
- return FileResponse(
- path=kb_file.filepath,
- filename=kb_file.filename,
- media_type="multipart/form-data")
- else:
- return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
+ try:
+ kb_file = KnowledgeFile(filename=file_name,
+ knowledge_base_name=knowledge_base_name)
+ if os.path.exists(kb_file.filepath):
+ return FileResponse(
+ path=kb_file.filepath,
+ filename=kb_file.filename,
+ media_type="multipart/form-data")
+ except Exception as e:
+ print(e)
+ return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败,错误信息是:{e}")
+ return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
async def recreate_vector_store(
@@ -163,24 +182,35 @@ async def recreate_vector_store(
by default, get_service_by_name only return knowledge base in the info.db and having document files in it.
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
'''
- kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
- if not kb.exists() and not allow_empty_kb:
- return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
- async def output(kb):
- kb.create_kb()
- kb.clear_vs()
- docs = list_docs_from_folder(knowledge_base_name)
- for i, doc in enumerate(docs):
- try:
- kb_file = KnowledgeFile(doc, knowledge_base_name)
- yield json.dumps({
- "total": len(docs),
- "finished": i,
- "doc": doc,
- }, ensure_ascii=False)
- kb.add_doc(kb_file)
- except Exception as e:
- print(e)
+ async def output():
+ kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
+ if not kb.exists() and not allow_empty_kb:
+ yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
+ else:
+ kb.create_kb()
+ kb.clear_vs()
+ docs = list_docs_from_folder(knowledge_base_name)
+ for i, doc in enumerate(docs):
+ try:
+ kb_file = KnowledgeFile(doc, knowledge_base_name)
+ yield json.dumps({
+ "code": 200,
+ "msg": f"({i + 1} / {len(docs)}): {doc}",
+ "total": len(docs),
+ "finished": i,
+ "doc": doc,
+ }, ensure_ascii=False)
+ if i == len(docs) - 1:
+ not_refresh_vs_cache = False
+ else:
+ not_refresh_vs_cache = True
+ kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
+ except Exception as e:
+ print(e)
+ yield json.dumps({
+ "code": 500,
+ "msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。",
+ })
- return StreamingResponse(output(kb), media_type="text/event-stream")
+ return StreamingResponse(output(), media_type="text/event-stream")
diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py
index d506f63..9af5b0e 100644
--- a/server/knowledge_base/kb_service/base.py
+++ b/server/knowledge_base/kb_service/base.py
@@ -71,36 +71,37 @@ class KBService(ABC):
status = delete_kb_from_db(self.kb_name)
return status
- def add_doc(self, kb_file: KnowledgeFile):
+ def add_doc(self, kb_file: KnowledgeFile, **kwargs):
"""
向知识库添加文件
"""
docs = kb_file.file2text()
if docs:
+ self.delete_doc(kb_file)
embeddings = self._load_embeddings()
- self.do_add_doc(docs, embeddings)
+ self.do_add_doc(docs, embeddings, **kwargs)
status = add_doc_to_db(kb_file)
else:
status = False
return status
- def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False):
+ def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs):
"""
从知识库删除文件
"""
- self.do_delete_doc(kb_file)
+ self.do_delete_doc(kb_file, **kwargs)
status = delete_file_from_db(kb_file)
if delete_content and os.path.exists(kb_file.filepath):
os.remove(kb_file.filepath)
return status
- def update_doc(self, kb_file: KnowledgeFile):
+ def update_doc(self, kb_file: KnowledgeFile, **kwargs):
"""
使用content中的文件更新向量库
"""
if os.path.exists(kb_file.filepath):
- self.delete_doc(kb_file)
- return self.add_doc(kb_file)
+ self.delete_doc(kb_file, **kwargs)
+ return self.add_doc(kb_file, **kwargs)
def exist_doc(self, file_name: str):
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py
index 5c8376f..b3f5439 100644
--- a/server/knowledge_base/kb_service/faiss_kb_service.py
+++ b/server/knowledge_base/kb_service/faiss_kb_service.py
@@ -13,7 +13,8 @@ from functools import lru_cache
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
from langchain.vectorstores import FAISS
from langchain.embeddings.base import Embeddings
-from langchain.embeddings.huggingface import HuggingFaceEmbeddings
+from langchain.embeddings.huggingface import HuggingFaceEmbeddings,HuggingFaceBgeEmbeddings
+from langchain.embeddings.openai import OpenAIEmbeddings
from typing import List
from langchain.docstore.document import Document
from server.utils import torch_gc
@@ -21,10 +22,19 @@ from server.utils import torch_gc
# make HuggingFaceEmbeddings hashable
def _embeddings_hash(self):
- return hash(self.model_name)
-
+ if isinstance(self, HuggingFaceEmbeddings):
+ return hash(self.model_name)
+ elif isinstance(self, HuggingFaceBgeEmbeddings):
+ return hash(self.model_name)
+ elif isinstance(self, OpenAIEmbeddings):
+ return hash(self.model)
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
+OpenAIEmbeddings.__hash__ = _embeddings_hash
+HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash
+
+_VECTOR_STORE_TICKS = {}
+
_VECTOR_STORE_TICKS = {}
@@ -41,7 +51,23 @@ def load_vector_store(
vs_path = get_vs_path(knowledge_base_name)
if embeddings is None:
embeddings = load_embeddings(embed_model, embed_device)
- search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
+
+ if not os.path.exists(vs_path):
+ os.makedirs(vs_path)
+
+ if "index.faiss" in os.listdir(vs_path):
+ search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
+ else:
+ # create an empty vector store
+ doc = Document(page_content="init", metadata={})
+ search_index = FAISS.from_documents([doc], embeddings, normalize_L2=True)
+ ids = [k for k, v in search_index.docstore._dict.items()]
+ search_index.delete(ids)
+ search_index.save_local(vs_path)
+
+ if tick == 0: # vector store is loaded first time
+ _VECTOR_STORE_TICKS[knowledge_base_name] = 0
+
return search_index
@@ -50,6 +76,7 @@ def refresh_vs_cache(kb_name: str):
make vector store cache refreshed when next loading
"""
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
+ print(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
class FaissKBService(KBService):
@@ -74,8 +101,10 @@ class FaissKBService(KBService):
def do_create_kb(self):
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
+ load_vector_store(self.kb_name)
def do_drop_kb(self):
+ self.clear_vs()
shutil.rmtree(self.kb_path)
def do_search(self,
@@ -93,38 +122,40 @@ class FaissKBService(KBService):
def do_add_doc(self,
docs: List[Document],
embeddings: Embeddings,
+ **kwargs,
):
- if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
- vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True)
- vector_store.add_documents(docs)
- torch_gc()
- else:
- if not os.path.exists(self.vs_path):
- os.makedirs(self.vs_path)
- vector_store = FAISS.from_documents(
- docs, embeddings, normalize_L2=True) # docs 为Document列表
- torch_gc()
- vector_store.save_local(self.vs_path)
- refresh_vs_cache(self.kb_name)
-
- def do_delete_doc(self,
- kb_file: KnowledgeFile):
- embeddings = self._load_embeddings()
- if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
- vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True)
- ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
- if len(ids) == 0:
- return None
- vector_store.delete(ids)
+ vector_store = load_vector_store(self.kb_name,
+ embeddings=embeddings,
+ tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
+ vector_store.add_documents(docs)
+ torch_gc()
+ if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
- return True
- else:
+
+ def do_delete_doc(self,
+ kb_file: KnowledgeFile,
+ **kwargs):
+ embeddings = self._load_embeddings()
+ vector_store = load_vector_store(self.kb_name,
+ embeddings=embeddings,
+ tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
+
+ ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
+ if len(ids) == 0:
return None
+ vector_store.delete(ids)
+ if not kwargs.get("not_refresh_vs_cache"):
+ vector_store.save_local(self.vs_path)
+ refresh_vs_cache(self.kb_name)
+
+ return True
+
def do_clear_vs(self):
shutil.rmtree(self.vs_path)
os.makedirs(self.vs_path)
+ refresh_vs_cache(self.kb_name)
def exist_doc(self, file_name: str):
if super().exist_doc(file_name):
diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py
index f9c40c0..ba74c14 100644
--- a/server/knowledge_base/kb_service/milvus_kb_service.py
+++ b/server/knowledge_base/kb_service/milvus_kb_service.py
@@ -45,12 +45,12 @@ class MilvusKBService(KBService):
def do_drop_kb(self):
self.milvus.col.drop()
- def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
+ def do_search(self, query: str, top_k: int, embeddings: Embeddings):
# todo: support score threshold
self._load_milvus(embeddings=embeddings)
- return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD)
+ return self.milvus.similarity_search_with_score(query, top_k)
- def add_doc(self, kb_file: KnowledgeFile):
+ def add_doc(self, kb_file: KnowledgeFile, **kwargs):
"""
向知识库添加文件
"""
@@ -60,10 +60,10 @@ class MilvusKBService(KBService):
status = add_doc_to_db(kb_file)
return status
- def do_add_doc(self, docs: List[Document], embeddings: Embeddings):
+ def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs):
pass
- def do_delete_doc(self, kb_file: KnowledgeFile):
+ def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
filepath = kb_file.filepath.replace('\\', '\\\\')
delete_list = [item.get("pk") for item in
self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
@@ -76,6 +76,7 @@ class MilvusKBService(KBService):
if __name__ == '__main__':
# 测试建表使用
from server.db.base import Base, engine
+
Base.metadata.create_all(bind=engine)
milvusService = MilvusKBService("test")
milvusService.add_doc(KnowledgeFile("README.md", "test"))
diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py
index a3126ec..38d8065 100644
--- a/server/knowledge_base/kb_service/pg_kb_service.py
+++ b/server/knowledge_base/kb_service/pg_kb_service.py
@@ -43,12 +43,12 @@ class PGKBService(KBService):
'''))
connect.commit()
- def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
+ def do_search(self, query: str, top_k: int, embeddings: Embeddings):
# todo: support score threshold
self._load_pg_vector(embeddings=embeddings)
- return self.pg_vector.similarity_search(query, top_k)
+ return self.pg_vector.similarity_search_with_score(query, top_k)
- def add_doc(self, kb_file: KnowledgeFile):
+ def add_doc(self, kb_file: KnowledgeFile, **kwargs):
"""
向知识库添加文件
"""
@@ -58,10 +58,10 @@ class PGKBService(KBService):
status = add_doc_to_db(kb_file)
return status
- def do_add_doc(self, docs: List[Document], embeddings: Embeddings):
+ def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs):
pass
- def do_delete_doc(self, kb_file: KnowledgeFile):
+ def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
with self.pg_vector.connect() as connect:
filepath = kb_file.filepath.replace('\\', '\\\\')
connect.execute(
@@ -76,6 +76,7 @@ class PGKBService(KBService):
if __name__ == '__main__':
from server.db.base import Base, engine
+
Base.metadata.create_all(bind=engine)
pGKBService = PGKBService("test")
pGKBService.create_kb()
diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py
index 1c023fa..c96d386 100644
--- a/server/knowledge_base/migrate.py
+++ b/server/knowledge_base/migrate.py
@@ -43,7 +43,11 @@ def folder2db(
kb_file = KnowledgeFile(doc, kb_name)
if callable(callback_before):
callback_before(kb_file, i, docs)
- kb.add_doc(kb_file)
+ if i == len(docs) - 1:
+ not_refresh_vs_cache = False
+ else:
+ not_refresh_vs_cache = True
+ kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
if callable(callback_after):
callback_after(kb_file, i, docs)
except Exception as e:
@@ -67,7 +71,11 @@ def folder2db(
kb_file = KnowledgeFile(doc, kb_name)
if callable(callback_before):
callback_before(kb_file, i, docs)
- kb.update_doc(kb_file)
+ if i == len(docs) - 1:
+ not_refresh_vs_cache = False
+ else:
+ not_refresh_vs_cache = True
+ kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
if callable(callback_after):
callback_after(kb_file, i, docs)
except Exception as e:
@@ -81,7 +89,11 @@ def folder2db(
kb_file = KnowledgeFile(doc, kb_name)
if callable(callback_before):
callback_before(kb_file, i, docs)
- kb.add_doc(kb_file)
+ if i == len(docs) - 1:
+ not_refresh_vs_cache = False
+ else:
+ not_refresh_vs_cache = True
+ kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
if callable(callback_after):
callback_after(kb_file, i, docs)
except Exception as e:
diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py
index 4f0bad9..da53049 100644
--- a/server/knowledge_base/utils.py
+++ b/server/knowledge_base/utils.py
@@ -1,5 +1,7 @@
import os
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
+from langchain.embeddings.openai import OpenAIEmbeddings
+from langchain.embeddings import HuggingFaceBgeEmbeddings
from configs.model_config import (
embedding_model_dict,
KB_ROOT_PATH,
@@ -41,11 +43,20 @@ def list_docs_from_folder(kb_name: str):
@lru_cache(1)
def load_embeddings(model: str, device: str):
- embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
- model_kwargs={'device': device})
+ if model == "text-embedding-ada-002": # openai text-embedding-ada-002
+ embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE)
+ elif 'bge-' in model:
+ embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model],
+ model_kwargs={'device': device},
+ query_instruction="为这个句子生成表示以用于检索相关文章:")
+ if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
+ embeddings.query_instruction = ""
+ else:
+ embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device})
return embeddings
+
LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst',
'.rtf', '.txt', '.xml',
'.doc', '.docx', '.epub', '.odt', '.pdf',
diff --git a/server/utils.py b/server/utils.py
index c0f11a5..4a88722 100644
--- a/server/utils.py
+++ b/server/utils.py
@@ -9,8 +9,8 @@ from typing import Any, Optional
class BaseResponse(BaseModel):
- code: int = pydantic.Field(200, description="HTTP status code")
- msg: str = pydantic.Field("success", description="HTTP status message")
+ code: int = pydantic.Field(200, description="API status code")
+ msg: str = pydantic.Field("success", description="API status message")
class Config:
schema_extra = {
diff --git a/startup.py b/startup.py
index 66814ad..df00851 100644
--- a/startup.py
+++ b/startup.py
@@ -5,6 +5,14 @@ import sys
import os
from pprint import pprint
+# 设置numexpr最大线程数,默认为CPU核心数
+try:
+ import numexpr
+ n_cores = numexpr.utils.detect_number_of_cores()
+ os.environ["NUMEXPR_MAX_THREADS"] = str(n_cores)
+except:
+ pass
+
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, \
logger
@@ -252,35 +260,36 @@ def run_webui(q: Queue, run_seq: int = 5):
def parse_args() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
+ "-a",
"--all-webui",
action="store_true",
- help="run fastchat's controller/model_worker/openai_api servers, run api.py and webui.py",
+ help="run fastchat's controller/openai_api/model_worker servers, run api.py and webui.py",
dest="all_webui",
)
parser.add_argument(
"--all-api",
action="store_true",
- help="run fastchat's controller/model_worker/openai_api servers, run api.py and webui.py",
+ help="run fastchat's controller/openai_api/model_worker servers, run api.py",
dest="all_api",
)
parser.add_argument(
"--llm-api",
action="store_true",
- help="run fastchat's controller/model_worker/openai_api servers, run api.py and webui.py",
+ help="run fastchat's controller/openai_api/model_worker servers",
dest="llm_api",
)
parser.add_argument(
"-o",
"--openai-api",
action="store_true",
- help="run fastchat controller/openai_api servers",
+ help="run fastchat's controller/openai_api servers",
dest="openai_api",
)
parser.add_argument(
"-m",
"--model-worker",
action="store_true",
- help="run fastchat model_worker server with specified model name. specify --model-name if not using default LLM_MODEL",
+ help="run fastchat's model_worker server with specified model name. specify --model-name if not using default LLM_MODEL",
dest="model_worker",
)
parser.add_argument(
@@ -315,13 +324,39 @@ def parse_args() -> argparse.ArgumentParser:
return args
-if __name__ == "__main__":
+def dump_server_info(after_start=False):
import platform
- import time
import langchain
import fastchat
from configs.server_config import api_address, webui_address
+ print("\n\n")
+ print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
+ print(f"操作系统:{platform.platform()}.")
+ print(f"python版本:{sys.version}")
+ print(f"项目版本:{VERSION}")
+ print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
+ print("\n")
+ print(f"当前LLM模型:{LLM_MODEL} @ {LLM_DEVICE}")
+ pprint(llm_model_dict[LLM_MODEL])
+ print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}")
+ if after_start:
+ print("\n")
+ print(f"服务端运行信息:")
+ if args.openai_api:
+ print(f" OpenAI API Server: {fschat_openai_api_address()}/v1")
+ print(" (请确认llm_model_dict中配置的api_base_url与上面地址一致。)")
+ if args.api:
+ print(f" Chatchat API Server: {api_address()}")
+ if args.webui:
+ print(f" Chatchat WEBUI Server: {webui_address()}")
+ print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
+ print("\n\n")
+
+
+if __name__ == "__main__":
+ import time
+
mp.set_start_method("spawn")
queue = Queue()
args = parse_args()
@@ -343,6 +378,7 @@ if __name__ == "__main__":
args.api = False
args.webui = False
+ dump_server_info()
logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
@@ -403,27 +439,7 @@ if __name__ == "__main__":
no = queue.get()
if no == len(processes):
time.sleep(0.5)
- print("\n\n")
- print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
- print(f"操作系统:{platform.platform()}.")
- print(f"python版本:{sys.version}")
- print(f"项目版本:{VERSION}")
- print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
- print("\n")
- print(f"当前LLM模型:{LLM_MODEL} @ {LLM_DEVICE}")
- pprint(llm_model_dict[LLM_MODEL])
- print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}")
- print("\n")
- print(f"服务端运行信息:")
- if args.openai_api:
- print(f" OpenAI API Server: {fschat_openai_api_address()}/v1")
- print("请确认llm_model_dict中配置的api_base_url与上面地址一致。")
- if args.api:
- print(f" Chatchat API Server: {api_address()}")
- if args.webui:
- print(f" Chatchat WEBUI Server: {webui_address()}")
- print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
- print("\n\n")
+ dump_server_info(True)
break
else:
queue.put(no)
diff --git a/tests/api/test_kb_api.py b/tests/api/test_kb_api.py
new file mode 100644
index 0000000..5a8b97d
--- /dev/null
+++ b/tests/api/test_kb_api.py
@@ -0,0 +1,204 @@
+from doctest import testfile
+import requests
+import json
+import sys
+from pathlib import Path
+
+root_path = Path(__file__).parent.parent.parent
+sys.path.append(str(root_path))
+from configs.server_config import api_address
+from configs.model_config import VECTOR_SEARCH_TOP_K
+from server.knowledge_base.utils import get_kb_path
+
+from pprint import pprint
+
+
+api_base_url = api_address()
+
+kb = "kb_for_api_test"
+test_files = {
+ "README.MD": str(root_path / "README.MD"),
+ "FAQ.MD": str(root_path / "docs" / "FAQ.MD")
+}
+
+
+def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"):
+ if not Path(get_kb_path(kb)).exists():
+ return
+
+ url = api_base_url + api
+ print("\n测试知识库存在,需要删除")
+ r = requests.post(url, json=kb)
+ data = r.json()
+ pprint(data)
+
+ # check kb not exists anymore
+ url = api_base_url + "/knowledge_base/list_knowledge_bases"
+ print("\n获取知识库列表:")
+ r = requests.get(url)
+ data = r.json()
+ pprint(data)
+ assert data["code"] == 200
+ assert isinstance(data["data"], list) and len(data["data"]) > 0
+ assert kb not in data["data"]
+
+
+def test_create_kb(api="/knowledge_base/create_knowledge_base"):
+ url = api_base_url + api
+
+ print(f"\n尝试用空名称创建知识库:")
+ r = requests.post(url, json={"knowledge_base_name": " "})
+ data = r.json()
+ pprint(data)
+ assert data["code"] == 404
+ assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称"
+
+ print(f"\n创建新知识库: {kb}")
+ r = requests.post(url, json={"knowledge_base_name": kb})
+ data = r.json()
+ pprint(data)
+ assert data["code"] == 200
+ assert data["msg"] == f"已新增知识库 {kb}"
+
+ print(f"\n尝试创建同名知识库: {kb}")
+ r = requests.post(url, json={"knowledge_base_name": kb})
+ data = r.json()
+ pprint(data)
+ assert data["code"] == 404
+ assert data["msg"] == f"已存在同名知识库 {kb}"
+
+
+def test_list_kbs(api="/knowledge_base/list_knowledge_bases"):
+ url = api_base_url + api
+ print("\n获取知识库列表:")
+ r = requests.get(url)
+ data = r.json()
+ pprint(data)
+ assert data["code"] == 200
+ assert isinstance(data["data"], list) and len(data["data"]) > 0
+ assert kb in data["data"]
+
+
+def test_upload_doc(api="/knowledge_base/upload_doc"):
+ url = api_base_url + api
+ for name, path in test_files.items():
+ print(f"\n上传知识文件: {name}")
+ data = {"knowledge_base_name": kb, "override": True}
+ files = {"file": (name, open(path, "rb"))}
+ r = requests.post(url, data=data, files=files)
+ data = r.json()
+ pprint(data)
+ assert data["code"] == 200
+ assert data["msg"] == f"成功上传文件 {name}"
+
+ for name, path in test_files.items():
+ print(f"\n尝试重新上传知识文件: {name}, 不覆盖")
+ data = {"knowledge_base_name": kb, "override": False}
+ files = {"file": (name, open(path, "rb"))}
+ r = requests.post(url, data=data, files=files)
+ data = r.json()
+ pprint(data)
+ assert data["code"] == 404
+ assert data["msg"] == f"文件 {name} 已存在。"
+
+ for name, path in test_files.items():
+ print(f"\n尝试重新上传知识文件: {name}, 覆盖")
+ data = {"knowledge_base_name": kb, "override": True}
+ files = {"file": (name, open(path, "rb"))}
+ r = requests.post(url, data=data, files=files)
+ data = r.json()
+ pprint(data)
+ assert data["code"] == 200
+ assert data["msg"] == f"成功上传文件 {name}"
+
+
+def test_list_docs(api="/knowledge_base/list_docs"):
+ url = api_base_url + api
+ print("\n获取知识库中文件列表:")
+ r = requests.get(url, params={"knowledge_base_name": kb})
+ data = r.json()
+ pprint(data)
+ assert data["code"] == 200
+ assert isinstance(data["data"], list)
+ for name in test_files:
+ assert name in data["data"]
+
+
+def test_search_docs(api="/knowledge_base/search_docs"):
+ url = api_base_url + api
+ query = "介绍一下langchain-chatchat项目"
+ print("\n检索知识库:")
+ print(query)
+ r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
+ data = r.json()
+ pprint(data)
+ assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
+
+
+def test_update_doc(api="/knowledge_base/update_doc"):
+ url = api_base_url + api
+ for name, path in test_files.items():
+ print(f"\n更新知识文件: {name}")
+ r = requests.post(url, json={"knowledge_base_name": kb, "file_name": name})
+ data = r.json()
+ pprint(data)
+ assert data["code"] == 200
+ assert data["msg"] == f"成功更新文件 {name}"
+
+
+def test_delete_doc(api="/knowledge_base/delete_doc"):
+ url = api_base_url + api
+ for name, path in test_files.items():
+ print(f"\n删除知识文件: {name}")
+ r = requests.post(url, json={"knowledge_base_name": kb, "doc_name": name})
+ data = r.json()
+ pprint(data)
+ assert data["code"] == 200
+ assert data["msg"] == f"{name} 文件删除成功"
+
+ url = api_base_url + "/knowledge_base/search_docs"
+ query = "介绍一下langchain-chatchat项目"
+ print("\n尝试检索删除后的检索知识库:")
+ print(query)
+ r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
+ data = r.json()
+ pprint(data)
+ assert isinstance(data, list) and len(data) == 0
+
+
+def test_recreate_vs(api="/knowledge_base/recreate_vector_store"):
+ url = api_base_url + api
+ print("\n重建知识库:")
+ r = requests.post(url, json={"knowledge_base_name": kb}, stream=True)
+ for chunk in r.iter_content(None):
+ data = json.loads(chunk)
+ assert isinstance(data, dict)
+ assert data["code"] == 200
+ print(data["msg"])
+
+ url = api_base_url + "/knowledge_base/search_docs"
+ query = "本项目支持哪些文件格式?"
+ print("\n尝试检索重建后的检索知识库:")
+ print(query)
+ r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
+ data = r.json()
+ pprint(data)
+ assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
+
+
+def test_delete_kb_after(api="/knowledge_base/delete_knowledge_base"):
+ url = api_base_url + api
+ print("\n删除知识库")
+ r = requests.post(url, json=kb)
+ data = r.json()
+ pprint(data)
+
+ # check kb not exists anymore
+ url = api_base_url + "/knowledge_base/list_knowledge_bases"
+ print("\n获取知识库列表:")
+ r = requests.get(url)
+ data = r.json()
+ pprint(data)
+ assert data["code"] == 200
+ assert isinstance(data["data"], list) and len(data["data"]) > 0
+ assert kb not in data["data"]
diff --git a/tests/api/test_stream_chat_api.py b/tests/api/test_stream_chat_api.py
new file mode 100644
index 0000000..56d3237
--- /dev/null
+++ b/tests/api/test_stream_chat_api.py
@@ -0,0 +1,108 @@
+import requests
+import json
+import sys
+from pathlib import Path
+
+sys.path.append(str(Path(__file__).parent.parent.parent))
+from configs.server_config import API_SERVER, api_address
+
+from pprint import pprint
+
+
+api_base_url = api_address()
+
+
+def dump_input(d, title):
+ print("\n")
+ print("=" * 30 + title + " input " + "="*30)
+ pprint(d)
+
+
+def dump_output(r, title):
+ print("\n")
+ print("=" * 30 + title + " output" + "="*30)
+ for line in r.iter_content(None, decode_unicode=True):
+ print(line, end="", flush=True)
+
+
+headers = {
+ 'accept': 'application/json',
+ 'Content-Type': 'application/json',
+}
+
+data = {
+ "query": "请用100字左右的文字介绍自己",
+ "history": [
+ {
+ "role": "user",
+ "content": "你好"
+ },
+ {
+ "role": "assistant",
+ "content": "你好,我是 ChatGLM"
+ }
+ ],
+ "stream": True
+}
+
+
+
+def test_chat_fastchat(api="/chat/fastchat"):
+ url = f"{api_base_url}{api}"
+ data2 = {
+ "stream": True,
+ "messages": data["history"] + [{"role": "user", "content": "推荐一部科幻电影"}]
+ }
+ dump_input(data2, api)
+ response = requests.post(url, headers=headers, json=data2, stream=True)
+ dump_output(response, api)
+ assert response.status_code == 200
+
+
+def test_chat_chat(api="/chat/chat"):
+ url = f"{api_base_url}{api}"
+ dump_input(data, api)
+ response = requests.post(url, headers=headers, json=data, stream=True)
+ dump_output(response, api)
+ assert response.status_code == 200
+
+
+def test_knowledge_chat(api="/chat/knowledge_base_chat"):
+ url = f"{api_base_url}{api}"
+ data = {
+ "query": "如何提问以获得高质量答案",
+ "knowledge_base_name": "samples",
+ "history": [
+ {
+ "role": "user",
+ "content": "你好"
+ },
+ {
+ "role": "assistant",
+ "content": "你好,我是 ChatGLM"
+ }
+ ],
+ "stream": True
+ }
+ dump_input(data, api)
+ response = requests.post(url, headers=headers, json=data, stream=True)
+ print("\n")
+ print("=" * 30 + api + " output" + "="*30)
+ first = True
+ for line in response.iter_content(None, decode_unicode=True):
+ data = json.loads(line)
+ if first:
+ for doc in data["docs"]:
+ print(doc)
+ first = False
+ print(data["answer"], end="", flush=True)
+ assert response.status_code == 200
+
+
+def test_search_engine_chat(api="/chat/search_engine_chat"):
+ url = f"{api_base_url}{api}"
+ for se in ["bing", "duckduckgo"]:
+ dump_input(data, api)
+ response = requests.post(url, json=data, stream=True)
+ dump_output(response, api)
+ assert response.status_code == 200
diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py
index f059475..4351e95 100644
--- a/webui_pages/knowledge_base/knowledge_base.py
+++ b/webui_pages/knowledge_base/knowledge_base.py
@@ -118,7 +118,7 @@ def knowledge_base_page(api: ApiRequest):
vector_store_type=vs_type,
embed_model=embed_model,
)
- st.toast(ret["msg"])
+ st.toast(ret.get("msg", " "))
st.session_state["selected_kb_name"] = kb_name
st.experimental_rerun()
@@ -138,12 +138,14 @@ def knowledge_base_page(api: ApiRequest):
# use_container_width=True,
disabled=len(files) == 0,
):
- for f in files:
- ret = api.upload_kb_doc(f, kb)
- if ret["code"] == 200:
- st.toast(ret["msg"], icon="✔")
- else:
- st.toast(ret["msg"], icon="✖")
+ data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files]
+ data[-1]["not_refresh_vs_cache"]=False
+ for k in data:
+ ret = api.upload_kb_doc(**k)
+ if msg := check_success_msg(ret):
+ st.toast(msg, icon="✔")
+ elif msg := check_error_msg(ret):
+ st.toast(msg, icon="✖")
st.session_state.files = []
st.divider()
@@ -235,7 +237,7 @@ def knowledge_base_page(api: ApiRequest):
):
for row in selected_rows:
ret = api.delete_kb_doc(kb, row["file_name"], True)
- st.toast(ret["msg"])
+ st.toast(ret.get("msg", " "))
st.experimental_rerun()
st.divider()
@@ -249,12 +251,14 @@ def knowledge_base_page(api: ApiRequest):
use_container_width=True,
type="primary",
):
- with st.spinner("向量库重构中"):
+ with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"):
empty = st.empty()
empty.progress(0.0, "")
for d in api.recreate_vector_store(kb):
- print(d)
- empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
+ if msg := check_error_msg(d):
+ st.toast(msg)
+ else:
+ empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
st.experimental_rerun()
if cols[2].button(
@@ -262,6 +266,6 @@ def knowledge_base_page(api: ApiRequest):
use_container_width=True,
):
ret = api.delete_knowledge_base(kb)
- st.toast(ret["msg"])
+ st.toast(ret.get("msg", " "))
time.sleep(1)
st.experimental_rerun()
diff --git a/webui_pages/utils.py b/webui_pages/utils.py
index 3e67ed7..c666d45 100644
--- a/webui_pages/utils.py
+++ b/webui_pages/utils.py
@@ -229,18 +229,18 @@ class ApiRequest:
elif chunk.strip():
yield chunk
except httpx.ConnectError as e:
- msg = f"无法连接API服务器,请确认已执行python server\\api.py"
+ msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。"
logger.error(msg)
logger.error(e)
- yield {"code": 500, "errorMsg": msg}
+ yield {"code": 500, "msg": msg}
except httpx.ReadTimeout as e:
msg = f"API通信超时,请确认已启动FastChat与API服务(详见RADME '5. 启动 API 服务或 Web UI')"
logger.error(msg)
logger.error(e)
- yield {"code": 500, "errorMsg": msg}
+ yield {"code": 500, "msg": msg}
except Exception as e:
logger.error(e)
- yield {"code": 500, "errorMsg": str(e)}
+ yield {"code": 500, "msg": str(e)}
# 对话相关操作
@@ -394,7 +394,7 @@ class ApiRequest:
return response.json()
except Exception as e:
logger.error(e)
- return {"code": 500, "errorMsg": errorMsg or str(e)}
+ return {"code": 500, "msg": errorMsg or str(e)}
def list_knowledge_bases(
self,
@@ -496,6 +496,7 @@ class ApiRequest:
knowledge_base_name: str,
filename: str = None,
override: bool = False,
+ not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
):
'''
@@ -529,7 +530,11 @@ class ApiRequest:
else:
response = self.post(
"/knowledge_base/upload_doc",
- data={"knowledge_base_name": knowledge_base_name, "override": override},
+ data={
+ "knowledge_base_name": knowledge_base_name,
+ "override": override,
+ "not_refresh_vs_cache": not_refresh_vs_cache,
+ },
files={"file": (filename, file)},
)
return self._check_httpx_json_response(response)
@@ -539,6 +544,7 @@ class ApiRequest:
knowledge_base_name: str,
doc_name: str,
delete_content: bool = False,
+ not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
):
'''
@@ -551,6 +557,7 @@ class ApiRequest:
"knowledge_base_name": knowledge_base_name,
"doc_name": doc_name,
"delete_content": delete_content,
+ "not_refresh_vs_cache": not_refresh_vs_cache,
}
if no_remote_api:
@@ -568,6 +575,7 @@ class ApiRequest:
self,
knowledge_base_name: str,
file_name: str,
+ not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
):
'''
@@ -583,7 +591,11 @@ class ApiRequest:
else:
response = self.post(
"/knowledge_base/update_doc",
- json={"knowledge_base_name": knowledge_base_name, "file_name": file_name},
+ json={
+ "knowledge_base_name": knowledge_base_name,
+ "file_name": file_name,
+ "not_refresh_vs_cache": not_refresh_vs_cache,
+ },
)
return self._check_httpx_json_response(response)
@@ -617,7 +629,7 @@ class ApiRequest:
"/knowledge_base/recreate_vector_store",
json=data,
stream=True,
- timeout=False,
+ timeout=None,
)
return self._httpx_stream2generator(response, as_json=True)
@@ -626,7 +638,22 @@ def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
'''
return error message if error occured when requests API
'''
- if isinstance(data, dict) and key in data:
+ if isinstance(data, dict):
+ if key in data:
+ return data[key]
+ if "code" in data and data["code"] != 200:
+ return data["msg"]
+ return ""
+
+
+def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str:
+ '''
+ return error message if error occured when requests API
+ '''
+ if (isinstance(data, dict)
+ and key in data
+ and "code" in data
+ and data["code"] == 200):
return data[key]
return ""