Merge branch 'dev' into pre-release

This commit is contained in:
imClumsyPanda 2023-08-24 21:44:22 +08:00
commit e995301995
27 changed files with 700 additions and 208 deletions

View File

@ -126,6 +126,7 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
- [BAAI/bge-small-zh](https://huggingface.co/BAAI/bge-small-zh)
- [BAAI/bge-base-zh](https://huggingface.co/BAAI/bge-base-zh)
- [BAAI/bge-large-zh](https://huggingface.co/BAAI/bge-large-zh)
- [BAAI/bge-large-zh-noinstruct](https://huggingface.co/BAAI/bge-large-zh-noinstruct)
- [text2vec-base-chinese-sentence](https://huggingface.co/shibing624/text2vec-base-chinese-sentence)
- [text2vec-base-chinese-paraphrase](https://huggingface.co/shibing624/text2vec-base-chinese-paraphrase)
- [text2vec-base-multilingual](https://huggingface.co/shibing624/text2vec-base-multilingual)
@ -133,6 +134,7 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
- [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese)
- [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh)
- [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh)
- [OpenAI/text-embedding-ada-002](https://platform.openai.com/docs/guides/embeddings)
---
@ -206,6 +208,7 @@ embedding_model_dict = {
"m3e-base": "/Users/xxx/Downloads/m3e-base",
}
```
如果你选择使用OpenAI的Embedding模型请将模型的```key```写入`embedding_model_dict`中。使用该模型你需要鞥能够访问OpenAI官的API或设置代理。
### 4. 知识库初始化与迁移
@ -298,24 +301,13 @@ $ python server/llm_api_shutdown.py --serve all
亦可单独停止一个 FastChat 服务模块,可选 [`all`, `controller`, `model_worker`, `openai_api_server`]
##### 5.1.3 PEFT 加载
##### 5.1.3 PEFT 加载(包括lora,p-tuning,prefix tuning, prompt tuning,ia等)
本项目基于 FastChat 加载 LLM 服务,故需以 FastChat 加载 PEFT 路径,即保证路径名称里必须有 peft 这个词,配置文件的名字为 adapter_config.jsonpeft 路径下包含 model.bin 格式的 PEFT 权重。
详细步骤参考[加载lora微调后模型失效](https://github.com/chatchat-space/Langchain-Chatchat/issues/1130#issuecomment-1685291822)
示例代码如下:
![image](https://github.com/chatchat-space/Langchain-Chatchat/assets/22924096/4e056c1c-5c4b-4865-a1af-859cd58a625d)
```shell
PEFT_SHARE_BASE_WEIGHTS=true python3 -m fastchat.serve.multi_model_worker \
--model-path /data/chris/peft-llama-dummy-1 \
--model-names peft-dummy-1 \
--model-path /data/chris/peft-llama-dummy-2 \
--model-names peft-dummy-2 \
--model-path /data/chris/peft-llama-dummy-3 \
--model-names peft-dummy-3 \
--num-gpus 2
```
详见 [FastChat 相关 PR](https://github.com/lm-sys/fastchat/pull/1905#issuecomment-1627801216)
#### 5.2 启动 API 服务
@ -441,6 +433,6 @@ $ python startup.py --all-webui --model-name Qwen-7B-Chat
## 项目交流群
<img src="img/qr_code_54.jpg" alt="二维码" width="300" height="300" />
<img src="img/qr_code_56.jpg" alt="二维码" width="300" height="300" />
🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。

0
common/__init__.py Normal file
View File

View File

@ -24,7 +24,9 @@ embedding_model_dict = {
"m3e-large": "moka-ai/m3e-large",
"bge-small-zh": "BAAI/bge-small-zh",
"bge-base-zh": "BAAI/bge-base-zh",
"bge-large-zh": "BAAI/bge-large-zh"
"bge-large-zh": "BAAI/bge-large-zh",
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
"text-embedding-ada-002": os.environ.get("OPENAI_API_KEY")
}
# 选用的 Embedding 名称
@ -41,12 +43,6 @@ llm_model_dict = {
"api_key": "EMPTY"
},
"chatglm-6b-int4": {
"local_model_path": "THUDM/chatglm-6b-int4",
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_key": "EMPTY"
},
"chatglm2-6b": {
"local_model_path": "THUDM/chatglm2-6b",
"api_base_url": "http://localhost:8888/v1", # URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
@ -59,12 +55,6 @@ llm_model_dict = {
"api_key": "EMPTY"
},
"vicuna-13b-hf": {
"local_model_path": "",
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_key": "EMPTY"
},
# 调用chatgpt时如果报出 urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
# Max retries exceeded with url: /v1/chat/completions
# 则需要将urllib3版本修改为1.25.11
@ -75,10 +65,15 @@ llm_model_dict = {
# urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPSConnection object at 0x000001FE4BDB85E0>:
# Failed to establish a new connection: [WinError 10060]
# 则是因为内地和香港的IP都被OPENAI封了需要切换为日本、新加坡等地
# 如果出现WARNING: Retrying langchain.chat_models.openai.acompletion_with_retry.<locals>._completion_with_retry in
# 4.0 seconds as it raised APIConnectionError: Error communicating with OpenAI.
# 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置
"gpt-3.5-turbo": {
"local_model_path": "gpt-3.5-turbo",
"api_base_url": "https://api.openai.com/v1",
"api_key": os.environ.get("OPENAI_API_KEY")
"api_key": os.environ.get("OPENAI_API_KEY"),
"openai_proxy": os.environ.get("OPENAI_PROXY")
},
}
@ -114,7 +109,7 @@ kbs_config = {
"secure": False,
},
"pg": {
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatglm",
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
}
}
@ -142,12 +137,12 @@ SEARCH_ENGINE_TOP_K = 5
# nltk 模型存储路径
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
# 基于本地知识问答的提示词模版
PROMPT_TEMPLATE = """【指令】根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。
# 基于本地知识问答的提示词模版使用Jinja2语法简单点就是用双大括号代替f-string的单大括号
PROMPT_TEMPLATE = """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
【已知信息】{context}
<已知信息>{{ context }}</已知信息>
【问题】{question}"""
<问题>{{ question }}</问题>"""
# API 是否开启跨域默认为False如果需要开启请设置为True
# is open cross domain

View File

@ -170,3 +170,16 @@ A13: 疑为 chatglm 的 quantization 的问题或 torch 版本差异问题,针
Q14: 修改配置中路径后,加载 text2vec-large-chinese 依然提示 `WARNING: No sentence-transformers model found with name text2vec-large-chinese. Creating a new one with MEAN pooling.`
A14: 尝试更换 embedding如 text2vec-base-chinese请在 [configs/model_config.py](../configs/model_config.py) 文件中,修改 `text2vec-base`参数为本地路径,绝对路径或者相对路径均可
---
Q15: 使用pg向量库建表报错
A15: 需要手动安装对应的vector扩展(连接pg执行 CREATE EXTENSION IF NOT EXISTS vector)
---
Q16: pymilvus 连接超时
A16.pymilvus版本需要匹配和milvus对应否则会超时参考pymilvus==2.1.3

View File

@ -2,9 +2,9 @@ version: "3.8"
services:
postgresql:
image: ankane/pgvector:v0.4.1
container_name: langchain-chatgml-pg-db
container_name: langchain_chatchat-pg-db
environment:
POSTGRES_DB: langchain_chatgml
POSTGRES_DB: langchain_chatchat
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
ports:

View File

@ -5,3 +5,4 @@
cd docs/docker/vector_db/milvus
docker-compose up -d
```

BIN
img/qr_code_55.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 291 KiB

BIN
img/qr_code_56.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 200 KiB

View File

@ -2,6 +2,8 @@ from server.knowledge_base.migrate import create_tables, folder2db, recreate_all
from configs.model_config import NLTK_DATA_PATH
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
from startup import dump_server_info
if __name__ == "__main__":
import argparse
@ -21,6 +23,8 @@ if __name__ == "__main__":
)
args = parser.parse_args()
dump_server_info()
create_tables()
print("database talbes created")

View File

@ -21,7 +21,7 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
),
stream: bool = Body(False, description="流式输出"),
):
history = [History(**h) if isinstance(h, dict) else h for h in history]
history = [History.from_data(h) for h in history]
async def chat_iterator(query: str,
history: List[History] = [],
@ -34,11 +34,13 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL
model_name=LLM_MODEL,
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
)
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_tuple() for i in history] + [("human", "{input}")])
[i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model)
# Begin a task that runs in the background.

View File

@ -38,7 +38,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
history = [History(**h) if isinstance(h, dict) else h for h in history]
history = [History.from_data(h) for h in history]
async def knowledge_base_chat_iterator(query: str,
kb: KBService,
@ -52,13 +52,15 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL
model_name=LLM_MODEL,
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
)
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs])
input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
[i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model)

View File

@ -73,6 +73,8 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
history = [History.from_data(h) for h in history]
async def search_engine_chat_iterator(query: str,
search_engine_name: str,
top_k: int,
@ -85,14 +87,16 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL
model_name=LLM_MODEL,
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
)
docs = lookup_search_engine(query, search_engine_name, top_k)
context = "\n".join([doc.page_content for doc in docs])
input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
[i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model)
@ -117,7 +121,7 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps({"answer": token,
yield json.dumps({"answer": answer,
"docs": source_documents},
ensure_ascii=False)
await task

View File

@ -1,6 +1,7 @@
import asyncio
from typing import Awaitable
from typing import Awaitable, List, Tuple, Dict, Union
from pydantic import BaseModel, Field
from langchain.prompts.chat import ChatMessagePromptTemplate
async def wrap_done(fn: Awaitable, event: asyncio.Event):
@ -28,3 +29,29 @@ class History(BaseModel):
def to_msg_tuple(self):
return "ai" if self.role=="assistant" else "human", self.content
def to_msg_template(self, is_raw=True) -> ChatMessagePromptTemplate:
role_maps = {
"ai": "assistant",
"human": "user",
}
role = role_maps.get(self.role, self.role)
if is_raw: # 当前默认历史消息都是没有input_variable的文本。
content = "{% raw %}" + self.content + "{% endraw %}"
else:
content = self.content
return ChatMessagePromptTemplate.from_template(
content,
"jinja2",
role=role,
)
@classmethod
def from_data(cls, h: Union[List, Tuple, Dict]) -> "History":
if isinstance(h, (list,tuple)) and len(h) >= 2:
h = cls(role=h[0], content=h[1])
elif isinstance(h, dict):
h = cls(**h)
return h

View File

@ -15,7 +15,7 @@ async def list_kbs():
async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"),
embed_model: str = Body(EMBEDDING_MODEL),
):
) -> BaseResponse:
# Create selected knowledge base
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -27,13 +27,18 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
kb.create_kb()
try:
kb.create_kb()
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"创建知识库出错: {e}")
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
async def delete_kb(
knowledge_base_name: str = Body(..., examples=["samples"])
):
) -> BaseResponse:
# Delete selected knowledge base
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -51,5 +56,6 @@ async def delete_kb(
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"删除知识库时出现意外: {e}")
return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}")

View File

@ -22,7 +22,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
) -> List[DocumentWithScore]:
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}", "docs": []}
return []
docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
@ -31,7 +31,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
async def list_docs(
knowledge_base_name: str
):
) -> ListResponse:
if not validate_kb_name(knowledge_base_name):
return ListResponse(code=403, msg="Don't attack me", data=[])
@ -41,13 +41,14 @@ async def list_docs(
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
else:
all_doc_names = kb.list_docs()
return ListResponse(data=all_doc_names)
return ListResponse(data=all_doc_names)
async def upload_doc(file: UploadFile = File(..., description="上传文件"),
knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]),
override: bool = Form(False, description="覆盖已有文件"),
):
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -57,31 +58,38 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"),
file_content = await file.read() # 读取上传文件的内容
kb_file = KnowledgeFile(filename=file.filename,
knowledge_base_name=knowledge_base_name)
if (os.path.exists(kb_file.filepath)
and not override
and os.path.getsize(kb_file.filepath) == len(file_content)
):
# TODO: filesize 不同后的处理
file_status = f"文件 {kb_file.filename} 已存在。"
return BaseResponse(code=404, msg=file_status)
try:
kb_file = KnowledgeFile(filename=file.filename,
knowledge_base_name=knowledge_base_name)
if (os.path.exists(kb_file.filepath)
and not override
and os.path.getsize(kb_file.filepath) == len(file_content)
):
# TODO: filesize 不同后的处理
file_status = f"文件 {kb_file.filename} 已存在。"
return BaseResponse(code=404, msg=file_status)
with open(kb_file.filepath, "wb") as f:
f.write(file_content)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}")
kb.add_doc(kb_file)
try:
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}")
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
doc_name: str = Body(..., examples=["file_name.md"]),
delete_content: bool = Body(False),
):
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -92,17 +100,23 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
if not kb.exist_doc(doc_name):
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
kb_file = KnowledgeFile(filename=doc_name,
knowledge_base_name=knowledge_base_name)
kb.delete_doc(kb_file, delete_content)
try:
kb_file = KnowledgeFile(filename=doc_name,
knowledge_base_name=knowledge_base_name)
kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=not_refresh_vs_cache)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}")
return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功")
# return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败")
async def update_doc(
knowledge_base_name: str = Body(..., examples=["samples"]),
file_name: str = Body(..., examples=["file_name"]),
):
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
'''
更新知识库文档
'''
@ -113,14 +127,17 @@ async def update_doc(
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
try:
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
if os.path.exists(kb_file.filepath):
kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败,错误信息是:{e}")
if os.path.exists(kb_file.filepath):
kb.update_doc(kb_file)
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
else:
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
async def download_doc(
@ -137,18 +154,20 @@ async def download_doc(
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
if os.path.exists(kb_file.filepath):
return FileResponse(
path=kb_file.filepath,
filename=kb_file.filename,
media_type="multipart/form-data")
else:
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
try:
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
if os.path.exists(kb_file.filepath):
return FileResponse(
path=kb_file.filepath,
filename=kb_file.filename,
media_type="multipart/form-data")
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败,错误信息是:{e}")
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
async def recreate_vector_store(
@ -163,24 +182,35 @@ async def recreate_vector_store(
by default, get_service_by_name only return knowledge base in the info.db and having document files in it.
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
'''
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
async def output(kb):
kb.create_kb()
kb.clear_vs()
docs = list_docs_from_folder(knowledge_base_name)
for i, doc in enumerate(docs):
try:
kb_file = KnowledgeFile(doc, knowledge_base_name)
yield json.dumps({
"total": len(docs),
"finished": i,
"doc": doc,
}, ensure_ascii=False)
kb.add_doc(kb_file)
except Exception as e:
print(e)
async def output():
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else:
kb.create_kb()
kb.clear_vs()
docs = list_docs_from_folder(knowledge_base_name)
for i, doc in enumerate(docs):
try:
kb_file = KnowledgeFile(doc, knowledge_base_name)
yield json.dumps({
"code": 200,
"msg": f"({i + 1} / {len(docs)}): {doc}",
"total": len(docs),
"finished": i,
"doc": doc,
}, ensure_ascii=False)
if i == len(docs) - 1:
not_refresh_vs_cache = False
else:
not_refresh_vs_cache = True
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
except Exception as e:
print(e)
yield json.dumps({
"code": 500,
"msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。",
})
return StreamingResponse(output(kb), media_type="text/event-stream")
return StreamingResponse(output(), media_type="text/event-stream")

View File

@ -71,36 +71,37 @@ class KBService(ABC):
status = delete_kb_from_db(self.kb_name)
return status
def add_doc(self, kb_file: KnowledgeFile):
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
"""
向知识库添加文件
"""
docs = kb_file.file2text()
if docs:
self.delete_doc(kb_file)
embeddings = self._load_embeddings()
self.do_add_doc(docs, embeddings)
self.do_add_doc(docs, embeddings, **kwargs)
status = add_doc_to_db(kb_file)
else:
status = False
return status
def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False):
def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs):
"""
从知识库删除文件
"""
self.do_delete_doc(kb_file)
self.do_delete_doc(kb_file, **kwargs)
status = delete_file_from_db(kb_file)
if delete_content and os.path.exists(kb_file.filepath):
os.remove(kb_file.filepath)
return status
def update_doc(self, kb_file: KnowledgeFile):
def update_doc(self, kb_file: KnowledgeFile, **kwargs):
"""
使用content中的文件更新向量库
"""
if os.path.exists(kb_file.filepath):
self.delete_doc(kb_file)
return self.add_doc(kb_file)
self.delete_doc(kb_file, **kwargs)
return self.add_doc(kb_file, **kwargs)
def exist_doc(self, file_name: str):
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,

View File

@ -13,7 +13,8 @@ from functools import lru_cache
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
from langchain.vectorstores import FAISS
from langchain.embeddings.base import Embeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings,HuggingFaceBgeEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from typing import List
from langchain.docstore.document import Document
from server.utils import torch_gc
@ -21,10 +22,19 @@ from server.utils import torch_gc
# make HuggingFaceEmbeddings hashable
def _embeddings_hash(self):
return hash(self.model_name)
if isinstance(self, HuggingFaceEmbeddings):
return hash(self.model_name)
elif isinstance(self, HuggingFaceBgeEmbeddings):
return hash(self.model_name)
elif isinstance(self, OpenAIEmbeddings):
return hash(self.model)
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
OpenAIEmbeddings.__hash__ = _embeddings_hash
HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash
_VECTOR_STORE_TICKS = {}
_VECTOR_STORE_TICKS = {}
@ -41,7 +51,23 @@ def load_vector_store(
vs_path = get_vs_path(knowledge_base_name)
if embeddings is None:
embeddings = load_embeddings(embed_model, embed_device)
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
if not os.path.exists(vs_path):
os.makedirs(vs_path)
if "index.faiss" in os.listdir(vs_path):
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
else:
# create an empty vector store
doc = Document(page_content="init", metadata={})
search_index = FAISS.from_documents([doc], embeddings, normalize_L2=True)
ids = [k for k, v in search_index.docstore._dict.items()]
search_index.delete(ids)
search_index.save_local(vs_path)
if tick == 0: # vector store is loaded first time
_VECTOR_STORE_TICKS[knowledge_base_name] = 0
return search_index
@ -50,6 +76,7 @@ def refresh_vs_cache(kb_name: str):
make vector store cache refreshed when next loading
"""
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
print(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
class FaissKBService(KBService):
@ -74,8 +101,10 @@ class FaissKBService(KBService):
def do_create_kb(self):
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
load_vector_store(self.kb_name)
def do_drop_kb(self):
self.clear_vs()
shutil.rmtree(self.kb_path)
def do_search(self,
@ -93,38 +122,40 @@ class FaissKBService(KBService):
def do_add_doc(self,
docs: List[Document],
embeddings: Embeddings,
**kwargs,
):
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True)
vector_store.add_documents(docs)
torch_gc()
else:
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
vector_store = FAISS.from_documents(
docs, embeddings, normalize_L2=True) # docs 为Document列表
torch_gc()
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
def do_delete_doc(self,
kb_file: KnowledgeFile):
embeddings = self._load_embeddings()
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True)
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
if len(ids) == 0:
return None
vector_store.delete(ids)
vector_store = load_vector_store(self.kb_name,
embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
vector_store.add_documents(docs)
torch_gc()
if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
return True
else:
def do_delete_doc(self,
kb_file: KnowledgeFile,
**kwargs):
embeddings = self._load_embeddings()
vector_store = load_vector_store(self.kb_name,
embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
if len(ids) == 0:
return None
vector_store.delete(ids)
if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
return True
def do_clear_vs(self):
shutil.rmtree(self.vs_path)
os.makedirs(self.vs_path)
refresh_vs_cache(self.kb_name)
def exist_doc(self, file_name: str):
if super().exist_doc(file_name):

View File

@ -45,12 +45,12 @@ class MilvusKBService(KBService):
def do_drop_kb(self):
self.milvus.col.drop()
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
def do_search(self, query: str, top_k: int, embeddings: Embeddings):
# todo: support score threshold
self._load_milvus(embeddings=embeddings)
return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD)
return self.milvus.similarity_search_with_score(query, top_k)
def add_doc(self, kb_file: KnowledgeFile):
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
"""
向知识库添加文件
"""
@ -60,10 +60,10 @@ class MilvusKBService(KBService):
status = add_doc_to_db(kb_file)
return status
def do_add_doc(self, docs: List[Document], embeddings: Embeddings):
def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs):
pass
def do_delete_doc(self, kb_file: KnowledgeFile):
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
filepath = kb_file.filepath.replace('\\', '\\\\')
delete_list = [item.get("pk") for item in
self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
@ -76,6 +76,7 @@ class MilvusKBService(KBService):
if __name__ == '__main__':
# 测试建表使用
from server.db.base import Base, engine
Base.metadata.create_all(bind=engine)
milvusService = MilvusKBService("test")
milvusService.add_doc(KnowledgeFile("README.md", "test"))

View File

@ -43,12 +43,12 @@ class PGKBService(KBService):
'''))
connect.commit()
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
def do_search(self, query: str, top_k: int, embeddings: Embeddings):
# todo: support score threshold
self._load_pg_vector(embeddings=embeddings)
return self.pg_vector.similarity_search(query, top_k)
return self.pg_vector.similarity_search_with_score(query, top_k)
def add_doc(self, kb_file: KnowledgeFile):
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
"""
向知识库添加文件
"""
@ -58,10 +58,10 @@ class PGKBService(KBService):
status = add_doc_to_db(kb_file)
return status
def do_add_doc(self, docs: List[Document], embeddings: Embeddings):
def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs):
pass
def do_delete_doc(self, kb_file: KnowledgeFile):
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
with self.pg_vector.connect() as connect:
filepath = kb_file.filepath.replace('\\', '\\\\')
connect.execute(
@ -76,6 +76,7 @@ class PGKBService(KBService):
if __name__ == '__main__':
from server.db.base import Base, engine
Base.metadata.create_all(bind=engine)
pGKBService = PGKBService("test")
pGKBService.create_kb()

View File

@ -43,7 +43,11 @@ def folder2db(
kb_file = KnowledgeFile(doc, kb_name)
if callable(callback_before):
callback_before(kb_file, i, docs)
kb.add_doc(kb_file)
if i == len(docs) - 1:
not_refresh_vs_cache = False
else:
not_refresh_vs_cache = True
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
if callable(callback_after):
callback_after(kb_file, i, docs)
except Exception as e:
@ -67,7 +71,11 @@ def folder2db(
kb_file = KnowledgeFile(doc, kb_name)
if callable(callback_before):
callback_before(kb_file, i, docs)
kb.update_doc(kb_file)
if i == len(docs) - 1:
not_refresh_vs_cache = False
else:
not_refresh_vs_cache = True
kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
if callable(callback_after):
callback_after(kb_file, i, docs)
except Exception as e:
@ -81,7 +89,11 @@ def folder2db(
kb_file = KnowledgeFile(doc, kb_name)
if callable(callback_before):
callback_before(kb_file, i, docs)
kb.add_doc(kb_file)
if i == len(docs) - 1:
not_refresh_vs_cache = False
else:
not_refresh_vs_cache = True
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
if callable(callback_after):
callback_after(kb_file, i, docs)
except Exception as e:

View File

@ -1,5 +1,7 @@
import os
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceBgeEmbeddings
from configs.model_config import (
embedding_model_dict,
KB_ROOT_PATH,
@ -41,11 +43,20 @@ def list_docs_from_folder(kb_name: str):
@lru_cache(1)
def load_embeddings(model: str, device: str):
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
model_kwargs={'device': device})
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE)
elif 'bge-' in model:
embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model],
model_kwargs={'device': device},
query_instruction="为这个句子生成表示以用于检索相关文章:")
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
embeddings.query_instruction = ""
else:
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device})
return embeddings
LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst',
'.rtf', '.txt', '.xml',
'.doc', '.docx', '.epub', '.odt', '.pdf',

View File

@ -9,8 +9,8 @@ from typing import Any, Optional
class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="HTTP status code")
msg: str = pydantic.Field("success", description="HTTP status message")
code: int = pydantic.Field(200, description="API status code")
msg: str = pydantic.Field("success", description="API status message")
class Config:
schema_extra = {

View File

@ -5,6 +5,14 @@ import sys
import os
from pprint import pprint
# 设置numexpr最大线程数默认为CPU核心数
try:
import numexpr
n_cores = numexpr.utils.detect_number_of_cores()
os.environ["NUMEXPR_MAX_THREADS"] = str(n_cores)
except:
pass
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, \
logger
@ -252,35 +260,36 @@ def run_webui(q: Queue, run_seq: int = 5):
def parse_args() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"-a",
"--all-webui",
action="store_true",
help="run fastchat's controller/model_worker/openai_api servers, run api.py and webui.py",
help="run fastchat's controller/openai_api/model_worker servers, run api.py and webui.py",
dest="all_webui",
)
parser.add_argument(
"--all-api",
action="store_true",
help="run fastchat's controller/model_worker/openai_api servers, run api.py and webui.py",
help="run fastchat's controller/openai_api/model_worker servers, run api.py",
dest="all_api",
)
parser.add_argument(
"--llm-api",
action="store_true",
help="run fastchat's controller/model_worker/openai_api servers, run api.py and webui.py",
help="run fastchat's controller/openai_api/model_worker servers",
dest="llm_api",
)
parser.add_argument(
"-o",
"--openai-api",
action="store_true",
help="run fastchat controller/openai_api servers",
help="run fastchat's controller/openai_api servers",
dest="openai_api",
)
parser.add_argument(
"-m",
"--model-worker",
action="store_true",
help="run fastchat model_worker server with specified model name. specify --model-name if not using default LLM_MODEL",
help="run fastchat's model_worker server with specified model name. specify --model-name if not using default LLM_MODEL",
dest="model_worker",
)
parser.add_argument(
@ -315,13 +324,39 @@ def parse_args() -> argparse.ArgumentParser:
return args
if __name__ == "__main__":
def dump_server_info(after_start=False):
import platform
import time
import langchain
import fastchat
from configs.server_config import api_address, webui_address
print("\n\n")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print(f"操作系统:{platform.platform()}.")
print(f"python版本{sys.version}")
print(f"项目版本:{VERSION}")
print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}")
print("\n")
print(f"当前LLM模型{LLM_MODEL} @ {LLM_DEVICE}")
pprint(llm_model_dict[LLM_MODEL])
print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}")
if after_start:
print("\n")
print(f"服务端运行信息:")
if args.openai_api:
print(f" OpenAI API Server: {fschat_openai_api_address()}/v1")
print(" (请确认llm_model_dict中配置的api_base_url与上面地址一致。)")
if args.api:
print(f" Chatchat API Server: {api_address()}")
if args.webui:
print(f" Chatchat WEBUI Server: {webui_address()}")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print("\n\n")
if __name__ == "__main__":
import time
mp.set_start_method("spawn")
queue = Queue()
args = parse_args()
@ -343,6 +378,7 @@ if __name__ == "__main__":
args.api = False
args.webui = False
dump_server_info()
logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
@ -403,27 +439,7 @@ if __name__ == "__main__":
no = queue.get()
if no == len(processes):
time.sleep(0.5)
print("\n\n")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print(f"操作系统:{platform.platform()}.")
print(f"python版本{sys.version}")
print(f"项目版本:{VERSION}")
print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}")
print("\n")
print(f"当前LLM模型{LLM_MODEL} @ {LLM_DEVICE}")
pprint(llm_model_dict[LLM_MODEL])
print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}")
print("\n")
print(f"服务端运行信息:")
if args.openai_api:
print(f" OpenAI API Server: {fschat_openai_api_address()}/v1")
print("请确认llm_model_dict中配置的api_base_url与上面地址一致。")
if args.api:
print(f" Chatchat API Server: {api_address()}")
if args.webui:
print(f" Chatchat WEBUI Server: {webui_address()}")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print("\n\n")
dump_server_info(True)
break
else:
queue.put(no)

204
tests/api/test_kb_api.py Normal file
View File

@ -0,0 +1,204 @@
from doctest import testfile
import requests
import json
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from configs.server_config import api_address
from configs.model_config import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path
from pprint import pprint
api_base_url = api_address()
kb = "kb_for_api_test"
test_files = {
"README.MD": str(root_path / "README.MD"),
"FAQ.MD": str(root_path / "docs" / "FAQ.MD")
}
def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"):
if not Path(get_kb_path(kb)).exists():
return
url = api_base_url + api
print("\n测试知识库存在,需要删除")
r = requests.post(url, json=kb)
data = r.json()
pprint(data)
# check kb not exists anymore
url = api_base_url + "/knowledge_base/list_knowledge_bases"
print("\n获取知识库列表:")
r = requests.get(url)
data = r.json()
pprint(data)
assert data["code"] == 200
assert isinstance(data["data"], list) and len(data["data"]) > 0
assert kb not in data["data"]
def test_create_kb(api="/knowledge_base/create_knowledge_base"):
url = api_base_url + api
print(f"\n尝试用空名称创建知识库:")
r = requests.post(url, json={"knowledge_base_name": " "})
data = r.json()
pprint(data)
assert data["code"] == 404
assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称"
print(f"\n创建新知识库: {kb}")
r = requests.post(url, json={"knowledge_base_name": kb})
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"已新增知识库 {kb}"
print(f"\n尝试创建同名知识库: {kb}")
r = requests.post(url, json={"knowledge_base_name": kb})
data = r.json()
pprint(data)
assert data["code"] == 404
assert data["msg"] == f"已存在同名知识库 {kb}"
def test_list_kbs(api="/knowledge_base/list_knowledge_bases"):
url = api_base_url + api
print("\n获取知识库列表:")
r = requests.get(url)
data = r.json()
pprint(data)
assert data["code"] == 200
assert isinstance(data["data"], list) and len(data["data"]) > 0
assert kb in data["data"]
def test_upload_doc(api="/knowledge_base/upload_doc"):
url = api_base_url + api
for name, path in test_files.items():
print(f"\n上传知识文件: {name}")
data = {"knowledge_base_name": kb, "override": True}
files = {"file": (name, open(path, "rb"))}
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"成功上传文件 {name}"
for name, path in test_files.items():
print(f"\n尝试重新上传知识文件: {name} 不覆盖")
data = {"knowledge_base_name": kb, "override": False}
files = {"file": (name, open(path, "rb"))}
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 404
assert data["msg"] == f"文件 {name} 已存在。"
for name, path in test_files.items():
print(f"\n尝试重新上传知识文件: {name} 覆盖")
data = {"knowledge_base_name": kb, "override": True}
files = {"file": (name, open(path, "rb"))}
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"成功上传文件 {name}"
def test_list_docs(api="/knowledge_base/list_docs"):
url = api_base_url + api
print("\n获取知识库中文件列表:")
r = requests.get(url, params={"knowledge_base_name": kb})
data = r.json()
pprint(data)
assert data["code"] == 200
assert isinstance(data["data"], list)
for name in test_files:
assert name in data["data"]
def test_search_docs(api="/knowledge_base/search_docs"):
url = api_base_url + api
query = "介绍一下langchain-chatchat项目"
print("\n检索知识库:")
print(query)
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
data = r.json()
pprint(data)
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
def test_update_doc(api="/knowledge_base/update_doc"):
url = api_base_url + api
for name, path in test_files.items():
print(f"\n更新知识文件: {name}")
r = requests.post(url, json={"knowledge_base_name": kb, "file_name": name})
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"成功更新文件 {name}"
def test_delete_doc(api="/knowledge_base/delete_doc"):
url = api_base_url + api
for name, path in test_files.items():
print(f"\n删除知识文件: {name}")
r = requests.post(url, json={"knowledge_base_name": kb, "doc_name": name})
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"{name} 文件删除成功"
url = api_base_url + "/knowledge_base/search_docs"
query = "介绍一下langchain-chatchat项目"
print("\n尝试检索删除后的检索知识库:")
print(query)
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
data = r.json()
pprint(data)
assert isinstance(data, list) and len(data) == 0
def test_recreate_vs(api="/knowledge_base/recreate_vector_store"):
url = api_base_url + api
print("\n重建知识库:")
r = requests.post(url, json={"knowledge_base_name": kb}, stream=True)
for chunk in r.iter_content(None):
data = json.loads(chunk)
assert isinstance(data, dict)
assert data["code"] == 200
print(data["msg"])
url = api_base_url + "/knowledge_base/search_docs"
query = "本项目支持哪些文件格式?"
print("\n尝试检索重建后的检索知识库:")
print(query)
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
data = r.json()
pprint(data)
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
def test_delete_kb_after(api="/knowledge_base/delete_knowledge_base"):
url = api_base_url + api
print("\n删除知识库")
r = requests.post(url, json=kb)
data = r.json()
pprint(data)
# check kb not exists anymore
url = api_base_url + "/knowledge_base/list_knowledge_bases"
print("\n获取知识库列表:")
r = requests.get(url)
data = r.json()
pprint(data)
assert data["code"] == 200
assert isinstance(data["data"], list) and len(data["data"]) > 0
assert kb not in data["data"]

View File

@ -0,0 +1,108 @@
import requests
import json
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent.parent))
from configs.server_config import API_SERVER, api_address
from pprint import pprint
api_base_url = api_address()
def dump_input(d, title):
print("\n")
print("=" * 30 + title + " input " + "="*30)
pprint(d)
def dump_output(r, title):
print("\n")
print("=" * 30 + title + " output" + "="*30)
for line in r.iter_content(None, decode_unicode=True):
print(line, end="", flush=True)
headers = {
'accept': 'application/json',
'Content-Type': 'application/json',
}
data = {
"query": "请用100字左右的文字介绍自己",
"history": [
{
"role": "user",
"content": "你好"
},
{
"role": "assistant",
"content": "你好,我是 ChatGLM"
}
],
"stream": True
}
def test_chat_fastchat(api="/chat/fastchat"):
url = f"{api_base_url}{api}"
data2 = {
"stream": True,
"messages": data["history"] + [{"role": "user", "content": "推荐一部科幻电影"}]
}
dump_input(data2, api)
response = requests.post(url, headers=headers, json=data2, stream=True)
dump_output(response, api)
assert response.status_code == 200
def test_chat_chat(api="/chat/chat"):
url = f"{api_base_url}{api}"
dump_input(data, api)
response = requests.post(url, headers=headers, json=data, stream=True)
dump_output(response, api)
assert response.status_code == 200
def test_knowledge_chat(api="/chat/knowledge_base_chat"):
url = f"{api_base_url}{api}"
data = {
"query": "如何提问以获得高质量答案",
"knowledge_base_name": "samples",
"history": [
{
"role": "user",
"content": "你好"
},
{
"role": "assistant",
"content": "你好,我是 ChatGLM"
}
],
"stream": True
}
dump_input(data, api)
response = requests.post(url, headers=headers, json=data, stream=True)
print("\n")
print("=" * 30 + api + " output" + "="*30)
first = True
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line)
if first:
for doc in data["docs"]:
print(doc)
first = False
print(data["answer"], end="", flush=True)
assert response.status_code == 200
def test_search_engine_chat(api="/chat/search_engine_chat"):
url = f"{api_base_url}{api}"
for se in ["bing", "duckduckgo"]:
dump_input(data, api)
response = requests.post(url, json=data, stream=True)
dump_output(response, api)
assert response.status_code == 200

View File

@ -118,7 +118,7 @@ def knowledge_base_page(api: ApiRequest):
vector_store_type=vs_type,
embed_model=embed_model,
)
st.toast(ret["msg"])
st.toast(ret.get("msg", " "))
st.session_state["selected_kb_name"] = kb_name
st.experimental_rerun()
@ -138,12 +138,14 @@ def knowledge_base_page(api: ApiRequest):
# use_container_width=True,
disabled=len(files) == 0,
):
for f in files:
ret = api.upload_kb_doc(f, kb)
if ret["code"] == 200:
st.toast(ret["msg"], icon="")
else:
st.toast(ret["msg"], icon="")
data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files]
data[-1]["not_refresh_vs_cache"]=False
for k in data:
ret = api.upload_kb_doc(**k)
if msg := check_success_msg(ret):
st.toast(msg, icon="")
elif msg := check_error_msg(ret):
st.toast(msg, icon="")
st.session_state.files = []
st.divider()
@ -235,7 +237,7 @@ def knowledge_base_page(api: ApiRequest):
):
for row in selected_rows:
ret = api.delete_kb_doc(kb, row["file_name"], True)
st.toast(ret["msg"])
st.toast(ret.get("msg", " "))
st.experimental_rerun()
st.divider()
@ -249,12 +251,14 @@ def knowledge_base_page(api: ApiRequest):
use_container_width=True,
type="primary",
):
with st.spinner("向量库重构中"):
with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"):
empty = st.empty()
empty.progress(0.0, "")
for d in api.recreate_vector_store(kb):
print(d)
empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
if msg := check_error_msg(d):
st.toast(msg)
else:
empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
st.experimental_rerun()
if cols[2].button(
@ -262,6 +266,6 @@ def knowledge_base_page(api: ApiRequest):
use_container_width=True,
):
ret = api.delete_knowledge_base(kb)
st.toast(ret["msg"])
st.toast(ret.get("msg", " "))
time.sleep(1)
st.experimental_rerun()

View File

@ -229,18 +229,18 @@ class ApiRequest:
elif chunk.strip():
yield chunk
except httpx.ConnectError as e:
msg = f"无法连接API服务器请确认已执行python server\\api.py"
msg = f"无法连接API服务器请确认 api.py 已正常启动。"
logger.error(msg)
logger.error(e)
yield {"code": 500, "errorMsg": msg}
yield {"code": 500, "msg": msg}
except httpx.ReadTimeout as e:
msg = f"API通信超时请确认已启动FastChat与API服务详见RADME '5. 启动 API 服务或 Web UI'"
logger.error(msg)
logger.error(e)
yield {"code": 500, "errorMsg": msg}
yield {"code": 500, "msg": msg}
except Exception as e:
logger.error(e)
yield {"code": 500, "errorMsg": str(e)}
yield {"code": 500, "msg": str(e)}
# 对话相关操作
@ -394,7 +394,7 @@ class ApiRequest:
return response.json()
except Exception as e:
logger.error(e)
return {"code": 500, "errorMsg": errorMsg or str(e)}
return {"code": 500, "msg": errorMsg or str(e)}
def list_knowledge_bases(
self,
@ -496,6 +496,7 @@ class ApiRequest:
knowledge_base_name: str,
filename: str = None,
override: bool = False,
not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
):
'''
@ -529,7 +530,11 @@ class ApiRequest:
else:
response = self.post(
"/knowledge_base/upload_doc",
data={"knowledge_base_name": knowledge_base_name, "override": override},
data={
"knowledge_base_name": knowledge_base_name,
"override": override,
"not_refresh_vs_cache": not_refresh_vs_cache,
},
files={"file": (filename, file)},
)
return self._check_httpx_json_response(response)
@ -539,6 +544,7 @@ class ApiRequest:
knowledge_base_name: str,
doc_name: str,
delete_content: bool = False,
not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
):
'''
@ -551,6 +557,7 @@ class ApiRequest:
"knowledge_base_name": knowledge_base_name,
"doc_name": doc_name,
"delete_content": delete_content,
"not_refresh_vs_cache": not_refresh_vs_cache,
}
if no_remote_api:
@ -568,6 +575,7 @@ class ApiRequest:
self,
knowledge_base_name: str,
file_name: str,
not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
):
'''
@ -583,7 +591,11 @@ class ApiRequest:
else:
response = self.post(
"/knowledge_base/update_doc",
json={"knowledge_base_name": knowledge_base_name, "file_name": file_name},
json={
"knowledge_base_name": knowledge_base_name,
"file_name": file_name,
"not_refresh_vs_cache": not_refresh_vs_cache,
},
)
return self._check_httpx_json_response(response)
@ -617,7 +629,7 @@ class ApiRequest:
"/knowledge_base/recreate_vector_store",
json=data,
stream=True,
timeout=False,
timeout=None,
)
return self._httpx_stream2generator(response, as_json=True)
@ -626,7 +638,22 @@ def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
'''
return error message if error occured when requests API
'''
if isinstance(data, dict) and key in data:
if isinstance(data, dict):
if key in data:
return data[key]
if "code" in data and data["code"] != 200:
return data["msg"]
return ""
def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str:
'''
return error message if error occured when requests API
'''
if (isinstance(data, dict)
and key in data
and "code" in data
and data["code"] == 200):
return data[key]
return ""