502 lines
21 KiB
Python
502 lines
21 KiB
Python
import asyncio
|
||
import json
|
||
import os
|
||
import urllib
|
||
from typing import Dict, List
|
||
|
||
from fastapi import Body, File, Form, Query, UploadFile
|
||
from fastapi.responses import FileResponse
|
||
from langchain.docstore.document import Document
|
||
from sse_starlette import EventSourceResponse
|
||
|
||
from chatchat.settings import Settings
|
||
from chatchat.server.db.repository.knowledge_file_repository import get_file_detail
|
||
from chatchat.server.knowledge_base.kb_service.base import (
|
||
KBServiceFactory,
|
||
get_kb_file_details,
|
||
)
|
||
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||
from chatchat.server.knowledge_base.utils import (
|
||
KnowledgeFile,
|
||
files2docs_in_thread,
|
||
get_file_path,
|
||
list_files_from_folder,
|
||
validate_kb_name,
|
||
)
|
||
from chatchat.server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
|
||
from chatchat.server.utils import (
|
||
BaseResponse,
|
||
ListResponse,
|
||
check_embed_model,
|
||
run_in_thread_pool,
|
||
get_default_embedding,
|
||
)
|
||
from chatchat.utils import build_logger
|
||
from typing import List, Dict,Tuple
|
||
|
||
logger = build_logger()
|
||
|
||
|
||
def search_temp_docs(knowledge_id: str = Body(..., description="知识库 ID", examples=["example_id"]),
|
||
query: str = Body("", description="用户输入", examples=["你好"]),
|
||
top_k: int = Body(..., description="返回的文档数量", examples=[5]),
|
||
score_threshold: float = Body(..., description="分数阈值", examples=[0.8])) -> List[Dict]:
|
||
'''从临时 FAISS 知识库中检索文档,用于文件对话'''
|
||
with memo_faiss_pool.acquire(knowledge_id) as vs:
|
||
docs = vs.similarity_search_with_score(
|
||
query, k=top_k, score_threshold=score_threshold
|
||
)
|
||
docs = [x[0].dict() for x in docs]
|
||
return docs
|
||
|
||
|
||
def search_docs(
|
||
query: str = Body("", description="用户输入", examples=["你好"]),
|
||
knowledge_base_name: str = Body(
|
||
..., description="知识库名称", examples=["samples"]
|
||
),
|
||
top_k: int = Body(Settings.kb_settings.VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||
score_threshold: float = Body(
|
||
Settings.kb_settings.SCORE_THRESHOLD,
|
||
description="知识库匹配相关度阈值,取值范围在0-1之间,"
|
||
"SCORE越小,相关度越高,"
|
||
"取到2相当于不筛选,建议设置在0.5左右",
|
||
ge=0.0,
|
||
le=2.0,
|
||
),
|
||
file_name: str = Body("", description="文件名称,支持 sql 通配符"),
|
||
metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"),
|
||
) -> List[Dict]:
|
||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||
data = []
|
||
if kb is not None:
|
||
if query:
|
||
docs = kb.search_docs(query, top_k, score_threshold)
|
||
if docs is not None:
|
||
logger.info(f"search_docs, query:{query},top_k:{top_k},score_threshold:{score_threshold},len(docs):{len(docs)}")
|
||
|
||
docs_key = kb.search_content_internal(query,2)
|
||
if docs_key is not None:
|
||
logger.info(f"before merge_and_deduplicate ,len(docs_key):{len(docs_key)}")
|
||
docs = merge_and_deduplicate(docs, docs_key)
|
||
if docs is not None:
|
||
logger.info(f"after merge_and_deduplicate len(docs):{len(docs)}")
|
||
data = [DocumentWithVSId(**{"id": x.metadata.get("id"), **x.dict()}) for x in docs]
|
||
elif file_name or metadata:
|
||
data = kb.list_docs(file_name=file_name, metadata=metadata)
|
||
for d in data:
|
||
if "vector" in d.metadata:
|
||
del d.metadata["vector"]
|
||
return [x.dict() for x in data]
|
||
|
||
def merge_and_deduplicate(list1: List[Document], list2: List[Document]) -> List[Document]:
|
||
# 使用字典存储唯一的 Document
|
||
merged_dict = {}
|
||
if list1 is not None:
|
||
merged_dict = {doc.page_content: doc for doc in list1}
|
||
|
||
# 遍历 list2,将新的 Document 添加到字典
|
||
if list2 is not None:
|
||
for doc in list2:
|
||
if doc.page_content not in merged_dict:
|
||
merged_dict[doc.page_content] = doc
|
||
|
||
# 返回去重后的列表
|
||
return list(merged_dict.values())
|
||
|
||
def list_files(knowledge_base_name: str) -> ListResponse:
|
||
if not validate_kb_name(knowledge_base_name):
|
||
return ListResponse(code=403, msg="Don't attack me", data=[])
|
||
|
||
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||
if kb is None:
|
||
return ListResponse(
|
||
code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[]
|
||
)
|
||
else:
|
||
all_docs = get_kb_file_details(knowledge_base_name)
|
||
return ListResponse(data=all_docs)
|
||
|
||
|
||
def _save_files_in_thread(
|
||
files: List[UploadFile], knowledge_base_name: str, override: bool
|
||
):
|
||
"""
|
||
通过多线程将上传的文件保存到对应知识库目录内。
|
||
生成器返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}}
|
||
"""
|
||
|
||
def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dict:
|
||
"""
|
||
保存单个文件。
|
||
"""
|
||
try:
|
||
filename = file.filename
|
||
file_path = get_file_path(
|
||
knowledge_base_name=knowledge_base_name, doc_name=filename
|
||
)
|
||
data = {"knowledge_base_name": knowledge_base_name, "file_name": filename}
|
||
|
||
file_content = file.file.read() # 读取上传文件的内容
|
||
if (
|
||
os.path.isfile(file_path)
|
||
and not override
|
||
and os.path.getsize(file_path) == len(file_content)
|
||
):
|
||
file_status = f"文件 {filename} 已存在。"
|
||
logger.warn(file_status)
|
||
return dict(code=404, msg=file_status, data=data)
|
||
|
||
if not os.path.isdir(os.path.dirname(file_path)):
|
||
os.makedirs(os.path.dirname(file_path))
|
||
with open(file_path, "wb") as f:
|
||
f.write(file_content)
|
||
return dict(code=200, msg=f"成功上传文件 {filename}", data=data)
|
||
except Exception as e:
|
||
msg = f"{filename} 文件上传失败,报错信息为: {e}"
|
||
logger.error(f"{e.__class__.__name__}: {msg}")
|
||
return dict(code=500, msg=msg, data=data)
|
||
|
||
params = [
|
||
{"file": file, "knowledge_base_name": knowledge_base_name, "override": override}
|
||
for file in files
|
||
]
|
||
for result in run_in_thread_pool(save_file, params=params):
|
||
yield result
|
||
|
||
|
||
# def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
|
||
# override: bool = Form(False, description="覆盖已有文件"),
|
||
# save: bool = Form(True, description="是否将文件保存到知识库目录")):
|
||
# def save_files(files, knowledge_base_name, override):
|
||
# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
|
||
# yield json.dumps(result, ensure_ascii=False)
|
||
|
||
# def files_to_docs(files):
|
||
# for result in files2docs_in_thread(files):
|
||
# yield json.dumps(result, ensure_ascii=False)
|
||
|
||
|
||
def upload_docs(
|
||
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||
knowledge_base_name: str = Form(
|
||
..., description="知识库名称", examples=["samples"]
|
||
),
|
||
override: bool = Form(False, description="覆盖已有文件"),
|
||
to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
|
||
chunk_size: int = Form(Settings.kb_settings.CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
||
chunk_overlap: int = Form(Settings.kb_settings.OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||
zh_title_enhance: bool = Form(Settings.kb_settings.ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
||
docs: str = Form("", description="自定义的docs,需要转为json字符串"),
|
||
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
|
||
) -> BaseResponse:
|
||
"""
|
||
API接口:上传文件,并/或向量化
|
||
"""
|
||
if not validate_kb_name(knowledge_base_name):
|
||
return BaseResponse(code=403, msg="Don't attack me")
|
||
|
||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||
if kb is None:
|
||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||
|
||
docs = json.loads(docs) if docs else {}
|
||
failed_files = {}
|
||
file_names = list(docs.keys())
|
||
|
||
# 先将上传的文件保存到磁盘
|
||
for result in _save_files_in_thread(
|
||
files, knowledge_base_name=knowledge_base_name, override=override
|
||
):
|
||
filename = result["data"]["file_name"]
|
||
if result["code"] != 200:
|
||
failed_files[filename] = result["msg"]
|
||
|
||
if filename not in file_names:
|
||
file_names.append(filename)
|
||
|
||
# 对保存的文件进行向量化
|
||
if to_vector_store:
|
||
result = update_docs(
|
||
knowledge_base_name=knowledge_base_name,
|
||
file_names=file_names,
|
||
override_custom_docs=True,
|
||
chunk_size=chunk_size,
|
||
chunk_overlap=chunk_overlap,
|
||
zh_title_enhance=zh_title_enhance,
|
||
docs=docs,
|
||
not_refresh_vs_cache=True,
|
||
)
|
||
failed_files.update(result.data["failed_files"])
|
||
if not not_refresh_vs_cache:
|
||
kb.save_vector_store()
|
||
|
||
return BaseResponse(
|
||
code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files}
|
||
)
|
||
|
||
|
||
def delete_docs(
|
||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||
file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]),
|
||
delete_content: bool = Body(False),
|
||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||
) -> BaseResponse:
|
||
if not validate_kb_name(knowledge_base_name):
|
||
return BaseResponse(code=403, msg="Don't attack me")
|
||
|
||
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||
if kb is None:
|
||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||
|
||
failed_files = {}
|
||
for file_name in file_names:
|
||
if not kb.exist_doc(file_name):
|
||
failed_files[file_name] = f"未找到文件 {file_name}"
|
||
|
||
try:
|
||
kb_file = KnowledgeFile(
|
||
filename=file_name, knowledge_base_name=knowledge_base_name
|
||
)
|
||
kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=True)
|
||
except Exception as e:
|
||
msg = f"{file_name} 文件删除失败,错误信息:{e}"
|
||
logger.error(f"{e.__class__.__name__}: {msg}")
|
||
failed_files[file_name] = msg
|
||
|
||
if not not_refresh_vs_cache:
|
||
kb.save_vector_store()
|
||
|
||
return BaseResponse(
|
||
code=200, msg=f"文件删除完成", data={"failed_files": failed_files}
|
||
)
|
||
|
||
|
||
def update_info(
|
||
knowledge_base_name: str = Body(
|
||
..., description="知识库名称", examples=["samples"]
|
||
),
|
||
kb_info: str = Body(..., description="知识库介绍", examples=["这是一个知识库"]),
|
||
):
|
||
if not validate_kb_name(knowledge_base_name):
|
||
return BaseResponse(code=403, msg="Don't attack me")
|
||
|
||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||
if kb is None:
|
||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||
kb.update_info(kb_info)
|
||
|
||
return BaseResponse(code=200, msg=f"知识库介绍修改完成", data={"kb_info": kb_info})
|
||
|
||
|
||
def update_docs(
|
||
knowledge_base_name: str = Body(
|
||
..., description="知识库名称", examples=["samples"]
|
||
),
|
||
file_names: List[str] = Body(
|
||
..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]
|
||
),
|
||
chunk_size: int = Body(Settings.kb_settings.CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
||
chunk_overlap: int = Body(Settings.kb_settings.OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||
zh_title_enhance: bool = Body(Settings.kb_settings.ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
||
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
|
||
docs: str = Body("", description="自定义的docs,需要转为json字符串"),
|
||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||
) -> BaseResponse:
|
||
"""
|
||
更新知识库文档
|
||
"""
|
||
if not validate_kb_name(knowledge_base_name):
|
||
return BaseResponse(code=403, msg="Don't attack me")
|
||
|
||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||
if kb is None:
|
||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||
|
||
failed_files = {}
|
||
kb_files = []
|
||
docs = json.loads(docs) if docs else {}
|
||
|
||
logger.info(f"111111 kb_doc_api update_docs file_names:{file_names},更新的doc 长度:{len(docs)}")
|
||
# 生成需要加载docs的文件列表
|
||
for file_name in file_names:
|
||
file_detail = get_file_detail(kb_name=knowledge_base_name, filename=file_name)
|
||
# 如果该文件之前使用了自定义docs,则根据参数决定略过或覆盖
|
||
if file_detail.get("custom_docs") and not override_custom_docs:
|
||
continue
|
||
if file_name not in docs:
|
||
try:
|
||
logger.info(f"****kb_doc_api update_docs file_name not in docs,filename:{file_name}")
|
||
kb_files.append(
|
||
KnowledgeFile(
|
||
filename=file_name, knowledge_base_name=knowledge_base_name
|
||
)
|
||
)
|
||
except Exception as e:
|
||
msg = f"加载文档 {file_name} 时出错:{e}"
|
||
logger.error(f"{e.__class__.__name__}: {msg}")
|
||
failed_files[file_name] = msg
|
||
|
||
# 从文件生成docs,并进行向量化。
|
||
# 这里利用了KnowledgeFile的缓存功能,在多线程中加载Document,然后传给KnowledgeFile
|
||
for status, result in files2docs_in_thread(
|
||
kb_files,
|
||
chunk_size=chunk_size,
|
||
chunk_overlap=chunk_overlap,
|
||
zh_title_enhance=zh_title_enhance,
|
||
):
|
||
if status:
|
||
kb_name, file_name, new_docs = result
|
||
kb_file = KnowledgeFile(
|
||
filename=file_name, knowledge_base_name=knowledge_base_name
|
||
)
|
||
kb_file.splited_docs = new_docs
|
||
kb.update_doc(kb_file, not_refresh_vs_cache=True)
|
||
else:
|
||
kb_name, file_name, error = result
|
||
failed_files[file_name] = error
|
||
|
||
# 将自定义的docs进行向量化
|
||
for file_name, v in docs.items():
|
||
try:
|
||
logger.info(f"222222 kb_doc_api update_docs file_name:{file_name},更新的doc 长度:{len(docs)}")
|
||
v = [x if isinstance(x, Document) else Document(**x) for x in v]
|
||
kb_file = KnowledgeFile(
|
||
filename=file_name, knowledge_base_name=knowledge_base_name
|
||
)
|
||
kb.update_doc(kb_file, docs=v, not_refresh_vs_cache=True)
|
||
except Exception as e:
|
||
msg = f"为 {file_name} 添加自定义docs时出错:{e}"
|
||
logger.error(f"{e.__class__.__name__}: {msg}")
|
||
failed_files[file_name] = msg
|
||
|
||
if not not_refresh_vs_cache:
|
||
kb.save_vector_store()
|
||
|
||
return BaseResponse(
|
||
code=200, msg=f"更新文档完成", data={"failed_files": failed_files}
|
||
)
|
||
|
||
|
||
def download_doc(
|
||
knowledge_base_name: str = Query(
|
||
..., description="知识库名称", examples=["samples"]
|
||
),
|
||
file_name: str = Query(..., description="文件名称", examples=["test.txt"]),
|
||
preview: bool = Query(False, description="是:浏览器内预览;否:下载"),
|
||
):
|
||
"""
|
||
下载知识库文档
|
||
"""
|
||
if not validate_kb_name(knowledge_base_name):
|
||
return BaseResponse(code=403, msg="Don't attack me")
|
||
|
||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||
if kb is None:
|
||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||
|
||
if preview:
|
||
content_disposition_type = "inline"
|
||
else:
|
||
content_disposition_type = None
|
||
|
||
try:
|
||
kb_file = KnowledgeFile(
|
||
filename=file_name, knowledge_base_name=knowledge_base_name
|
||
)
|
||
|
||
if os.path.exists(kb_file.filepath):
|
||
return FileResponse(
|
||
path=kb_file.filepath,
|
||
filename=kb_file.filename,
|
||
media_type="multipart/form-data",
|
||
content_disposition_type=content_disposition_type,
|
||
)
|
||
except Exception as e:
|
||
msg = f"{kb_file.filename} 读取文件失败,错误信息是:{e}"
|
||
logger.error(f"{e.__class__.__name__}: {msg}")
|
||
return BaseResponse(code=500, msg=msg)
|
||
|
||
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
|
||
|
||
|
||
def recreate_vector_store(
|
||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||
allow_empty_kb: bool = Body(True),
|
||
vs_type: str = Body(Settings.kb_settings.DEFAULT_VS_TYPE, description="为空知识库指定向量库类型。已有知识库默认使用原向量库类型。"),
|
||
embed_model: str = Body(get_default_embedding(), description="为空知识库指定Embedding模型。已有知识库默认使用原Embedding模型。"),
|
||
chunk_size: int = Body(Settings.kb_settings.CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
||
chunk_overlap: int = Body(Settings.kb_settings.OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||
zh_title_enhance: bool = Body(Settings.kb_settings.ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||
):
|
||
"""
|
||
recreate vector store from the content.
|
||
this is usefull when user can copy files to content folder directly instead of upload through network.
|
||
by default, get_service_by_name only return knowledge base in the info.db and having document files in it.
|
||
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
|
||
"""
|
||
|
||
def output():
|
||
try:
|
||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||
if kb is None:
|
||
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
|
||
if not kb.exists() and not allow_empty_kb:
|
||
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||
else:
|
||
ok, msg = kb.check_embed_model()
|
||
if not ok:
|
||
yield {"code": 404, "msg": msg}
|
||
else:
|
||
if kb.exists():
|
||
kb.clear_vs()
|
||
kb.create_kb()
|
||
files = list_files_from_folder(knowledge_base_name)
|
||
kb_files = [(file, knowledge_base_name) for file in files]
|
||
i = 0
|
||
for status, result in files2docs_in_thread(
|
||
kb_files,
|
||
chunk_size=chunk_size,
|
||
chunk_overlap=chunk_overlap,
|
||
zh_title_enhance=zh_title_enhance,
|
||
):
|
||
if status:
|
||
kb_name, file_name, docs = result
|
||
kb_file = KnowledgeFile(
|
||
filename=file_name, knowledge_base_name=kb_name
|
||
)
|
||
kb_file.splited_docs = docs
|
||
yield json.dumps(
|
||
{
|
||
"code": 200,
|
||
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
||
"total": len(files),
|
||
"finished": i + 1,
|
||
"doc": file_name,
|
||
},
|
||
ensure_ascii=False,
|
||
)
|
||
kb.add_doc(kb_file, not_refresh_vs_cache=True)
|
||
else:
|
||
kb_name, file_name, error = result
|
||
msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"
|
||
logger.error(msg)
|
||
yield json.dumps(
|
||
{
|
||
"code": 500,
|
||
"msg": msg,
|
||
}
|
||
)
|
||
i += 1
|
||
if not not_refresh_vs_cache:
|
||
kb.save_vector_store()
|
||
except asyncio.exceptions.CancelledError:
|
||
logger.warning("streaming progress has been interrupted by user.")
|
||
return
|
||
|
||
return EventSourceResponse(output())
|