diff --git a/server/api.py b/server/api.py index 37954b7..74426cc 100644 --- a/server/api.py +++ b/server/api.py @@ -15,8 +15,8 @@ from starlette.responses import RedirectResponse from server.chat import (chat, knowledge_base_chat, openai_chat, search_engine_chat) from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb -from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc, - update_doc, download_doc, recreate_vector_store, +from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, + update_docs, download_doc, recreate_vector_store, search_docs, DocumentWithScore) from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address import httpx @@ -98,23 +98,23 @@ def create_app(): summary="搜索知识库" )(search_docs) - app.post("/knowledge_base/upload_doc", + app.post("/knowledge_base/upload_docs", tags=["Knowledge Base Management"], response_model=BaseResponse, - summary="上传文件到知识库" - )(upload_doc) + summary="上传文件到知识库,并/或进行向量化" + )(upload_docs) - app.post("/knowledge_base/delete_doc", + app.post("/knowledge_base/delete_docs", tags=["Knowledge Base Management"], response_model=BaseResponse, summary="删除知识库内指定文件" - )(delete_doc) + )(delete_docs) - app.post("/knowledge_base/update_doc", + app.post("/knowledge_base/update_docs", tags=["Knowledge Base Management"], response_model=BaseResponse, summary="更新现有文件到知识库" - )(update_doc) + )(update_docs) app.get("/knowledge_base/download_doc", tags=["Knowledge Base Management"], diff --git a/server/knowledge_base/kb_api.py b/server/knowledge_base/kb_api.py index b9151b8..4b135fe 100644 --- a/server/knowledge_base/kb_api.py +++ b/server/knowledge_base/kb_api.py @@ -3,7 +3,7 @@ from server.utils import BaseResponse, ListResponse from server.knowledge_base.utils import validate_kb_name from server.knowledge_base.kb_service.base import KBServiceFactory from server.db.repository.knowledge_base_repository import list_kbs_from_db -from configs.model_config import EMBEDDING_MODEL +from configs.model_config import EMBEDDING_MODEL, logger from fastapi import Body @@ -30,8 +30,9 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), try: kb.create_kb() except Exception as e: - print(e) - return BaseResponse(code=500, msg=f"创建知识库出错: {e}") + msg = f"创建知识库出错: {e}" + logger.error(msg) + return BaseResponse(code=500, msg=msg) return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") @@ -55,7 +56,8 @@ async def delete_kb( if status: return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}") except Exception as e: - print(e) - return BaseResponse(code=500, msg=f"删除知识库时出现意外: {e}") + msg = f"删除知识库时出现意外: {e}" + logger.error(msg) + return BaseResponse(code=500, msg=msg) return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}") diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 7ea5d27..24b1c8e 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -1,12 +1,17 @@ import os import urllib from fastapi import File, Form, Body, Query, UploadFile -from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) -from server.utils import BaseResponse, ListResponse -from server.knowledge_base.utils import validate_kb_name, list_files_from_folder, KnowledgeFile +from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, + VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, + logger,) +from server.utils import BaseResponse, ListResponse, run_in_thread_pool +from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder,get_file_path, + files2docs_in_thread, KnowledgeFile) from fastapi.responses import StreamingResponse, FileResponse +from pydantic import Json import json from server.knowledge_base.kb_service.base import KBServiceFactory +from server.db.repository.knowledge_file_repository import get_file_detail from typing import List, Dict from langchain.docstore.document import Document @@ -44,11 +49,83 @@ async def list_files( return ListResponse(data=all_doc_names) -async def upload_doc(file: UploadFile = File(..., description="上传文件"), - knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]), +def _save_files_in_thread(files: List[UploadFile], + knowledge_base_name: str, + override: bool): + ''' + 通过多线程将上传的文件保存到对应知识库目录内。 + 生成器返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}} + ''' + def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dict: + ''' + 保存单个文件。 + ''' + try: + filename = file.filename + file_path = get_file_path(knowledge_base_name=knowledge_base_name, doc_name=filename) + data = {"knowledge_base_name": knowledge_base_name, "file_name": filename} + + file_content = file.file.read() # 读取上传文件的内容 + if (os.path.isfile(file_path) + and not override + and os.path.getsize(file_path) == len(file_content) + ): + # TODO: filesize 不同后的处理 + file_status = f"文件 {filename} 已存在。" + logger.warn(file_status) + return dict(code=404, msg=file_status, data=data) + + with open(file_path, "wb") as f: + f.write(file_content) + return dict(code=200, msg=f"成功上传文件 {filename}", data=data) + except Exception as e: + msg = f"{filename} 文件上传失败,报错信息为: {e}" + logger.error(msg) + return dict(code=500, msg=msg, data=data) + + params = [{"file": file, "knowledge_base_name": knowledge_base_name, "override": override} for file in files] + for result in run_in_thread_pool(save_file, params=params): + yield result + + +# 似乎没有单独增加一个文件上传API接口的必要 +# def upload_files(files: List[UploadFile] = File(..., description="上传文件,支持多文件"), +# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]), +# override: bool = Form(False, description="覆盖已有文件")): +# ''' +# API接口:上传文件。流式返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}} +# ''' +# def generate(files, knowledge_base_name, override): +# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override): +# yield json.dumps(result, ensure_ascii=False) + +# return StreamingResponse(generate(files, knowledge_base_name=knowledge_base_name, override=override), media_type="text/event-stream") + + +# TODO: 等langchain.document_loaders支持内存文件的时候再开通 +# def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"), +# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]), +# override: bool = Form(False, description="覆盖已有文件"), +# save: bool = Form(True, description="是否将文件保存到知识库目录")): +# def save_files(files, knowledge_base_name, override): +# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override): +# yield json.dumps(result, ensure_ascii=False) + +# def files_to_docs(files): +# for result in files2docs_in_thread(files): +# yield json.dumps(result, ensure_ascii=False) + + +async def upload_docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"), + knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]), override: bool = Form(False, description="覆盖已有文件"), + to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"), + docs: Json = Form({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]), not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: + ''' + API接口:上传文件,并/或向量化 + ''' if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -56,37 +133,36 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"), if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - file_content = await file.read() # 读取上传文件的内容 + failed_files = {} + file_names = list(docs.keys()) - try: - kb_file = KnowledgeFile(filename=file.filename, - knowledge_base_name=knowledge_base_name) + # 先将上传的文件保存到磁盘 + for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override): + filename = result["data"]["file_name"] + if result["code"] != 200: + failed_files[filename] = result["msg"] + + if filename not in file_names: + file_names.append(filename) - if (os.path.exists(kb_file.filepath) - and not override - and os.path.getsize(kb_file.filepath) == len(file_content) - ): - # TODO: filesize 不同后的处理 - file_status = f"文件 {kb_file.filename} 已存在。" - return BaseResponse(code=404, msg=file_status) + # 对保存的文件进行向量化 + if to_vector_store: + result = await update_docs( + knowledge_base_name=knowledge_base_name, + file_names=file_names, + override_custom_docs=True, + docs=docs, + not_refresh_vs_cache=True, + ) + failed_files.update(result.data["failed_files"]) + if not not_refresh_vs_cache: + kb.save_vector_store() - with open(kb_file.filepath, "wb") as f: - f.write(file_content) - except Exception as e: - print(e) - return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}") - - try: - kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) - except Exception as e: - print(e) - return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}") - - return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}") + return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files}) -async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]), - doc_name: str = Body(..., examples=["file_name.md"]), +async def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]), + file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]), delete_content: bool = Body(False), not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: @@ -98,23 +174,31 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]), if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - if not kb.exist_doc(doc_name): - return BaseResponse(code=404, msg=f"未找到文件 {doc_name}") + failed_files = {} + for file_name in file_names: + if not kb.exist_doc(file_name): + failed_files[file_name] = f"未找到文件 {file_name}" - try: - kb_file = KnowledgeFile(filename=doc_name, - knowledge_base_name=knowledge_base_name) - kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=not_refresh_vs_cache) - except Exception as e: - print(e) - return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}") + try: + kb_file = KnowledgeFile(filename=file_name, + knowledge_base_name=knowledge_base_name) + kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=True) + except Exception as e: + msg = f"{file_name} 文件删除失败,错误信息:{e}" + logger.error(msg) + failed_files[file_name] = msg + + if not not_refresh_vs_cache: + kb.save_vector_store() - return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功") + return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files}) -async def update_doc( - knowledge_base_name: str = Body(..., examples=["samples"]), - file_name: str = Body(..., examples=["file_name"]), +async def update_docs( + knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), + file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=["file_name"]), + override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"), + docs: Json = Body({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]), not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: ''' @@ -127,22 +211,57 @@ async def update_doc( if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - try: - kb_file = KnowledgeFile(filename=file_name, - knowledge_base_name=knowledge_base_name) - if os.path.exists(kb_file.filepath): - kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) - return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}") - except Exception as e: - print(e) - return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败,错误信息是:{e}") + failed_files = {} + kb_files = [] - return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败") + # 生成需要加载docs的文件列表 + for file_name in file_names: + file_detail= get_file_detail(kb_name=knowledge_base_name, filename=file_name) + # 如果该文件之前使用了自定义docs,则根据参数决定略过或覆盖 + if file_detail.get("custom_docs") and not override_custom_docs: + continue + if file_name not in docs: + try: + kb_files.append(KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name)) + except Exception as e: + msg = f"加载文档 {file_name} 时出错:{e}" + logger.error(msg) + failed_files[file_name] = msg + + # 从文件生成docs,并进行向量化。 + # 这里利用了KnowledgeFile的缓存功能,在多线程中加载Document,然后传给KnowledgeFile + for status, result in files2docs_in_thread(kb_files): + if status: + kb_name, file_name, new_docs = result + kb_file = KnowledgeFile(filename=file_name, + knowledge_base_name=knowledge_base_name) + kb_file.splited_docs = new_docs + kb.update_doc(kb_file, not_refresh_vs_cache=True) + else: + kb_name, file_name, error = result + failed_files[file_name] = error + + # 将自定义的docs进行向量化 + for file_name, v in docs.items(): + try: + v = [x if isinstance(x, Document) else Document(**x) for x in v] + kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name) + kb.update_doc(kb_file, docs=v, not_refresh_vs_cache=True) + except Exception as e: + msg = f"为 {file_name} 添加自定义docs时出错:{e}" + logger.error(msg) + failed_files[file_name] = msg + + if not not_refresh_vs_cache: + kb.save_vector_store() + + return BaseResponse(code=200, msg=f"更新文档完成", data={"failed_files": failed_files}) async def download_doc( - knowledge_base_name: str = Query(..., examples=["samples"]), - file_name: str = Query(..., examples=["test.txt"]), + knowledge_base_name: str = Query(...,description="知识库名称", examples=["samples"]), + file_name: str = Query(...,description="文件名称", examples=["test.txt"]), + preview: bool = Query(False, description="是:浏览器内预览;否:下载"), ): ''' 下载知识库文档 @@ -154,6 +273,11 @@ async def download_doc( if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") + if preview: + content_disposition_type = "inline" + else: + content_disposition_type = None + try: kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name) @@ -162,10 +286,13 @@ async def download_doc( return FileResponse( path=kb_file.filepath, filename=kb_file.filename, - media_type="multipart/form-data") + media_type="multipart/form-data", + content_disposition_type=content_disposition_type, + ) except Exception as e: - print(e) - return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败,错误信息是:{e}") + msg = f"{kb_file.filename} 读取文件失败,错误信息是:{e}" + logger.error(msg) + return BaseResponse(code=500, msg=msg) return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败") @@ -190,27 +317,30 @@ async def recreate_vector_store( else: kb.create_kb() kb.clear_vs() - docs = list_files_from_folder(knowledge_base_name) - for i, doc in enumerate(docs): - try: - kb_file = KnowledgeFile(doc, knowledge_base_name) + files = list_files_from_folder(knowledge_base_name) + kb_files = [(file, knowledge_base_name) for file in files] + i = 0 + for status, result in files2docs_in_thread(kb_files): + if status: + kb_name, file_name, docs = result + kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name) + kb_file.splited_docs = docs yield json.dumps({ "code": 200, - "msg": f"({i + 1} / {len(docs)}): {doc}", - "total": len(docs), + "msg": f"({i + 1} / {len(files)}): {file_name}", + "total": len(files), "finished": i, - "doc": doc, + "doc": file_name, }, ensure_ascii=False) - if i == len(docs) - 1: - not_refresh_vs_cache = False - else: - not_refresh_vs_cache = True - kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) - except Exception as e: - print(e) + kb.add_doc(kb_file, not_refresh_vs_cache=True) + else: + kb_name, file_name, error = result + msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。" + logger.error(msg) yield json.dumps({ "code": 500, - "msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。", + "msg": msg, }) + i += 1 return StreamingResponse(output(), media_type="text/event-stream") diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index ca0919e..5cdebfb 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -51,6 +51,13 @@ class KBService(ABC): def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings: return load_embeddings(self.embed_model, embed_device) + def save_vector_store(self, vector_store=None): + ''' + 保存向量库,仅支持FAISS。对于其它向量库该函数不做任何操作。 + 减少FAISS向量库操作时的类型判断。 + ''' + pass + def create_kb(self): """ 创建知识库 @@ -84,6 +91,8 @@ class KBService(ABC): """ if docs: custom_docs = True + for doc in docs: + doc.metadata.setdefault("source", kb_file.filepath) else: docs = kb_file.file2text() custom_docs = False diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 15cc790..a8fc17d 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -5,7 +5,8 @@ from configs.model_config import ( KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_MODEL, - SCORE_THRESHOLD + SCORE_THRESHOLD, + logger, ) from server.knowledge_base.kb_service.base import KBService, SupportedVSType from functools import lru_cache @@ -28,7 +29,7 @@ def load_faiss_vector_store( embeddings: Embeddings = None, tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed. ) -> FAISS: - print(f"loading vector store in '{knowledge_base_name}'.") + logger.info(f"loading vector store in '{knowledge_base_name}'.") vs_path = get_vs_path(knowledge_base_name) if embeddings is None: embeddings = load_embeddings(embed_model, embed_device) @@ -57,7 +58,7 @@ def refresh_vs_cache(kb_name: str): make vector store cache refreshed when next loading """ _VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1 - print(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}") + logger.info(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}") class FaissKBService(KBService): @@ -133,7 +134,7 @@ class FaissKBService(KBService): **kwargs): vector_store = self.load_vector_store() - ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] + ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata.get("source") == kb_file.filepath] if len(ids) == 0: return None diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 36c972b..9de4e7c 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -7,16 +7,20 @@ from configs.model_config import ( KB_ROOT_PATH, CHUNK_SIZE, OVERLAP_SIZE, - ZH_TITLE_ENHANCE + ZH_TITLE_ENHANCE, + logger, ) from functools import lru_cache import importlib from text_splitter import zh_title_enhance import langchain.document_loaders from langchain.docstore.document import Document +from langchain.text_splitter import TextSplitter from pathlib import Path import json -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import ThreadPoolExecutor +from server.utils import run_in_thread_pool +import io from typing import List, Union, Callable, Dict, Optional, Tuple, Generator @@ -173,12 +177,74 @@ def get_LoaderClass(file_extension): return LoaderClass +# 把一些向量化共用逻辑从KnowledgeFile抽取出来,等langchain支持内存文件的时候,可以将非磁盘文件向量化 +def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.StringIO, io.BytesIO]): + ''' + 根据loader_name和文件路径或内容返回文档加载器。 + ''' + try: + if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]: + document_loaders_module = importlib.import_module('document_loaders') + else: + document_loaders_module = importlib.import_module('langchain.document_loaders') + DocumentLoader = getattr(document_loaders_module, loader_name) + except Exception as e: + logger.error(f"为文件{file_path_or_content}查找加载器{loader_name}时出错:{e}") + document_loaders_module = importlib.import_module('langchain.document_loaders') + DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader") + + if loader_name == "UnstructuredFileLoader": + loader = DocumentLoader(file_path_or_content, autodetect_encoding=True) + elif loader_name == "CSVLoader": + loader = DocumentLoader(file_path_or_content, encoding="utf-8") + elif loader_name == "JSONLoader": + loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False) + elif loader_name == "CustomJSONLoader": + loader = DocumentLoader(file_path_or_content, text_content=False) + elif loader_name == "UnstructuredMarkdownLoader": + loader = DocumentLoader(file_path_or_content, mode="elements") + elif loader_name == "UnstructuredHTMLLoader": + loader = DocumentLoader(file_path_or_content, mode="elements") + else: + loader = DocumentLoader(file_path_or_content) + return loader + + +def make_text_splitter( + splitter_name: str = "SpacyTextSplitter", + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, +): + ''' + 根据参数获取特定的分词器 + ''' + splitter_name = splitter_name or "SpacyTextSplitter" + text_splitter_module = importlib.import_module('langchain.text_splitter') + try: + TextSplitter = getattr(text_splitter_module, splitter_name) + text_splitter = TextSplitter( + pipeline="zh_core_web_sm", + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + except Exception as e: + logger.error(f"查找分词器 {splitter_name} 时出错:{e}") + TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter") + text_splitter = TextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + return text_splitter + class KnowledgeFile: def __init__( self, filename: str, knowledge_base_name: str ): + ''' + 对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。 + ''' self.kb_name = knowledge_base_name self.filename = filename self.ext = os.path.splitext(filename)[-1].lower() @@ -186,76 +252,62 @@ class KnowledgeFile: raise ValueError(f"暂未支持的文件格式 {self.ext}") self.filepath = get_file_path(knowledge_base_name, filename) self.docs = None + self.splited_docs = None self.document_loader_name = get_LoaderClass(self.ext) # TODO: 增加依据文件格式匹配text_splitter self.text_splitter_name = None - def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE, refresh: bool = False): - if self.docs is not None and not refresh: - return self.docs + def file2docs(self, refresh: bool=False): + if self.docs is None or refresh: + logger.info(f"{self.document_loader_name} used for {self.filepath}") + loader = get_loader(self.document_loader_name, self.filepath) + self.docs = loader.load() + return self.docs - print(f"{self.document_loader_name} used for {self.filepath}") - try: - if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]: - document_loaders_module = importlib.import_module('document_loaders') - else: - document_loaders_module = importlib.import_module('langchain.document_loaders') - DocumentLoader = getattr(document_loaders_module, self.document_loader_name) - except Exception as e: - print(e) - document_loaders_module = importlib.import_module('langchain.document_loaders') - DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader") - if self.document_loader_name == "UnstructuredFileLoader": - loader = DocumentLoader(self.filepath, autodetect_encoding=True) - elif self.document_loader_name == "CSVLoader": - loader = DocumentLoader(self.filepath, encoding="utf-8") - elif self.document_loader_name == "JSONLoader": - loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False) - elif self.document_loader_name == "CustomJSONLoader": - loader = DocumentLoader(self.filepath, text_content=False) - elif self.document_loader_name == "UnstructuredMarkdownLoader": - loader = DocumentLoader(self.filepath, mode="elements") - elif self.document_loader_name == "UnstructuredHTMLLoader": - loader = DocumentLoader(self.filepath, mode="elements") - else: - loader = DocumentLoader(self.filepath) + def docs2texts( + self, + docs: List[Document] = None, + using_zh_title_enhance=ZH_TITLE_ENHANCE, + refresh: bool = False, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + text_splitter: TextSplitter = None, + ): + docs = docs or self.file2docs(refresh=refresh) + if not docs: + return [] + if self.ext not in [".csv"]: + if text_splitter is None: + text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap) + docs = text_splitter.split_documents(docs) - if self.ext in ".csv": - docs = loader.load() - else: - try: - if self.text_splitter_name is None: - text_splitter_module = importlib.import_module('langchain.text_splitter') - TextSplitter = getattr(text_splitter_module, "SpacyTextSplitter") - text_splitter = TextSplitter( - pipeline="zh_core_web_sm", - chunk_size=CHUNK_SIZE, - chunk_overlap=OVERLAP_SIZE, - ) - self.text_splitter_name = "SpacyTextSplitter" - else: - text_splitter_module = importlib.import_module('langchain.text_splitter') - TextSplitter = getattr(text_splitter_module, self.text_splitter_name) - text_splitter = TextSplitter( - chunk_size=CHUNK_SIZE, - chunk_overlap=OVERLAP_SIZE) - except Exception as e: - print(e) - text_splitter_module = importlib.import_module('langchain.text_splitter') - TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter") - text_splitter = TextSplitter( - chunk_size=CHUNK_SIZE, - chunk_overlap=OVERLAP_SIZE, - ) - - docs = loader.load_and_split(text_splitter) - - print(docs[0]) + print(f"文档切分示例:{docs[0]}") if using_zh_title_enhance: docs = zh_title_enhance(docs) - self.docs = docs - return docs + self.splited_docs = docs + return self.splited_docs + + def file2text( + self, + using_zh_title_enhance=ZH_TITLE_ENHANCE, + refresh: bool = False, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + text_splitter: TextSplitter = None, + ): + if self.splited_docs is None or refresh: + docs = self.file2docs() + self.splited_docs = self.docs2texts(docs=docs, + using_zh_title_enhance=using_zh_title_enhance, + refresh=refresh, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + text_splitter=text_splitter) + return self.splited_docs + + def file_exist(self): + return os.path.isfile(self.filepath) def get_mtime(self): return os.path.getmtime(self.filepath) @@ -264,53 +316,47 @@ class KnowledgeFile: return os.path.getsize(self.filepath) -def run_in_thread_pool( - func: Callable, - params: List[Dict] = [], - pool: ThreadPoolExecutor = None, -) -> Generator: - ''' - 在线程池中批量运行任务,并将运行结果以生成器的形式返回。 - 请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。 - ''' - tasks = [] - if pool is None: - pool = ThreadPoolExecutor() - - for kwargs in params: - thread = pool.submit(func, **kwargs) - tasks.append(thread) - - for obj in as_completed(tasks): - yield obj.result() - - def files2docs_in_thread( files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], pool: ThreadPoolExecutor = None, ) -> Generator: ''' - 利用多线程批量将文件转化成langchain Document. - 生成器返回值为{(kb_name, file_name): docs} + 利用多线程批量将磁盘文件转化成langchain Document. + 如果传入参数是Tuple,形式为(filename, kb_name) + 生成器返回值为 status, (kb_name, file_name, docs | error) ''' - def task(*, file: KnowledgeFile, **kwargs) -> Dict[Tuple[str, str], List[Document]]: + def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]: try: return True, (file.kb_name, file.filename, file.file2text(**kwargs)) except Exception as e: - return False, e + msg = f"从文件 {file.kb_name}/{file.filename} 加载文档时出错:{e}" + logger.error(msg) + return False, (file.kb_name, file.filename, msg) kwargs_list = [] for i, file in enumerate(files): kwargs = {} if isinstance(file, tuple) and len(file) >= 2: - files[i] = KnowledgeFile(filename=file[0], knowledge_base_name=file[1]) + file = KnowledgeFile(filename=file[0], knowledge_base_name=file[1]) elif isinstance(file, dict): filename = file.pop("filename") kb_name = file.pop("kb_name") - files[i] = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) kwargs = file + file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) kwargs["file"] = file kwargs_list.append(kwargs) - - for result in run_in_thread_pool(func=task, params=kwargs_list, pool=pool): + + for result in run_in_thread_pool(func=file2docs, params=kwargs_list, pool=pool): yield result + + +if __name__ == "__main__": + from pprint import pprint + + kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples") + # kb_file.text_splitter_name = "RecursiveCharacterTextSplitter" + docs = kb_file.file2docs() + pprint(docs[-1]) + + docs = kb_file.file2text() + pprint(docs[-1]) diff --git a/server/utils.py b/server/utils.py index 0e53e3d..8d78124 100644 --- a/server/utils.py +++ b/server/utils.py @@ -8,7 +8,11 @@ from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDIN from configs.server_config import FSCHAT_MODEL_WORKERS import os from server import model_workers -from typing import Literal, Optional, Any +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Literal, Optional, Callable, Generator, Dict, Any + + +thread_pool = ThreadPoolExecutor() class BaseResponse(BaseModel): @@ -305,3 +309,24 @@ def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", " if device not in ["cuda", "mps", "cpu"]: device = detect_device() return device + + +def run_in_thread_pool( + func: Callable, + params: List[Dict] = [], + pool: ThreadPoolExecutor = None, +) -> Generator: + ''' + 在线程池中批量运行任务,并将运行结果以生成器的形式返回。 + 请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。 + ''' + tasks = [] + pool = pool or thread_pool + + for kwargs in params: + thread = pool.submit(func, **kwargs) + tasks.append(thread) + + for obj in as_completed(tasks): + yield obj.result() + diff --git a/tests/api/test_kb_api.py b/tests/api/test_kb_api.py index 51bbac1..ed4e8b2 100644 --- a/tests/api/test_kb_api.py +++ b/tests/api/test_kb_api.py @@ -7,19 +7,23 @@ root_path = Path(__file__).parent.parent.parent sys.path.append(str(root_path)) from server.utils import api_address from configs.model_config import VECTOR_SEARCH_TOP_K -from server.knowledge_base.utils import get_kb_path +from server.knowledge_base.utils import get_kb_path, get_file_path from pprint import pprint api_base_url = api_address() + kb = "kb_for_api_test" test_files = { + "FAQ.MD": str(root_path / "docs" / "FAQ.MD"), "README.MD": str(root_path / "README.MD"), - "FAQ.MD": str(root_path / "docs" / "FAQ.MD") + "test.txt": get_file_path("samples", "test.txt"), } +print("\n\n直接url访问\n") + def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"): if not Path(get_kb_path(kb)).exists(): @@ -78,37 +82,36 @@ def test_list_kbs(api="/knowledge_base/list_knowledge_bases"): assert kb in data["data"] -def test_upload_doc(api="/knowledge_base/upload_doc"): +def test_upload_docs(api="/knowledge_base/upload_docs"): url = api_base_url + api - for name, path in test_files.items(): - print(f"\n上传知识文件: {name}") - data = {"knowledge_base_name": kb, "override": True} - files = {"file": (name, open(path, "rb"))} - r = requests.post(url, data=data, files=files) - data = r.json() - pprint(data) - assert data["code"] == 200 - assert data["msg"] == f"成功上传文件 {name}" + files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()] - for name, path in test_files.items(): - print(f"\n尝试重新上传知识文件: {name}, 不覆盖") - data = {"knowledge_base_name": kb, "override": False} - files = {"file": (name, open(path, "rb"))} - r = requests.post(url, data=data, files=files) - data = r.json() - pprint(data) - assert data["code"] == 404 - assert data["msg"] == f"文件 {name} 已存在。" + print(f"\n上传知识文件") + data = {"knowledge_base_name": kb, "override": True} + r = requests.post(url, data=data, files=files) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 - for name, path in test_files.items(): - print(f"\n尝试重新上传知识文件: {name}, 覆盖") - data = {"knowledge_base_name": kb, "override": True} - files = {"file": (name, open(path, "rb"))} - r = requests.post(url, data=data, files=files) - data = r.json() - pprint(data) - assert data["code"] == 200 - assert data["msg"] == f"成功上传文件 {name}" + print(f"\n尝试重新上传知识文件, 不覆盖") + data = {"knowledge_base_name": kb, "override": False} + files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()] + r = requests.post(url, data=data, files=files) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == len(test_files) + + print(f"\n尝试重新上传知识文件, 覆盖,自定义docs") + docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]} + data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)} + files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()] + r = requests.post(url, data=data, files=files) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 def test_list_files(api="/knowledge_base/list_files"): @@ -134,26 +137,26 @@ def test_search_docs(api="/knowledge_base/search_docs"): assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K -def test_update_doc(api="/knowledge_base/update_doc"): +def test_update_docs(api="/knowledge_base/update_docs"): url = api_base_url + api - for name, path in test_files.items(): - print(f"\n更新知识文件: {name}") - r = requests.post(url, json={"knowledge_base_name": kb, "file_name": name}) - data = r.json() - pprint(data) - assert data["code"] == 200 - assert data["msg"] == f"成功更新文件 {name}" + + print(f"\n更新知识文件") + r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)}) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 -def test_delete_doc(api="/knowledge_base/delete_doc"): +def test_delete_docs(api="/knowledge_base/delete_docs"): url = api_base_url + api - for name, path in test_files.items(): - print(f"\n删除知识文件: {name}") - r = requests.post(url, json={"knowledge_base_name": kb, "doc_name": name}) - data = r.json() - pprint(data) - assert data["code"] == 200 - assert data["msg"] == f"{name} 文件删除成功" + + print(f"\n删除知识文件") + r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)}) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 url = api_base_url + "/knowledge_base/search_docs" query = "介绍一下langchain-chatchat项目" diff --git a/tests/api/test_kb_api_request.py b/tests/api/test_kb_api_request.py new file mode 100644 index 0000000..a0d15ce --- /dev/null +++ b/tests/api/test_kb_api_request.py @@ -0,0 +1,161 @@ +import requests +import json +import sys +from pathlib import Path + +root_path = Path(__file__).parent.parent.parent +sys.path.append(str(root_path)) +from server.utils import api_address +from configs.model_config import VECTOR_SEARCH_TOP_K +from server.knowledge_base.utils import get_kb_path, get_file_path +from webui_pages.utils import ApiRequest + +from pprint import pprint + + +api_base_url = api_address() +api: ApiRequest = ApiRequest(api_base_url) + + +kb = "kb_for_api_test" +test_files = { + "FAQ.MD": str(root_path / "docs" / "FAQ.MD"), + "README.MD": str(root_path / "README.MD"), + "test.txt": get_file_path("samples", "test.txt"), +} + +print("\n\nApiRquest调用\n") + + +def test_delete_kb_before(): + if not Path(get_kb_path(kb)).exists(): + return + + data = api.delete_knowledge_base(kb) + pprint(data) + assert data["code"] == 200 + assert isinstance(data["data"], list) and len(data["data"]) > 0 + assert kb not in data["data"] + + +def test_create_kb(): + print(f"\n尝试用空名称创建知识库:") + data = api.create_knowledge_base(" ") + pprint(data) + assert data["code"] == 404 + assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称" + + print(f"\n创建新知识库: {kb}") + data = api.create_knowledge_base(kb) + pprint(data) + assert data["code"] == 200 + assert data["msg"] == f"已新增知识库 {kb}" + + print(f"\n尝试创建同名知识库: {kb}") + data = api.create_knowledge_base(kb) + pprint(data) + assert data["code"] == 404 + assert data["msg"] == f"已存在同名知识库 {kb}" + + +def test_list_kbs(): + data = api.list_knowledge_bases() + pprint(data) + assert isinstance(data, list) and len(data) > 0 + assert kb in data + + +def test_upload_docs(): + files = list(test_files.values()) + + print(f"\n上传知识文件") + data = {"knowledge_base_name": kb, "override": True} + data = api.upload_kb_docs(files, **data) + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 + + print(f"\n尝试重新上传知识文件, 不覆盖") + data = {"knowledge_base_name": kb, "override": False} + data = api.upload_kb_docs(files, **data) + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == len(test_files) + + print(f"\n尝试重新上传知识文件, 覆盖,自定义docs") + docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]} + data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)} + data = api.upload_kb_docs(files, **data) + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 + + +def test_list_files(): + print("\n获取知识库中文件列表:") + data = api.list_kb_docs(knowledge_base_name=kb) + pprint(data) + assert isinstance(data, list) + for name in test_files: + assert name in data + + +def test_search_docs(): + query = "介绍一下langchain-chatchat项目" + print("\n检索知识库:") + print(query) + data = api.search_kb_docs(query, kb) + pprint(data) + assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K + + +def test_update_docs(): + print(f"\n更新知识文件") + data = api.update_kb_docs(knowledge_base_name=kb, file_names=list(test_files)) + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 + + +def test_delete_docs(): + print(f"\n删除知识文件") + data = api.delete_kb_docs(knowledge_base_name=kb, file_names=list(test_files)) + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 + + query = "介绍一下langchain-chatchat项目" + print("\n尝试检索删除后的检索知识库:") + print(query) + data = api.search_kb_docs(query, kb) + pprint(data) + assert isinstance(data, list) and len(data) == 0 + + +def test_recreate_vs(): + print("\n重建知识库:") + r = api.recreate_vector_store(kb) + for data in r: + assert isinstance(data, dict) + assert data["code"] == 200 + print(data["msg"]) + + query = "本项目支持哪些文件格式?" + print("\n尝试检索重建后的检索知识库:") + print(query) + data = api.search_kb_docs(query, kb) + pprint(data) + assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K + + +def test_delete_kb_after(): + print("\n删除知识库") + data = api.delete_knowledge_base(kb) + pprint(data) + + # check kb not exists anymore + print("\n获取知识库列表:") + data = api.list_knowledge_bases() + pprint(data) + assert isinstance(data, list) and len(data) > 0 + assert kb not in data diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 29a6322..2fafc51 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -138,14 +138,11 @@ def knowledge_base_page(api: ApiRequest): # use_container_width=True, disabled=len(files) == 0, ): - data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files] - data[-1]["not_refresh_vs_cache"]=False - for k in data: - ret = api.upload_kb_doc(**k) - if msg := check_success_msg(ret): - st.toast(msg, icon="✔") - elif msg := check_error_msg(ret): - st.toast(msg, icon="✖") + ret = api.upload_kb_docs(files, knowledge_base_name=kb, override=True) + if msg := check_success_msg(ret): + st.toast(msg, icon="✔") + elif msg := check_error_msg(ret): + st.toast(msg, icon="✖") st.session_state.files = [] st.divider() @@ -218,8 +215,8 @@ def knowledge_base_page(api: ApiRequest): disabled=not file_exists(kb, selected_rows)[0], use_container_width=True, ): - for row in selected_rows: - api.update_kb_doc(kb, row["file_name"]) + file_names = [row["file_name"] for row in selected_rows] + api.update_kb_docs(kb, file_names=file_names) st.experimental_rerun() # 将文件从向量库中删除,但不删除文件本身。 @@ -228,8 +225,8 @@ def knowledge_base_page(api: ApiRequest): disabled=not (selected_rows and selected_rows[0]["in_db"]), use_container_width=True, ): - for row in selected_rows: - api.delete_kb_doc(kb, row["file_name"]) + file_names = [row["file_name"] for row in selected_rows] + api.delete_kb_docs(kb, file_names=file_names) st.experimental_rerun() if cols[3].button( @@ -237,9 +234,8 @@ def knowledge_base_page(api: ApiRequest): type="primary", use_container_width=True, ): - for row in selected_rows: - ret = api.delete_kb_doc(kb, row["file_name"], True) - st.toast(ret.get("msg", " ")) + file_names = [row["file_name"] for row in selected_rows] + api.delete_kb_docs(kb, file_names=file_names, delete_content=True) st.experimental_rerun() st.divider() diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 0851104..27843f8 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -21,9 +21,7 @@ from fastapi.responses import StreamingResponse import contextlib import json from io import BytesIO -from server.db.repository.knowledge_base_repository import get_kb_detail -from server.db.repository.knowledge_file_repository import get_file_detail -from server.utils import run_async, iter_over_async, set_httpx_timeout +from server.utils import run_async, iter_over_async, set_httpx_timeout, api_address from configs.model_config import NLTK_DATA_PATH import nltk @@ -43,7 +41,7 @@ class ApiRequest: ''' def __init__( self, - base_url: str = "http://127.0.0.1:7861", + base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT, no_remote_api: bool = False, # call api view function directly ): @@ -78,7 +76,7 @@ class ApiRequest: else: return httpx.get(url, params=params, **kwargs) except Exception as e: - logger.error(e) + logger.error(f"error when get {url}: {e}") retry -= 1 async def aget( @@ -99,7 +97,7 @@ class ApiRequest: else: return await client.get(url, params=params, **kwargs) except Exception as e: - logger.error(e) + logger.error(f"error when aget {url}: {e}") retry -= 1 def post( @@ -121,7 +119,7 @@ class ApiRequest: else: return httpx.post(url, data=data, json=json, **kwargs) except Exception as e: - logger.error(e) + logger.error(f"error when post {url}: {e}") retry -= 1 async def apost( @@ -143,7 +141,7 @@ class ApiRequest: else: return await client.post(url, data=data, json=json, **kwargs) except Exception as e: - logger.error(e) + logger.error(f"error when apost {url}: {e}") retry -= 1 def delete( @@ -164,7 +162,7 @@ class ApiRequest: else: return httpx.delete(url, data=data, json=json, **kwargs) except Exception as e: - logger.error(e) + logger.error(f"error when delete {url}: {e}") retry -= 1 async def adelete( @@ -186,7 +184,7 @@ class ApiRequest: else: return await client.delete(url, data=data, json=json, **kwargs) except Exception as e: - logger.error(e) + logger.error(f"error when adelete {url}: {e}") retry -= 1 def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False): @@ -205,7 +203,7 @@ class ApiRequest: elif chunk.strip(): yield chunk except Exception as e: - logger.error(e) + logger.error(f"error when run fastapi router: {e}") def _httpx_stream2generator( self, @@ -231,18 +229,18 @@ class ApiRequest: print(chunk, end="", flush=True) yield chunk except httpx.ConnectError as e: - msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。" + msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})" + logger.error(msg) logger.error(msg) - logger.error(e) yield {"code": 500, "msg": msg} except httpx.ReadTimeout as e: - msg = f"API通信超时,请确认已启动FastChat与API服务(详见RADME '5. 启动 API 服务或 Web UI')" + msg = f"API通信超时,请确认已启动FastChat与API服务(详见RADME '5. 启动 API 服务或 Web UI')。({e})" logger.error(msg) - logger.error(e) yield {"code": 500, "msg": msg} except Exception as e: - logger.error(e) - yield {"code": 500, "msg": str(e)} + msg = f"API通信遇到错误:{e}" + logger.error(msg) + yield {"code": 500, "msg": msg} # 对话相关操作 @@ -413,8 +411,9 @@ class ApiRequest: try: return response.json() except Exception as e: - logger.error(e) - return {"code": 500, "msg": errorMsg or str(e)} + msg = "API未能返回正确的JSON。" + (errorMsg or str(e)) + logger.error(msg) + return {"code": 500, "msg": msg} def list_knowledge_bases( self, @@ -510,12 +509,45 @@ class ApiRequest: data = self._check_httpx_json_response(response) return data.get("data", []) - def upload_kb_doc( + def search_kb_docs( self, - file: Union[str, Path, bytes], + query: str, + knowledge_base_name: str, + top_k: int = VECTOR_SEARCH_TOP_K, + score_threshold: int = SCORE_THRESHOLD, + no_remote_api: bool = None, + ) -> List: + ''' + 对应api.py/knowledge_base/search_docs接口 + ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + + data = { + "query": query, + "knowledge_base_name": knowledge_base_name, + "top_k": top_k, + "score_threshold": score_threshold, + } + + if no_remote_api: + from server.knowledge_base.kb_doc_api import search_docs + return search_docs(**data) + else: + response = self.post( + "/knowledge_base/search_docs", + json=data, + ) + data = self._check_httpx_json_response(response) + return data + + def upload_kb_docs( + self, + files: List[Union[str, Path, bytes]], knowledge_base_name: str, - filename: str = None, override: bool = False, + to_vector_store: bool = True, + docs: Dict = {}, not_refresh_vs_cache: bool = False, no_remote_api: bool = None, ): @@ -525,97 +557,113 @@ class ApiRequest: if no_remote_api is None: no_remote_api = self.no_remote_api - if isinstance(file, bytes): # raw bytes - file = BytesIO(file) - elif hasattr(file, "read"): # a file io like object - filename = filename or file.name - else: # a local path - file = Path(file).absolute().open("rb") - filename = filename or file.name + def convert_file(file, filename=None): + if isinstance(file, bytes): # raw bytes + file = BytesIO(file) + elif hasattr(file, "read"): # a file io like object + filename = filename or file.name + else: # a local path + file = Path(file).absolute().open("rb") + filename = filename or file.name + return filename, file + + files = [convert_file(file) for file in files] + data={ + "knowledge_base_name": knowledge_base_name, + "override": override, + "to_vector_store": to_vector_store, + "docs": docs, + "not_refresh_vs_cache": not_refresh_vs_cache, + } if no_remote_api: - from server.knowledge_base.kb_doc_api import upload_doc + from server.knowledge_base.kb_doc_api import upload_docs from fastapi import UploadFile from tempfile import SpooledTemporaryFile - temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024) - temp_file.write(file.read()) - temp_file.seek(0) - response = run_async(upload_doc( - UploadFile(file=temp_file, filename=filename), - knowledge_base_name, - override, - )) + upload_files = [] + for file, filename in files: + temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024) + temp_file.write(file.read()) + temp_file.seek(0) + upload_files.append(UploadFile(file=temp_file, filename=filename)) + + response = run_async(upload_docs(upload_files, **data)) return response.dict() else: + if isinstance(data["docs"], dict): + data["docs"] = json.dumps(data["docs"], ensure_ascii=False) response = self.post( - "/knowledge_base/upload_doc", - data={ - "knowledge_base_name": knowledge_base_name, - "override": override, - "not_refresh_vs_cache": not_refresh_vs_cache, - }, - files={"file": (filename, file)}, + "/knowledge_base/upload_docs", + data=data, + files=[("files", (filename, file)) for filename, file in files], ) return self._check_httpx_json_response(response) - def delete_kb_doc( + def delete_kb_docs( self, knowledge_base_name: str, - doc_name: str, + file_names: List[str], delete_content: bool = False, not_refresh_vs_cache: bool = False, no_remote_api: bool = None, ): ''' - 对应api.py/knowledge_base/delete_doc接口 + 对应api.py/knowledge_base/delete_docs接口 ''' if no_remote_api is None: no_remote_api = self.no_remote_api data = { "knowledge_base_name": knowledge_base_name, - "doc_name": doc_name, + "file_names": file_names, "delete_content": delete_content, "not_refresh_vs_cache": not_refresh_vs_cache, } if no_remote_api: - from server.knowledge_base.kb_doc_api import delete_doc - response = run_async(delete_doc(**data)) + from server.knowledge_base.kb_doc_api import delete_docs + response = run_async(delete_docs(**data)) return response.dict() else: response = self.post( - "/knowledge_base/delete_doc", + "/knowledge_base/delete_docs", json=data, ) return self._check_httpx_json_response(response) - def update_kb_doc( + def update_kb_docs( self, knowledge_base_name: str, - file_name: str, + file_names: List[str], + override_custom_docs: bool = False, + docs: Dict = {}, not_refresh_vs_cache: bool = False, no_remote_api: bool = None, ): ''' - 对应api.py/knowledge_base/update_doc接口 + 对应api.py/knowledge_base/update_docs接口 ''' if no_remote_api is None: no_remote_api = self.no_remote_api + data = { + "knowledge_base_name": knowledge_base_name, + "file_names": file_names, + "override_custom_docs": override_custom_docs, + "docs": docs, + "not_refresh_vs_cache": not_refresh_vs_cache, + } if no_remote_api: - from server.knowledge_base.kb_doc_api import update_doc - response = run_async(update_doc(knowledge_base_name, file_name)) + from server.knowledge_base.kb_doc_api import update_docs + response = run_async(update_docs(**data)) return response.dict() else: + if isinstance(data["docs"], dict): + data["docs"] = json.dumps(data["docs"], ensure_ascii=False) response = self.post( - "/knowledge_base/update_doc", - json={ - "knowledge_base_name": knowledge_base_name, - "file_name": file_name, - "not_refresh_vs_cache": not_refresh_vs_cache, - }, + "/knowledge_base/update_docs", + json=data, ) return self._check_httpx_json_response(response)