From 93b133f9acb4a7e01265d31d8b23cf721283c3f2 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Mon, 4 Sep 2023 16:37:44 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E5=B0=86KnowledgeFile=E7=9A=84file2text?= =?UTF-8?q?=E6=8B=86=E5=88=86=E6=88=90file2docs=E3=80=81docs2texts?= =?UTF-8?q?=E5=92=8Cfile2text=E4=B8=89=E4=B8=AA=E9=83=A8=E5=88=86=EF=BC=8C?= =?UTF-8?q?=E5=9C=A8=E4=BF=9D=E6=8C=81=E6=8E=A5=E5=8F=A3=E4=B8=8D=E5=8F=98?= =?UTF-8?q?=E7=9A=84=E6=83=85=E5=86=B5=E4=B8=8B=EF=BC=8C=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=EF=BC=9A=201=E3=80=81=E6=94=AF=E6=8C=81chunk=5Fsize=E5=92=8Cch?= =?UTF-8?q?unk=5Foverlap=E5=8F=82=E6=95=B0=202=E3=80=81=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E8=87=AA=E5=AE=9A=E4=B9=89text=5Fsplitter=203=E3=80=81?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E8=87=AA=E5=AE=9A=E4=B9=89docs=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=EF=BC=9Acsv=E6=96=87=E4=BB=B6=E4=B8=8D=E4=BD=BF?= =?UTF-8?q?=E7=94=A8text=5Fsplitter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/knowledge_base/utils.py | 198 +++++++++++++++++++-------------- server/utils.py | 27 ++++- 2 files changed, 139 insertions(+), 86 deletions(-) diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index a8a9bcc..2d63308 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -14,9 +14,11 @@ import importlib from text_splitter import zh_title_enhance import langchain.document_loaders from langchain.docstore.document import Document +from langchain.text_splitter import TextSplitter from pathlib import Path import json -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import ThreadPoolExecutor +from server.utils import run_in_thread_pool from typing import List, Union, Callable, Dict, Optional, Tuple, Generator @@ -186,75 +188,110 @@ class KnowledgeFile: raise ValueError(f"暂未支持的文件格式 {self.ext}") self.filepath = get_file_path(knowledge_base_name, filename) self.docs = None + self.splited_docs = None self.document_loader_name = get_LoaderClass(self.ext) # TODO: 增加依据文件格式匹配text_splitter self.text_splitter_name = None - def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE, refresh: bool = False): - if self.docs is not None and not refresh: - return self.docs - - print(f"{self.document_loader_name} used for {self.filepath}") - try: - if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]: - document_loaders_module = importlib.import_module('document_loaders') - else: - document_loaders_module = importlib.import_module('langchain.document_loaders') - DocumentLoader = getattr(document_loaders_module, self.document_loader_name) - except Exception as e: - print(e) - document_loaders_module = importlib.import_module('langchain.document_loaders') - DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader") - if self.document_loader_name == "UnstructuredFileLoader": - loader = DocumentLoader(self.filepath, autodetect_encoding=True) - elif self.document_loader_name == "CSVLoader": - loader = DocumentLoader(self.filepath, encoding="utf-8") - elif self.document_loader_name == "JSONLoader": - loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False) - elif self.document_loader_name == "CustomJSONLoader": - loader = DocumentLoader(self.filepath, text_content=False) - elif self.document_loader_name == "UnstructuredMarkdownLoader": - loader = DocumentLoader(self.filepath, mode="elements") - elif self.document_loader_name == "UnstructuredHTMLLoader": - loader = DocumentLoader(self.filepath, mode="elements") - else: - loader = DocumentLoader(self.filepath) - - if self.ext in ".csv": - docs = loader.load() - else: + def file2docs(self, refresh: bool=False): + if self.docs is None or refresh: + print(f"{self.document_loader_name} used for {self.filepath}") try: - if self.text_splitter_name is None: - text_splitter_module = importlib.import_module('langchain.text_splitter') - TextSplitter = getattr(text_splitter_module, "SpacyTextSplitter") - text_splitter = TextSplitter( - pipeline="zh_core_web_sm", - chunk_size=CHUNK_SIZE, - chunk_overlap=OVERLAP_SIZE, - ) - self.text_splitter_name = "SpacyTextSplitter" + if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]: + document_loaders_module = importlib.import_module('document_loaders') else: - text_splitter_module = importlib.import_module('langchain.text_splitter') - TextSplitter = getattr(text_splitter_module, self.text_splitter_name) - text_splitter = TextSplitter( - chunk_size=CHUNK_SIZE, - chunk_overlap=OVERLAP_SIZE) + document_loaders_module = importlib.import_module('langchain.document_loaders') + DocumentLoader = getattr(document_loaders_module, self.document_loader_name) except Exception as e: print(e) - text_splitter_module = importlib.import_module('langchain.text_splitter') - TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter") - text_splitter = TextSplitter( - chunk_size=CHUNK_SIZE, - chunk_overlap=OVERLAP_SIZE, - ) + document_loaders_module = importlib.import_module('langchain.document_loaders') + DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader") + if self.document_loader_name == "UnstructuredFileLoader": + loader = DocumentLoader(self.filepath, autodetect_encoding=True) + elif self.document_loader_name == "CSVLoader": + loader = DocumentLoader(self.filepath, encoding="utf-8") + elif self.document_loader_name == "JSONLoader": + loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False) + elif self.document_loader_name == "CustomJSONLoader": + loader = DocumentLoader(self.filepath, text_content=False) + elif self.document_loader_name == "UnstructuredMarkdownLoader": + loader = DocumentLoader(self.filepath, mode="elements") + elif self.document_loader_name == "UnstructuredHTMLLoader": + loader = DocumentLoader(self.filepath, mode="elements") + else: + loader = DocumentLoader(self.filepath) + self.docs = loader.load() + return self.docs - docs = loader.load_and_split(text_splitter) - print(docs[0]) + def make_text_splitter( + self, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + ): + try: + if self.text_splitter_name is None: + text_splitter_module = importlib.import_module('langchain.text_splitter') + TextSplitter = getattr(text_splitter_module, "SpacyTextSplitter") + text_splitter = TextSplitter( + pipeline="zh_core_web_sm", + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + self.text_splitter_name = "SpacyTextSplitter" + else: + text_splitter_module = importlib.import_module('langchain.text_splitter') + TextSplitter = getattr(text_splitter_module, self.text_splitter_name) + text_splitter = TextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap) + except Exception as e: + print(e) + text_splitter_module = importlib.import_module('langchain.text_splitter') + TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter") + text_splitter = TextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + return text_splitter + + def docs2texts( + self, + docs: List[Document] = None, + using_zh_title_enhance=ZH_TITLE_ENHANCE, + refresh: bool = False, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + text_splitter: TextSplitter = None, + ): + docs = docs or self.file2docs(refresh=refresh) + + if self.ext not in [".csv"]: + if text_splitter is None: + text_splitter = self.make_text_splitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) + docs = text_splitter.split_documents(docs) + + print(f"文档切分示例:{docs[0]}") if using_zh_title_enhance: docs = zh_title_enhance(docs) - self.docs = docs - return docs + self.splited_docs = docs + return self.splited_docs + + def file2text( + self, + using_zh_title_enhance=ZH_TITLE_ENHANCE, + refresh: bool = False, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + text_splitter: TextSplitter = None, + ): + if self.splited_docs is None or refresh: + self.splited_docs = self.docs2texts(using_zh_title_enhance=using_zh_title_enhance, + refresh=refresh, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + text_splitter=text_splitter) + return self.splited_docs def get_mtime(self): return os.path.getmtime(self.filepath) @@ -263,36 +300,15 @@ class KnowledgeFile: return os.path.getsize(self.filepath) -def run_in_thread_pool( - func: Callable, - params: List[Dict] = [], - pool: ThreadPoolExecutor = None, -) -> Generator: - ''' - 在线程池中批量运行任务,并将运行结果以生成器的形式返回。 - 请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。 - ''' - tasks = [] - if pool is None: - pool = ThreadPoolExecutor() - - for kwargs in params: - thread = pool.submit(func, **kwargs) - tasks.append(thread) - - for obj in as_completed(tasks): - yield obj.result() - - def files2docs_in_thread( - files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], + files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], # 如果是Tuple,形式为(filename, kb_name) pool: ThreadPoolExecutor = None, ) -> Generator: ''' - 利用多线程批量将文件转化成langchain Document. + 利用多线程批量将磁盘文件转化成langchain Document. 生成器返回值为{(kb_name, file_name): docs} ''' - def task(*, file: KnowledgeFile, **kwargs) -> Dict[Tuple[str, str], List[Document]]: + def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]: try: return True, (file.kb_name, file.filename, file.file2text(**kwargs)) except Exception as e: @@ -302,14 +318,26 @@ def files2docs_in_thread( for i, file in enumerate(files): kwargs = {} if isinstance(file, tuple) and len(file) >= 2: - files[i] = KnowledgeFile(filename=file[0], knowledge_base_name=file[1]) + file = KnowledgeFile(filename=file[0], knowledge_base_name=file[1]) elif isinstance(file, dict): filename = file.pop("filename") kb_name = file.pop("kb_name") - files[i] = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) kwargs = file + file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) kwargs["file"] = file kwargs_list.append(kwargs) - for result in run_in_thread_pool(func=task, params=kwargs_list, pool=pool): + for result in run_in_thread_pool(func=file2docs, params=kwargs_list, pool=pool): yield result + + +if __name__ == "__main__": + from pprint import pprint + + kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples") + # kb_file.text_splitter_name = "RecursiveCharacterTextSplitter" + docs = kb_file.file2docs() + pprint(docs[-1]) + + docs = kb_file.file2text() + pprint(docs[-1]) diff --git a/server/utils.py b/server/utils.py index ec07dc6..9f4888f 100644 --- a/server/utils.py +++ b/server/utils.py @@ -9,7 +9,11 @@ from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDIN from configs.server_config import FSCHAT_MODEL_WORKERS import os from server import model_workers -from typing import Literal, Optional, Any +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Literal, Optional, Callable, Generator, Dict, Any + + +thread_pool = ThreadPoolExecutor() class BaseResponse(BaseModel): @@ -305,3 +309,24 @@ def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", " if device not in ["cuda", "mps", "cpu"]: device = detect_device() return device + + +def run_in_thread_pool( + func: Callable, + params: List[Dict] = [], + pool: ThreadPoolExecutor = None, +) -> Generator: + ''' + 在线程池中批量运行任务,并将运行结果以生成器的形式返回。 + 请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。 + ''' + tasks = [] + pool = pool or thread_pool + + for kwargs in params: + thread = pool.submit(func, **kwargs) + tasks.append(thread) + + for obj in as_completed(tasks): + yield obj.result() + From 661a0e9d724eb5b1cee97fac75f16560e6540454 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Fri, 8 Sep 2023 08:55:12 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E6=96=B0=E5=8A=9F=E8=83=BD:=20-=20?= =?UTF-8?q?=E7=9F=A5=E8=AF=86=E5=BA=93=E7=AE=A1=E7=90=86=E4=B8=AD=E7=9A=84?= =?UTF-8?q?add=5Fdocs/delete=5Fdocs/update=5Fdocs=E5=9D=87=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E6=89=B9=E9=87=8F=E6=93=8D=E4=BD=9C=EF=BC=8C=E5=B9=B6?= =?UTF-8?q?=E5=88=A9=E7=94=A8=E5=A4=9A=E7=BA=BF=E7=A8=8B=E6=8F=90=E9=AB=98?= =?UTF-8?q?=E6=95=88=E7=8E=87=20-=20API=E7=9A=84=E9=87=8D=E5=BB=BA?= =?UTF-8?q?=E7=9F=A5=E8=AF=86=E5=BA=93=E6=8E=A5=E5=8F=A3=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=A4=9A=E7=BA=BF=E7=A8=8B=20-=20add=5Fdocs=E5=8F=AF=E6=8F=90?= =?UTF-8?q?=E4=BE=9B=E5=8F=82=E6=95=B0=E6=8E=A7=E5=88=B6=E4=B8=8A=E4=BC=A0?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E5=90=8E=E6=98=AF=E5=90=A6=E7=BB=A7=E7=BB=AD?= =?UTF-8?q?=E8=BF=9B=E8=A1=8C=E5=90=91=E9=87=8F=E5=8C=96=20-=20add=5Fdocs/?= =?UTF-8?q?update=5Fdocs=E6=94=AF=E6=8C=81=E4=BC=A0=E5=85=A5=E8=87=AA?= =?UTF-8?q?=E5=AE=9A=E4=B9=89docs(=E4=BB=A5json=E5=BD=A2=E5=BC=8F)?= =?UTF-8?q?=E3=80=82=E5=90=8E=E7=BB=AD=E8=80=83=E8=99=91=E5=8C=BA=E5=88=86?= =?UTF-8?q?=E5=AE=8C=E6=95=B4=E6=88=96=E8=A1=A5=E5=85=85=E5=BC=8F=E8=87=AA?= =?UTF-8?q?=E5=AE=9A=E4=B9=89docs=20-=20download=5Fdoc=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0`preview`=E5=8F=82=E6=95=B0=EF=BC=8C=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E4=B8=8B=E8=BD=BD=E6=88=96=E9=A2=84=E8=A7=88=20-=20kb?= =?UTF-8?q?=5Fservice=E5=A2=9E=E5=8A=A0`save=5Fvector=5Fstore`=E6=96=B9?= =?UTF-8?q?=E6=B3=95=EF=BC=8C=E4=BE=BF=E4=BA=8E=E4=BF=9D=E5=AD=98=E5=90=91?= =?UTF-8?q?=E9=87=8F=E5=BA=93=EF=BC=88=E4=BB=85FAISS=EF=BC=8C=E5=85=B6?= =?UTF-8?q?=E5=AE=83=E6=97=A0=E6=93=8D=E4=BD=9C=EF=BC=89=20-=20=E5=B0=86do?= =?UTF-8?q?cument=5Floader=20&=20text=5Fsplitter=E9=80=BB=E8=BE=91?= =?UTF-8?q?=E4=BB=8EKnowledgeFile=E4=B8=AD=E6=8A=BD=E7=A6=BB=E5=87=BA?= =?UTF-8?q?=E6=9D=A5=EF=BC=8C=E4=B8=BA=E5=90=8E=E7=BB=AD=E5=AF=B9=E5=86=85?= =?UTF-8?q?=E5=AD=98=E6=96=87=E4=BB=B6=E8=BF=9B=E8=A1=8C=E5=90=91=E9=87=8F?= =?UTF-8?q?=E5=8C=96=E5=81=9A=E5=87=86=E5=A4=87=20-=20KowledgeFile?= =?UTF-8?q?=E6=94=AF=E6=8C=81docs=20&=20splitted=5Fdocs=E7=9A=84=E7=BC=93?= =?UTF-8?q?=E5=AD=98=EF=BC=8C=E6=96=B9=E4=BE=BF=E5=9C=A8=E4=B8=AD=E9=97=B4?= =?UTF-8?q?=E8=BF=87=E7=A8=8B=E5=81=9A=E4=B8=80=E4=BA=9B=E8=87=AA=E5=AE=9A?= =?UTF-8?q?=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 其它: - 将部分错误输出由print改为logger.error --- server/api.py | 18 +- server/knowledge_base/kb_api.py | 12 +- server/knowledge_base/kb_doc_api.py | 278 ++++++++--- server/knowledge_base/kb_service/base.py | 9 + .../kb_service/faiss_kb_service.py | 9 +- server/knowledge_base/utils.py | 145 +++--- server/knowledge_base/utils.py.bak | 431 ++++++++++++++++++ tests/api/test_kb_api.py | 95 ++-- webui_pages/utils.py | 44 +- 9 files changed, 818 insertions(+), 223 deletions(-) create mode 100644 server/knowledge_base/utils.py.bak diff --git a/server/api.py b/server/api.py index 37954b7..74426cc 100644 --- a/server/api.py +++ b/server/api.py @@ -15,8 +15,8 @@ from starlette.responses import RedirectResponse from server.chat import (chat, knowledge_base_chat, openai_chat, search_engine_chat) from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb -from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc, - update_doc, download_doc, recreate_vector_store, +from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, + update_docs, download_doc, recreate_vector_store, search_docs, DocumentWithScore) from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address import httpx @@ -98,23 +98,23 @@ def create_app(): summary="搜索知识库" )(search_docs) - app.post("/knowledge_base/upload_doc", + app.post("/knowledge_base/upload_docs", tags=["Knowledge Base Management"], response_model=BaseResponse, - summary="上传文件到知识库" - )(upload_doc) + summary="上传文件到知识库,并/或进行向量化" + )(upload_docs) - app.post("/knowledge_base/delete_doc", + app.post("/knowledge_base/delete_docs", tags=["Knowledge Base Management"], response_model=BaseResponse, summary="删除知识库内指定文件" - )(delete_doc) + )(delete_docs) - app.post("/knowledge_base/update_doc", + app.post("/knowledge_base/update_docs", tags=["Knowledge Base Management"], response_model=BaseResponse, summary="更新现有文件到知识库" - )(update_doc) + )(update_docs) app.get("/knowledge_base/download_doc", tags=["Knowledge Base Management"], diff --git a/server/knowledge_base/kb_api.py b/server/knowledge_base/kb_api.py index b9151b8..4b135fe 100644 --- a/server/knowledge_base/kb_api.py +++ b/server/knowledge_base/kb_api.py @@ -3,7 +3,7 @@ from server.utils import BaseResponse, ListResponse from server.knowledge_base.utils import validate_kb_name from server.knowledge_base.kb_service.base import KBServiceFactory from server.db.repository.knowledge_base_repository import list_kbs_from_db -from configs.model_config import EMBEDDING_MODEL +from configs.model_config import EMBEDDING_MODEL, logger from fastapi import Body @@ -30,8 +30,9 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), try: kb.create_kb() except Exception as e: - print(e) - return BaseResponse(code=500, msg=f"创建知识库出错: {e}") + msg = f"创建知识库出错: {e}" + logger.error(msg) + return BaseResponse(code=500, msg=msg) return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") @@ -55,7 +56,8 @@ async def delete_kb( if status: return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}") except Exception as e: - print(e) - return BaseResponse(code=500, msg=f"删除知识库时出现意外: {e}") + msg = f"删除知识库时出现意外: {e}" + logger.error(msg) + return BaseResponse(code=500, msg=msg) return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}") diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 7ea5d27..24b1c8e 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -1,12 +1,17 @@ import os import urllib from fastapi import File, Form, Body, Query, UploadFile -from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) -from server.utils import BaseResponse, ListResponse -from server.knowledge_base.utils import validate_kb_name, list_files_from_folder, KnowledgeFile +from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, + VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, + logger,) +from server.utils import BaseResponse, ListResponse, run_in_thread_pool +from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder,get_file_path, + files2docs_in_thread, KnowledgeFile) from fastapi.responses import StreamingResponse, FileResponse +from pydantic import Json import json from server.knowledge_base.kb_service.base import KBServiceFactory +from server.db.repository.knowledge_file_repository import get_file_detail from typing import List, Dict from langchain.docstore.document import Document @@ -44,11 +49,83 @@ async def list_files( return ListResponse(data=all_doc_names) -async def upload_doc(file: UploadFile = File(..., description="上传文件"), - knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]), +def _save_files_in_thread(files: List[UploadFile], + knowledge_base_name: str, + override: bool): + ''' + 通过多线程将上传的文件保存到对应知识库目录内。 + 生成器返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}} + ''' + def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dict: + ''' + 保存单个文件。 + ''' + try: + filename = file.filename + file_path = get_file_path(knowledge_base_name=knowledge_base_name, doc_name=filename) + data = {"knowledge_base_name": knowledge_base_name, "file_name": filename} + + file_content = file.file.read() # 读取上传文件的内容 + if (os.path.isfile(file_path) + and not override + and os.path.getsize(file_path) == len(file_content) + ): + # TODO: filesize 不同后的处理 + file_status = f"文件 {filename} 已存在。" + logger.warn(file_status) + return dict(code=404, msg=file_status, data=data) + + with open(file_path, "wb") as f: + f.write(file_content) + return dict(code=200, msg=f"成功上传文件 {filename}", data=data) + except Exception as e: + msg = f"{filename} 文件上传失败,报错信息为: {e}" + logger.error(msg) + return dict(code=500, msg=msg, data=data) + + params = [{"file": file, "knowledge_base_name": knowledge_base_name, "override": override} for file in files] + for result in run_in_thread_pool(save_file, params=params): + yield result + + +# 似乎没有单独增加一个文件上传API接口的必要 +# def upload_files(files: List[UploadFile] = File(..., description="上传文件,支持多文件"), +# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]), +# override: bool = Form(False, description="覆盖已有文件")): +# ''' +# API接口:上传文件。流式返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}} +# ''' +# def generate(files, knowledge_base_name, override): +# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override): +# yield json.dumps(result, ensure_ascii=False) + +# return StreamingResponse(generate(files, knowledge_base_name=knowledge_base_name, override=override), media_type="text/event-stream") + + +# TODO: 等langchain.document_loaders支持内存文件的时候再开通 +# def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"), +# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]), +# override: bool = Form(False, description="覆盖已有文件"), +# save: bool = Form(True, description="是否将文件保存到知识库目录")): +# def save_files(files, knowledge_base_name, override): +# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override): +# yield json.dumps(result, ensure_ascii=False) + +# def files_to_docs(files): +# for result in files2docs_in_thread(files): +# yield json.dumps(result, ensure_ascii=False) + + +async def upload_docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"), + knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]), override: bool = Form(False, description="覆盖已有文件"), + to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"), + docs: Json = Form({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]), not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: + ''' + API接口:上传文件,并/或向量化 + ''' if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -56,37 +133,36 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"), if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - file_content = await file.read() # 读取上传文件的内容 + failed_files = {} + file_names = list(docs.keys()) - try: - kb_file = KnowledgeFile(filename=file.filename, - knowledge_base_name=knowledge_base_name) + # 先将上传的文件保存到磁盘 + for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override): + filename = result["data"]["file_name"] + if result["code"] != 200: + failed_files[filename] = result["msg"] + + if filename not in file_names: + file_names.append(filename) - if (os.path.exists(kb_file.filepath) - and not override - and os.path.getsize(kb_file.filepath) == len(file_content) - ): - # TODO: filesize 不同后的处理 - file_status = f"文件 {kb_file.filename} 已存在。" - return BaseResponse(code=404, msg=file_status) + # 对保存的文件进行向量化 + if to_vector_store: + result = await update_docs( + knowledge_base_name=knowledge_base_name, + file_names=file_names, + override_custom_docs=True, + docs=docs, + not_refresh_vs_cache=True, + ) + failed_files.update(result.data["failed_files"]) + if not not_refresh_vs_cache: + kb.save_vector_store() - with open(kb_file.filepath, "wb") as f: - f.write(file_content) - except Exception as e: - print(e) - return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}") - - try: - kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) - except Exception as e: - print(e) - return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}") - - return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}") + return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files}) -async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]), - doc_name: str = Body(..., examples=["file_name.md"]), +async def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]), + file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]), delete_content: bool = Body(False), not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: @@ -98,23 +174,31 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]), if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - if not kb.exist_doc(doc_name): - return BaseResponse(code=404, msg=f"未找到文件 {doc_name}") + failed_files = {} + for file_name in file_names: + if not kb.exist_doc(file_name): + failed_files[file_name] = f"未找到文件 {file_name}" - try: - kb_file = KnowledgeFile(filename=doc_name, - knowledge_base_name=knowledge_base_name) - kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=not_refresh_vs_cache) - except Exception as e: - print(e) - return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}") + try: + kb_file = KnowledgeFile(filename=file_name, + knowledge_base_name=knowledge_base_name) + kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=True) + except Exception as e: + msg = f"{file_name} 文件删除失败,错误信息:{e}" + logger.error(msg) + failed_files[file_name] = msg + + if not not_refresh_vs_cache: + kb.save_vector_store() - return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功") + return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files}) -async def update_doc( - knowledge_base_name: str = Body(..., examples=["samples"]), - file_name: str = Body(..., examples=["file_name"]), +async def update_docs( + knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), + file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=["file_name"]), + override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"), + docs: Json = Body({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]), not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: ''' @@ -127,22 +211,57 @@ async def update_doc( if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - try: - kb_file = KnowledgeFile(filename=file_name, - knowledge_base_name=knowledge_base_name) - if os.path.exists(kb_file.filepath): - kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) - return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}") - except Exception as e: - print(e) - return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败,错误信息是:{e}") + failed_files = {} + kb_files = [] - return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败") + # 生成需要加载docs的文件列表 + for file_name in file_names: + file_detail= get_file_detail(kb_name=knowledge_base_name, filename=file_name) + # 如果该文件之前使用了自定义docs,则根据参数决定略过或覆盖 + if file_detail.get("custom_docs") and not override_custom_docs: + continue + if file_name not in docs: + try: + kb_files.append(KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name)) + except Exception as e: + msg = f"加载文档 {file_name} 时出错:{e}" + logger.error(msg) + failed_files[file_name] = msg + + # 从文件生成docs,并进行向量化。 + # 这里利用了KnowledgeFile的缓存功能,在多线程中加载Document,然后传给KnowledgeFile + for status, result in files2docs_in_thread(kb_files): + if status: + kb_name, file_name, new_docs = result + kb_file = KnowledgeFile(filename=file_name, + knowledge_base_name=knowledge_base_name) + kb_file.splited_docs = new_docs + kb.update_doc(kb_file, not_refresh_vs_cache=True) + else: + kb_name, file_name, error = result + failed_files[file_name] = error + + # 将自定义的docs进行向量化 + for file_name, v in docs.items(): + try: + v = [x if isinstance(x, Document) else Document(**x) for x in v] + kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name) + kb.update_doc(kb_file, docs=v, not_refresh_vs_cache=True) + except Exception as e: + msg = f"为 {file_name} 添加自定义docs时出错:{e}" + logger.error(msg) + failed_files[file_name] = msg + + if not not_refresh_vs_cache: + kb.save_vector_store() + + return BaseResponse(code=200, msg=f"更新文档完成", data={"failed_files": failed_files}) async def download_doc( - knowledge_base_name: str = Query(..., examples=["samples"]), - file_name: str = Query(..., examples=["test.txt"]), + knowledge_base_name: str = Query(...,description="知识库名称", examples=["samples"]), + file_name: str = Query(...,description="文件名称", examples=["test.txt"]), + preview: bool = Query(False, description="是:浏览器内预览;否:下载"), ): ''' 下载知识库文档 @@ -154,6 +273,11 @@ async def download_doc( if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") + if preview: + content_disposition_type = "inline" + else: + content_disposition_type = None + try: kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name) @@ -162,10 +286,13 @@ async def download_doc( return FileResponse( path=kb_file.filepath, filename=kb_file.filename, - media_type="multipart/form-data") + media_type="multipart/form-data", + content_disposition_type=content_disposition_type, + ) except Exception as e: - print(e) - return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败,错误信息是:{e}") + msg = f"{kb_file.filename} 读取文件失败,错误信息是:{e}" + logger.error(msg) + return BaseResponse(code=500, msg=msg) return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败") @@ -190,27 +317,30 @@ async def recreate_vector_store( else: kb.create_kb() kb.clear_vs() - docs = list_files_from_folder(knowledge_base_name) - for i, doc in enumerate(docs): - try: - kb_file = KnowledgeFile(doc, knowledge_base_name) + files = list_files_from_folder(knowledge_base_name) + kb_files = [(file, knowledge_base_name) for file in files] + i = 0 + for status, result in files2docs_in_thread(kb_files): + if status: + kb_name, file_name, docs = result + kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name) + kb_file.splited_docs = docs yield json.dumps({ "code": 200, - "msg": f"({i + 1} / {len(docs)}): {doc}", - "total": len(docs), + "msg": f"({i + 1} / {len(files)}): {file_name}", + "total": len(files), "finished": i, - "doc": doc, + "doc": file_name, }, ensure_ascii=False) - if i == len(docs) - 1: - not_refresh_vs_cache = False - else: - not_refresh_vs_cache = True - kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) - except Exception as e: - print(e) + kb.add_doc(kb_file, not_refresh_vs_cache=True) + else: + kb_name, file_name, error = result + msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。" + logger.error(msg) yield json.dumps({ "code": 500, - "msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。", + "msg": msg, }) + i += 1 return StreamingResponse(output(), media_type="text/event-stream") diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 79b1518..82cb849 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -49,6 +49,13 @@ class KBService(ABC): def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings: return load_embeddings(self.embed_model, embed_device) + def save_vector_store(self, vector_store=None): + ''' + 保存向量库,仅支持FAISS。对于其它向量库该函数不做任何操作。 + 减少FAISS向量库操作时的类型判断。 + ''' + pass + def create_kb(self): """ 创建知识库 @@ -82,6 +89,8 @@ class KBService(ABC): """ if docs: custom_docs = True + for doc in docs: + doc.metadata.setdefault("source", kb_file.filepath) else: docs = kb_file.file2text() custom_docs = False diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index f17b2da..764e36b 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -5,7 +5,8 @@ from configs.model_config import ( KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_MODEL, - SCORE_THRESHOLD + SCORE_THRESHOLD, + logger, ) from server.knowledge_base.kb_service.base import KBService, SupportedVSType from functools import lru_cache @@ -28,7 +29,7 @@ def load_faiss_vector_store( embeddings: Embeddings = None, tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed. ) -> FAISS: - print(f"loading vector store in '{knowledge_base_name}'.") + logger.info(f"loading vector store in '{knowledge_base_name}'.") vs_path = get_vs_path(knowledge_base_name) if embeddings is None: embeddings = load_embeddings(embed_model, embed_device) @@ -57,7 +58,7 @@ def refresh_vs_cache(kb_name: str): make vector store cache refreshed when next loading """ _VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1 - print(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}") + logger.info(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}") class FaissKBService(KBService): @@ -128,7 +129,7 @@ class FaissKBService(KBService): **kwargs): vector_store = self.load_vector_store() - ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] + ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata.get("source") == kb_file.filepath] if len(ids) == 0: return None diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 2d63308..395d998 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -7,7 +7,8 @@ from configs.model_config import ( KB_ROOT_PATH, CHUNK_SIZE, OVERLAP_SIZE, - ZH_TITLE_ENHANCE + ZH_TITLE_ENHANCE, + logger, ) from functools import lru_cache import importlib @@ -19,6 +20,7 @@ from pathlib import Path import json from concurrent.futures import ThreadPoolExecutor from server.utils import run_in_thread_pool +import io from typing import List, Union, Callable, Dict, Optional, Tuple, Generator @@ -175,12 +177,74 @@ def get_LoaderClass(file_extension): return LoaderClass +# 把一些向量化共用逻辑从KnowledgeFile抽取出来,等langchain支持内存文件的时候,可以将非磁盘文件向量化 +def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.StringIO, io.BytesIO]): + ''' + 根据loader_name和文件路径或内容返回文档加载器。 + ''' + try: + if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]: + document_loaders_module = importlib.import_module('document_loaders') + else: + document_loaders_module = importlib.import_module('langchain.document_loaders') + DocumentLoader = getattr(document_loaders_module, loader_name) + except Exception as e: + logger.error(f"为文件{file_path_or_content}查找加载器{loader_name}时出错:{e}") + document_loaders_module = importlib.import_module('langchain.document_loaders') + DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader") + + if loader_name == "UnstructuredFileLoader": + loader = DocumentLoader(file_path_or_content, autodetect_encoding=True) + elif loader_name == "CSVLoader": + loader = DocumentLoader(file_path_or_content, encoding="utf-8") + elif loader_name == "JSONLoader": + loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False) + elif loader_name == "CustomJSONLoader": + loader = DocumentLoader(file_path_or_content, text_content=False) + elif loader_name == "UnstructuredMarkdownLoader": + loader = DocumentLoader(file_path_or_content, mode="elements") + elif loader_name == "UnstructuredHTMLLoader": + loader = DocumentLoader(file_path_or_content, mode="elements") + else: + loader = DocumentLoader(file_path_or_content) + return loader + + +def make_text_splitter( + splitter_name: str = "SpacyTextSplitter", + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, +): + ''' + 根据参数获取特定的分词器 + ''' + splitter_name = splitter_name or "SpacyTextSplitter" + text_splitter_module = importlib.import_module('langchain.text_splitter') + try: + TextSplitter = getattr(text_splitter_module, splitter_name) + text_splitter = TextSplitter( + pipeline="zh_core_web_sm", + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + except Exception as e: + logger.error(f"查找分词器 {splitter_name} 时出错:{e}") + TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter") + text_splitter = TextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + return text_splitter + class KnowledgeFile: def __init__( self, filename: str, knowledge_base_name: str ): + ''' + 对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。 + ''' self.kb_name = knowledge_base_name self.filename = filename self.ext = os.path.splitext(filename)[-1].lower() @@ -196,65 +260,11 @@ class KnowledgeFile: def file2docs(self, refresh: bool=False): if self.docs is None or refresh: - print(f"{self.document_loader_name} used for {self.filepath}") - try: - if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]: - document_loaders_module = importlib.import_module('document_loaders') - else: - document_loaders_module = importlib.import_module('langchain.document_loaders') - DocumentLoader = getattr(document_loaders_module, self.document_loader_name) - except Exception as e: - print(e) - document_loaders_module = importlib.import_module('langchain.document_loaders') - DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader") - if self.document_loader_name == "UnstructuredFileLoader": - loader = DocumentLoader(self.filepath, autodetect_encoding=True) - elif self.document_loader_name == "CSVLoader": - loader = DocumentLoader(self.filepath, encoding="utf-8") - elif self.document_loader_name == "JSONLoader": - loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False) - elif self.document_loader_name == "CustomJSONLoader": - loader = DocumentLoader(self.filepath, text_content=False) - elif self.document_loader_name == "UnstructuredMarkdownLoader": - loader = DocumentLoader(self.filepath, mode="elements") - elif self.document_loader_name == "UnstructuredHTMLLoader": - loader = DocumentLoader(self.filepath, mode="elements") - else: - loader = DocumentLoader(self.filepath) + logger.info(f"{self.document_loader_name} used for {self.filepath}") + loader = get_loader(self.document_loader_name, self.filepath) self.docs = loader.load() return self.docs - def make_text_splitter( - self, - chunk_size: int = CHUNK_SIZE, - chunk_overlap: int = OVERLAP_SIZE, - ): - try: - if self.text_splitter_name is None: - text_splitter_module = importlib.import_module('langchain.text_splitter') - TextSplitter = getattr(text_splitter_module, "SpacyTextSplitter") - text_splitter = TextSplitter( - pipeline="zh_core_web_sm", - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - ) - self.text_splitter_name = "SpacyTextSplitter" - else: - text_splitter_module = importlib.import_module('langchain.text_splitter') - TextSplitter = getattr(text_splitter_module, self.text_splitter_name) - text_splitter = TextSplitter( - chunk_size=chunk_size, - chunk_overlap=chunk_overlap) - except Exception as e: - print(e) - text_splitter_module = importlib.import_module('langchain.text_splitter') - TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter") - text_splitter = TextSplitter( - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - ) - return text_splitter - def docs2texts( self, docs: List[Document] = None, @@ -265,10 +275,11 @@ class KnowledgeFile: text_splitter: TextSplitter = None, ): docs = docs or self.file2docs(refresh=refresh) - + if not docs: + return [] if self.ext not in [".csv"]: if text_splitter is None: - text_splitter = self.make_text_splitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) + text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap) docs = text_splitter.split_documents(docs) print(f"文档切分示例:{docs[0]}") @@ -286,13 +297,18 @@ class KnowledgeFile: text_splitter: TextSplitter = None, ): if self.splited_docs is None or refresh: - self.splited_docs = self.docs2texts(using_zh_title_enhance=using_zh_title_enhance, + docs = self.file2docs() + self.splited_docs = self.docs2texts(docs=docs, + using_zh_title_enhance=using_zh_title_enhance, refresh=refresh, chunk_size=chunk_size, chunk_overlap=chunk_overlap, text_splitter=text_splitter) return self.splited_docs + def file_exist(self): + return os.path.isfile(self.filepath) + def get_mtime(self): return os.path.getmtime(self.filepath) @@ -301,18 +317,21 @@ class KnowledgeFile: def files2docs_in_thread( - files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], # 如果是Tuple,形式为(filename, kb_name) + files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], pool: ThreadPoolExecutor = None, ) -> Generator: ''' 利用多线程批量将磁盘文件转化成langchain Document. - 生成器返回值为{(kb_name, file_name): docs} + 如果传入参数是Tuple,形式为(filename, kb_name) + 生成器返回值为 status, (kb_name, file_name, docs | error) ''' def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]: try: return True, (file.kb_name, file.filename, file.file2text(**kwargs)) except Exception as e: - return False, e + msg = f"从文件 {file.kb_name}/{file.filename} 加载文档时出错:{e}" + logger.error(msg) + return False, (file.kb_name, file.filename, msg) kwargs_list = [] for i, file in enumerate(files): diff --git a/server/knowledge_base/utils.py.bak b/server/knowledge_base/utils.py.bak new file mode 100644 index 0000000..499b85e --- /dev/null +++ b/server/knowledge_base/utils.py.bak @@ -0,0 +1,431 @@ +import os +from langchain.embeddings.huggingface import HuggingFaceEmbeddings +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.embeddings import HuggingFaceBgeEmbeddings +from configs.model_config import ( + embedding_model_dict, + KB_ROOT_PATH, + CHUNK_SIZE, + OVERLAP_SIZE, + ZH_TITLE_ENHANCE, + logger, +) +from functools import lru_cache +import importlib +from text_splitter import zh_title_enhance +import langchain.document_loaders +import document_loaders +import unstructured.partition +from langchain.docstore.document import Document +from langchain.text_splitter import TextSplitter +from pathlib import Path +import json +from concurrent.futures import ThreadPoolExecutor +from server.utils import run_in_thread_pool +import io +import builtins +from datetime import datetime +from typing import List, Union, Callable, Dict, Optional, Tuple, Generator + + +# make HuggingFaceEmbeddings hashable +def _embeddings_hash(self): + if isinstance(self, HuggingFaceEmbeddings): + return hash(self.model_name) + elif isinstance(self, HuggingFaceBgeEmbeddings): + return hash(self.model_name) + elif isinstance(self, OpenAIEmbeddings): + return hash(self.model) + +HuggingFaceEmbeddings.__hash__ = _embeddings_hash +OpenAIEmbeddings.__hash__ = _embeddings_hash +HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash + + +# patch langchain.document_loaders和项目自定义document_loaders,替换其中的open函数。 +# 使其支持对str,bytes,io.StringIO,io.BytesIO进行向量化 +def _new_open(content: Union[str, bytes, io.StringIO, io.BytesIO, Path], *args, **kw): + if isinstance(content, (io.StringIO, io.BytesIO)): + return content + if isinstance(content, str): + if os.path.isfile(content): + return builtins.open(content, *args, **kw) + else: + return io.StringIO(content) + if isinstance(content, bytes): + return io.BytesIO(bytes) + if isinstance(content, Path): + return Path.open(*args, **kw) + return open(content, *args, **kw) + +for module in [langchain.document_loaders, document_loaders]: + for k, v in module.__dict__.items(): + if type(v) == type(langchain.document_loaders): + v.open = _new_open + +# path unstructured 使其在处理非磁盘文件时不会出错 +def _new_get_last_modified_date(filename: str) -> Union[str, None]: + try: + modify_date = datetime.fromtimestamp(os.path.getmtime(filename)) + return modify_date.strftime("%Y-%m-%dT%H:%M:%S%z") + except: + return None + +for k, v in unstructured.partition.__dict__.items(): + if type(v) == type(unstructured.partition): + v.open = _new_open + v.get_last_modified_date = _new_get_last_modified_date + + +def validate_kb_name(knowledge_base_id: str) -> bool: + # 检查是否包含预期外的字符或路径攻击关键字 + if "../" in knowledge_base_id: + return False + return True + + +def get_kb_path(knowledge_base_name: str): + return os.path.join(KB_ROOT_PATH, knowledge_base_name) + + +def get_doc_path(knowledge_base_name: str): + return os.path.join(get_kb_path(knowledge_base_name), "content") + + +def get_vs_path(knowledge_base_name: str): + return os.path.join(get_kb_path(knowledge_base_name), "vector_store") + + +def get_file_path(knowledge_base_name: str, doc_name: str): + return os.path.join(get_doc_path(knowledge_base_name), doc_name) + + +def list_kbs_from_folder(): + return [f for f in os.listdir(KB_ROOT_PATH) + if os.path.isdir(os.path.join(KB_ROOT_PATH, f))] + + +def list_files_from_folder(kb_name: str): + doc_path = get_doc_path(kb_name) + return [file for file in os.listdir(doc_path) + if os.path.isfile(os.path.join(doc_path, file))] + + +@lru_cache(1) +def load_embeddings(model: str, device: str): + if model == "text-embedding-ada-002": # openai text-embedding-ada-002 + embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE) + elif 'bge-' in model: + embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model], + model_kwargs={'device': device}, + query_instruction="为这个句子生成表示以用于检索相关文章:") + if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding + embeddings.query_instruction = "" + else: + embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device}) + return embeddings + + +LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'], + "UnstructuredMarkdownLoader": ['.md'], + "CustomJSONLoader": [".json"], + "CSVLoader": [".csv"], + "RapidOCRPDFLoader": [".pdf"], + "RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'], + "UnstructuredFileLoader": ['.eml', '.msg', '.rst', + '.rtf', '.txt', '.xml', + '.doc', '.docx', '.epub', '.odt', + '.ppt', '.pptx', '.tsv'], # '.xlsx' + } +SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] + + +class CustomJSONLoader(langchain.document_loaders.JSONLoader): + ''' + langchain的JSONLoader需要jq,在win上使用不便,进行替代。 + ''' + + def __init__( + self, + file_path: Union[str, Path], + content_key: Optional[str] = None, + metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None, + text_content: bool = True, + json_lines: bool = False, + ): + """Initialize the JSONLoader. + + Args: + file_path (Union[str, Path]): The path to the JSON or JSON Lines file. + content_key (str): The key to use to extract the content from the JSON if + results to a list of objects (dict). + metadata_func (Callable[Dict, Dict]): A function that takes in the JSON + object extracted by the jq_schema and the default metadata and returns + a dict of the updated metadata. + text_content (bool): Boolean flag to indicate whether the content is in + string format, default to True. + json_lines (bool): Boolean flag to indicate whether the input is in + JSON Lines format. + """ + self.file_path = Path(file_path).resolve() + self._content_key = content_key + self._metadata_func = metadata_func + self._text_content = text_content + self._json_lines = json_lines + + # TODO: langchain's JSONLoader.load has a encoding bug, raise gbk encoding error on windows. + # This is a workaround for langchain==0.0.266. I have make a pr(#9785) to langchain, it should be deleted after langchain upgraded. + def load(self) -> List[Document]: + """Load and return documents from the JSON file.""" + docs: List[Document] = [] + if self._json_lines: + with self.file_path.open(encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + self._parse(line, docs) + else: + self._parse(self.file_path.read_text(encoding="utf-8"), docs) + return docs + + def _parse(self, content: str, docs: List[Document]) -> None: + """Convert given content to documents.""" + data = json.loads(content) + + # Perform some validation + # This is not a perfect validation, but it should catch most cases + # and prevent the user from getting a cryptic error later on. + if self._content_key is not None: + self._validate_content_key(data) + + for i, sample in enumerate(data, len(docs) + 1): + metadata = dict( + source=str(self.file_path), + seq_num=i, + ) + text = self._get_text(sample=sample, metadata=metadata) + docs.append(Document(page_content=text, metadata=metadata)) + + +langchain.document_loaders.CustomJSONLoader = CustomJSONLoader + + +def get_LoaderClass(file_extension): + for LoaderClass, extensions in LOADER_DICT.items(): + if file_extension in extensions: + return LoaderClass + + +def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.StringIO, io.BytesIO]): + ''' + 根据loader_name和文件路径或内容返回文档加载器。 + ''' + try: + if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]: + document_loaders_module = importlib.import_module('document_loaders') + else: + document_loaders_module = importlib.import_module('langchain.document_loaders') + DocumentLoader = getattr(document_loaders_module, loader_name) + except Exception as e: + logger.error(e) + document_loaders_module = importlib.import_module('langchain.document_loaders') + DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader") + + if loader_name == "UnstructuredFileLoader": + loader = DocumentLoader(file_path_or_content, autodetect_encoding=True) + elif loader_name == "CSVLoader": + loader = DocumentLoader(file_path_or_content, encoding="utf-8") + elif loader_name == "JSONLoader": + loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False) + elif loader_name == "CustomJSONLoader": + loader = DocumentLoader(file_path_or_content, text_content=False) + elif loader_name == "UnstructuredMarkdownLoader": + loader = DocumentLoader(file_path_or_content, mode="elements") + elif loader_name == "UnstructuredHTMLLoader": + loader = DocumentLoader(file_path_or_content, mode="elements") + else: + loader = DocumentLoader(file_path_or_content) + return loader + + +def make_text_splitter( + splitter_name: str = "SpacyTextSplitter", + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, +): + ''' + 根据参数获取特定的分词器 + ''' + splitter_name = splitter_name or "SpacyTextSplitter" + text_splitter_module = importlib.import_module('langchain.text_splitter') + try: + TextSplitter = getattr(text_splitter_module, splitter_name) + text_splitter = TextSplitter( + pipeline="zh_core_web_sm", + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + except Exception as e: + logger.error(e) + TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter") + text_splitter = TextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + return text_splitter + + +def content_to_docs(content: Union[str, bytes, io.StringIO, io.BytesIO, Path], ext: str = ".md") -> List[Document]: + ''' + 将磁盘文件、文本、字节、内存文件等转化成Document + ''' + if not ext.startswith("."): + ext = "." + ext + ext = ext.lower() + if ext not in SUPPORTED_EXTS: + raise ValueError(f"暂未支持的文件格式 {ext}") + + loader_name = get_LoaderClass(ext) + loader = get_loader(loader_name=loader_name, file_path_or_content=content) + return loader.load() + + +def split_docs( + docs: List[Document], + splitter_name: str = "SpacyTextSplitter", + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, +) -> List[Document]: + text_splitter = make_text_splitter(splitter_name=splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap) + return text_splitter.split_documents(docs) + + +class KnowledgeFile: + def __init__( + self, + filename: str, + knowledge_base_name: str + ): + ''' + 对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。 + ''' + self.kb_name = knowledge_base_name + self.filename = filename + self.ext = os.path.splitext(filename)[-1].lower() + if self.ext not in SUPPORTED_EXTS: + raise ValueError(f"暂未支持的文件格式 {self.ext}") + self.filepath = get_file_path(knowledge_base_name, filename) + self.docs = None + self.splited_docs = None + self.document_loader_name = get_LoaderClass(self.ext) + + # TODO: 增加依据文件格式匹配text_splitter + self.text_splitter_name = None + + def file2docs(self, refresh: bool=False): + if self.docs is None or refresh: + logger.info(f"{self.document_loader_name} used for {self.filepath}") + loader = get_loader(self.document_loader_name, self.filepath) + self.docs = loader.load() + return self.docs + + def docs2texts( + self, + docs: List[Document] = None, + using_zh_title_enhance=ZH_TITLE_ENHANCE, + refresh: bool = False, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + text_splitter: TextSplitter = None, + ): + docs = docs or self.file2docs(refresh=refresh) + + if self.ext not in [".csv"]: + if text_splitter is None: + text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap) + docs = text_splitter.split_documents(docs) + + print(f"文档切分示例:{docs[0]}") + if using_zh_title_enhance: + docs = zh_title_enhance(docs) + self.splited_docs = docs + return self.splited_docs + + def file2text( + self, + using_zh_title_enhance=ZH_TITLE_ENHANCE, + refresh: bool = False, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + text_splitter: TextSplitter = None, + ): + if self.splited_docs is None or refresh: + self.splited_docs = self.docs2texts(using_zh_title_enhance=using_zh_title_enhance, + refresh=refresh, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + text_splitter=text_splitter) + return self.splited_docs + + def file_exist(self): + return os.path.isfile(self.filepath) + + def get_mtime(self): + return os.path.getmtime(self.filepath) + + def get_size(self): + return os.path.getsize(self.filepath) + + +def files2docs_in_thread( + files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], # 如果是Tuple,形式为(filename, kb_name) + pool: ThreadPoolExecutor = None, +) -> Generator: + ''' + 利用多线程批量将磁盘文件转化成langchain Document. + 生成器返回值为 status, (kb_name, file_name, docs) + ''' + def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]: + try: + return True, (file.kb_name, file.filename, file.file2text(**kwargs)) + except Exception as e: + return False, e + + kwargs_list = [] + for i, file in enumerate(files): + kwargs = {} + if isinstance(file, tuple) and len(file) >= 2: + file = KnowledgeFile(filename=file[0], knowledge_base_name=file[1]) + elif isinstance(file, dict): + filename = file.pop("filename") + kb_name = file.pop("kb_name") + kwargs = file + file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) + kwargs["file"] = file + kwargs_list.append(kwargs) + + for result in run_in_thread_pool(func=file2docs, params=kwargs_list, pool=pool): + yield result + + +if __name__ == "__main__": + from pprint import pprint + + # kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples") + # # kb_file.text_splitter_name = "RecursiveCharacterTextSplitter" + # docs = kb_file.file2docs() + # pprint(docs[-1]) + + # docs = kb_file.file2text() + # pprint(docs[-1]) + + docs = content_to_docs(""" + ## this is a title + + ## another title + + how are you + this a wonderful day. + """, "txt") + pprint(docs) + pprint(split_docs(docs, chunk_size=10, chunk_overlap=0)) diff --git a/tests/api/test_kb_api.py b/tests/api/test_kb_api.py index 51bbac1..3e371cf 100644 --- a/tests/api/test_kb_api.py +++ b/tests/api/test_kb_api.py @@ -7,17 +7,21 @@ root_path = Path(__file__).parent.parent.parent sys.path.append(str(root_path)) from server.utils import api_address from configs.model_config import VECTOR_SEARCH_TOP_K -from server.knowledge_base.utils import get_kb_path +from server.knowledge_base.utils import get_kb_path, get_file_path +from webui_pages.utils import ApiRequest from pprint import pprint api_base_url = api_address() +api = ApiRequest(api_base_url) + kb = "kb_for_api_test" test_files = { + "FAQ.MD": str(root_path / "docs" / "FAQ.MD"), "README.MD": str(root_path / "README.MD"), - "FAQ.MD": str(root_path / "docs" / "FAQ.MD") + "test.txt": get_file_path("samples", "test.txt"), } @@ -78,37 +82,36 @@ def test_list_kbs(api="/knowledge_base/list_knowledge_bases"): assert kb in data["data"] -def test_upload_doc(api="/knowledge_base/upload_doc"): +def test_upload_docs(api="/knowledge_base/upload_docs"): url = api_base_url + api - for name, path in test_files.items(): - print(f"\n上传知识文件: {name}") - data = {"knowledge_base_name": kb, "override": True} - files = {"file": (name, open(path, "rb"))} - r = requests.post(url, data=data, files=files) - data = r.json() - pprint(data) - assert data["code"] == 200 - assert data["msg"] == f"成功上传文件 {name}" + files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()] - for name, path in test_files.items(): - print(f"\n尝试重新上传知识文件: {name}, 不覆盖") - data = {"knowledge_base_name": kb, "override": False} - files = {"file": (name, open(path, "rb"))} - r = requests.post(url, data=data, files=files) - data = r.json() - pprint(data) - assert data["code"] == 404 - assert data["msg"] == f"文件 {name} 已存在。" + print(f"\n上传知识文件") + data = {"knowledge_base_name": kb, "override": True} + r = requests.post(url, data=data, files=files) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 - for name, path in test_files.items(): - print(f"\n尝试重新上传知识文件: {name}, 覆盖") - data = {"knowledge_base_name": kb, "override": True} - files = {"file": (name, open(path, "rb"))} - r = requests.post(url, data=data, files=files) - data = r.json() - pprint(data) - assert data["code"] == 200 - assert data["msg"] == f"成功上传文件 {name}" + print(f"\n尝试重新上传知识文件, 不覆盖") + data = {"knowledge_base_name": kb, "override": False} + files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()] + r = requests.post(url, data=data, files=files) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == len(test_files) + + print(f"\n尝试重新上传知识文件, 覆盖,自定义docs") + docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]} + data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)} + files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()] + r = requests.post(url, data=data, files=files) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 def test_list_files(api="/knowledge_base/list_files"): @@ -134,26 +137,26 @@ def test_search_docs(api="/knowledge_base/search_docs"): assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K -def test_update_doc(api="/knowledge_base/update_doc"): +def test_update_docs(api="/knowledge_base/update_docs"): url = api_base_url + api - for name, path in test_files.items(): - print(f"\n更新知识文件: {name}") - r = requests.post(url, json={"knowledge_base_name": kb, "file_name": name}) - data = r.json() - pprint(data) - assert data["code"] == 200 - assert data["msg"] == f"成功更新文件 {name}" + + print(f"\n更新知识文件") + r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)}) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 -def test_delete_doc(api="/knowledge_base/delete_doc"): +def test_delete_docs(api="/knowledge_base/delete_docs"): url = api_base_url + api - for name, path in test_files.items(): - print(f"\n删除知识文件: {name}") - r = requests.post(url, json={"knowledge_base_name": kb, "doc_name": name}) - data = r.json() - pprint(data) - assert data["code"] == 200 - assert data["msg"] == f"{name} 文件删除成功" + + print(f"\n删除知识文件") + r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)}) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 url = api_base_url + "/knowledge_base/search_docs" query = "介绍一下langchain-chatchat项目" diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 0851104..e14df66 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -21,9 +21,7 @@ from fastapi.responses import StreamingResponse import contextlib import json from io import BytesIO -from server.db.repository.knowledge_base_repository import get_kb_detail -from server.db.repository.knowledge_file_repository import get_file_detail -from server.utils import run_async, iter_over_async, set_httpx_timeout +from server.utils import run_async, iter_over_async, set_httpx_timeout, api_address from configs.model_config import NLTK_DATA_PATH import nltk @@ -43,7 +41,7 @@ class ApiRequest: ''' def __init__( self, - base_url: str = "http://127.0.0.1:7861", + base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT, no_remote_api: bool = False, # call api view function directly ): @@ -78,7 +76,7 @@ class ApiRequest: else: return httpx.get(url, params=params, **kwargs) except Exception as e: - logger.error(e) + logger.error(f"error when get {url}: {e}") retry -= 1 async def aget( @@ -99,7 +97,7 @@ class ApiRequest: else: return await client.get(url, params=params, **kwargs) except Exception as e: - logger.error(e) + logger.error(f"error when aget {url}: {e}") retry -= 1 def post( @@ -121,7 +119,7 @@ class ApiRequest: else: return httpx.post(url, data=data, json=json, **kwargs) except Exception as e: - logger.error(e) + logger.error(f"error when post {url}: {e}") retry -= 1 async def apost( @@ -143,7 +141,7 @@ class ApiRequest: else: return await client.post(url, data=data, json=json, **kwargs) except Exception as e: - logger.error(e) + logger.error(f"error when apost {url}: {e}") retry -= 1 def delete( @@ -164,7 +162,7 @@ class ApiRequest: else: return httpx.delete(url, data=data, json=json, **kwargs) except Exception as e: - logger.error(e) + logger.error(f"error when delete {url}: {e}") retry -= 1 async def adelete( @@ -186,7 +184,7 @@ class ApiRequest: else: return await client.delete(url, data=data, json=json, **kwargs) except Exception as e: - logger.error(e) + logger.error(f"error when adelete {url}: {e}") retry -= 1 def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False): @@ -205,7 +203,7 @@ class ApiRequest: elif chunk.strip(): yield chunk except Exception as e: - logger.error(e) + logger.error(f"error when run fastapi router: {e}") def _httpx_stream2generator( self, @@ -231,18 +229,18 @@ class ApiRequest: print(chunk, end="", flush=True) yield chunk except httpx.ConnectError as e: - msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。" + msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})" + logger.error(msg) logger.error(msg) - logger.error(e) yield {"code": 500, "msg": msg} except httpx.ReadTimeout as e: - msg = f"API通信超时,请确认已启动FastChat与API服务(详见RADME '5. 启动 API 服务或 Web UI')" + msg = f"API通信超时,请确认已启动FastChat与API服务(详见RADME '5. 启动 API 服务或 Web UI')。({e})" logger.error(msg) - logger.error(e) yield {"code": 500, "msg": msg} except Exception as e: - logger.error(e) - yield {"code": 500, "msg": str(e)} + msg = f"API通信遇到错误:{e}" + logger.error(msg) + yield {"code": 500, "msg": msg} # 对话相关操作 @@ -413,8 +411,9 @@ class ApiRequest: try: return response.json() except Exception as e: - logger.error(e) - return {"code": 500, "msg": errorMsg or str(e)} + msg = "API未能返回正确的JSON。" + (errorMsg or str(e)) + logger.error(msg) + return {"code": 500, "msg": msg} def list_knowledge_bases( self, @@ -510,12 +509,13 @@ class ApiRequest: data = self._check_httpx_json_response(response) return data.get("data", []) - def upload_kb_doc( + def upload_kb_docs( self, - file: Union[str, Path, bytes], + files: List[Union[str, Path, bytes]], knowledge_base_name: str, - filename: str = None, override: bool = False, + to_vector_store: bool = True, + docs: List[Dict] = [], not_refresh_vs_cache: bool = False, no_remote_api: bool = None, ): From 4cfee9c17cf8b860ade20220f3aaa5a16f69e3ad Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Fri, 8 Sep 2023 10:22:04 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E6=A0=B9=E6=8D=AE=E6=96=B0=E7=9A=84?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=E4=BF=AE=E6=94=B9ApiRequest=E5=92=8Cwebui?= =?UTF-8?q?=EF=BC=8C=E4=BB=A5=E5=8F=8A=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B?= =?UTF-8?q?=E3=80=82=E4=BF=AE=E6=94=B9=E5=90=8E=E9=A2=84=E6=9C=9Fwebui?= =?UTF-8?q?=E4=B8=AD=E6=89=B9=E9=87=8F=E7=9F=A5=E8=AF=86=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E7=9B=B8=E5=85=B3=E6=93=8D=E4=BD=9C=E5=87=8F=E5=B0=91=E6=97=B6?= =?UTF-8?q?=E9=97=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/api/test_kb_api.py | 4 +- tests/api/test_kb_api_request.py | 161 +++++++++++++++++++ webui_pages/knowledge_base/knowledge_base.py | 28 ++-- webui_pages/utils.py | 132 ++++++++++----- 4 files changed, 265 insertions(+), 60 deletions(-) create mode 100644 tests/api/test_kb_api_request.py diff --git a/tests/api/test_kb_api.py b/tests/api/test_kb_api.py index 3e371cf..ed4e8b2 100644 --- a/tests/api/test_kb_api.py +++ b/tests/api/test_kb_api.py @@ -8,13 +8,11 @@ sys.path.append(str(root_path)) from server.utils import api_address from configs.model_config import VECTOR_SEARCH_TOP_K from server.knowledge_base.utils import get_kb_path, get_file_path -from webui_pages.utils import ApiRequest from pprint import pprint api_base_url = api_address() -api = ApiRequest(api_base_url) kb = "kb_for_api_test" @@ -24,6 +22,8 @@ test_files = { "test.txt": get_file_path("samples", "test.txt"), } +print("\n\n直接url访问\n") + def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"): if not Path(get_kb_path(kb)).exists(): diff --git a/tests/api/test_kb_api_request.py b/tests/api/test_kb_api_request.py new file mode 100644 index 0000000..a0d15ce --- /dev/null +++ b/tests/api/test_kb_api_request.py @@ -0,0 +1,161 @@ +import requests +import json +import sys +from pathlib import Path + +root_path = Path(__file__).parent.parent.parent +sys.path.append(str(root_path)) +from server.utils import api_address +from configs.model_config import VECTOR_SEARCH_TOP_K +from server.knowledge_base.utils import get_kb_path, get_file_path +from webui_pages.utils import ApiRequest + +from pprint import pprint + + +api_base_url = api_address() +api: ApiRequest = ApiRequest(api_base_url) + + +kb = "kb_for_api_test" +test_files = { + "FAQ.MD": str(root_path / "docs" / "FAQ.MD"), + "README.MD": str(root_path / "README.MD"), + "test.txt": get_file_path("samples", "test.txt"), +} + +print("\n\nApiRquest调用\n") + + +def test_delete_kb_before(): + if not Path(get_kb_path(kb)).exists(): + return + + data = api.delete_knowledge_base(kb) + pprint(data) + assert data["code"] == 200 + assert isinstance(data["data"], list) and len(data["data"]) > 0 + assert kb not in data["data"] + + +def test_create_kb(): + print(f"\n尝试用空名称创建知识库:") + data = api.create_knowledge_base(" ") + pprint(data) + assert data["code"] == 404 + assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称" + + print(f"\n创建新知识库: {kb}") + data = api.create_knowledge_base(kb) + pprint(data) + assert data["code"] == 200 + assert data["msg"] == f"已新增知识库 {kb}" + + print(f"\n尝试创建同名知识库: {kb}") + data = api.create_knowledge_base(kb) + pprint(data) + assert data["code"] == 404 + assert data["msg"] == f"已存在同名知识库 {kb}" + + +def test_list_kbs(): + data = api.list_knowledge_bases() + pprint(data) + assert isinstance(data, list) and len(data) > 0 + assert kb in data + + +def test_upload_docs(): + files = list(test_files.values()) + + print(f"\n上传知识文件") + data = {"knowledge_base_name": kb, "override": True} + data = api.upload_kb_docs(files, **data) + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 + + print(f"\n尝试重新上传知识文件, 不覆盖") + data = {"knowledge_base_name": kb, "override": False} + data = api.upload_kb_docs(files, **data) + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == len(test_files) + + print(f"\n尝试重新上传知识文件, 覆盖,自定义docs") + docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]} + data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)} + data = api.upload_kb_docs(files, **data) + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 + + +def test_list_files(): + print("\n获取知识库中文件列表:") + data = api.list_kb_docs(knowledge_base_name=kb) + pprint(data) + assert isinstance(data, list) + for name in test_files: + assert name in data + + +def test_search_docs(): + query = "介绍一下langchain-chatchat项目" + print("\n检索知识库:") + print(query) + data = api.search_kb_docs(query, kb) + pprint(data) + assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K + + +def test_update_docs(): + print(f"\n更新知识文件") + data = api.update_kb_docs(knowledge_base_name=kb, file_names=list(test_files)) + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 + + +def test_delete_docs(): + print(f"\n删除知识文件") + data = api.delete_kb_docs(knowledge_base_name=kb, file_names=list(test_files)) + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 + + query = "介绍一下langchain-chatchat项目" + print("\n尝试检索删除后的检索知识库:") + print(query) + data = api.search_kb_docs(query, kb) + pprint(data) + assert isinstance(data, list) and len(data) == 0 + + +def test_recreate_vs(): + print("\n重建知识库:") + r = api.recreate_vector_store(kb) + for data in r: + assert isinstance(data, dict) + assert data["code"] == 200 + print(data["msg"]) + + query = "本项目支持哪些文件格式?" + print("\n尝试检索重建后的检索知识库:") + print(query) + data = api.search_kb_docs(query, kb) + pprint(data) + assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K + + +def test_delete_kb_after(): + print("\n删除知识库") + data = api.delete_knowledge_base(kb) + pprint(data) + + # check kb not exists anymore + print("\n获取知识库列表:") + data = api.list_knowledge_bases() + pprint(data) + assert isinstance(data, list) and len(data) > 0 + assert kb not in data diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 0889ca5..17b35d4 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -127,7 +127,7 @@ def knowledge_base_page(api: ApiRequest): # 上传文件 # sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True) - files = st.file_uploader("上传知识文件(暂不支持扫描PDF)", + files = st.file_uploader("上传知识文件", [i for ls in LOADER_DICT.values() for i in ls], accept_multiple_files=True, ) @@ -138,14 +138,11 @@ def knowledge_base_page(api: ApiRequest): # use_container_width=True, disabled=len(files) == 0, ): - data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files] - data[-1]["not_refresh_vs_cache"]=False - for k in data: - ret = api.upload_kb_doc(**k) - if msg := check_success_msg(ret): - st.toast(msg, icon="✔") - elif msg := check_error_msg(ret): - st.toast(msg, icon="✖") + ret = api.upload_kb_docs(files, knowledge_base_name=kb, override=True) + if msg := check_success_msg(ret): + st.toast(msg, icon="✔") + elif msg := check_error_msg(ret): + st.toast(msg, icon="✖") st.session_state.files = [] st.divider() @@ -217,8 +214,8 @@ def knowledge_base_page(api: ApiRequest): disabled=not file_exists(kb, selected_rows)[0], use_container_width=True, ): - for row in selected_rows: - api.update_kb_doc(kb, row["file_name"]) + file_names = [row["file_name"] for row in selected_rows] + api.update_kb_docs(kb, file_names=file_names) st.experimental_rerun() # 将文件从向量库中删除,但不删除文件本身。 @@ -227,8 +224,8 @@ def knowledge_base_page(api: ApiRequest): disabled=not (selected_rows and selected_rows[0]["in_db"]), use_container_width=True, ): - for row in selected_rows: - api.delete_kb_doc(kb, row["file_name"]) + file_names = [row["file_name"] for row in selected_rows] + api.delete_kb_docs(kb, file_names=file_names) st.experimental_rerun() if cols[3].button( @@ -236,9 +233,8 @@ def knowledge_base_page(api: ApiRequest): type="primary", use_container_width=True, ): - for row in selected_rows: - ret = api.delete_kb_doc(kb, row["file_name"], True) - st.toast(ret.get("msg", " ")) + file_names = [row["file_name"] for row in selected_rows] + api.delete_kb_docs(kb, file_names=file_names, delete_content=True) st.experimental_rerun() st.divider() diff --git a/webui_pages/utils.py b/webui_pages/utils.py index e14df66..27843f8 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -509,13 +509,45 @@ class ApiRequest: data = self._check_httpx_json_response(response) return data.get("data", []) + def search_kb_docs( + self, + query: str, + knowledge_base_name: str, + top_k: int = VECTOR_SEARCH_TOP_K, + score_threshold: int = SCORE_THRESHOLD, + no_remote_api: bool = None, + ) -> List: + ''' + 对应api.py/knowledge_base/search_docs接口 + ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + + data = { + "query": query, + "knowledge_base_name": knowledge_base_name, + "top_k": top_k, + "score_threshold": score_threshold, + } + + if no_remote_api: + from server.knowledge_base.kb_doc_api import search_docs + return search_docs(**data) + else: + response = self.post( + "/knowledge_base/search_docs", + json=data, + ) + data = self._check_httpx_json_response(response) + return data + def upload_kb_docs( self, files: List[Union[str, Path, bytes]], knowledge_base_name: str, override: bool = False, to_vector_store: bool = True, - docs: List[Dict] = [], + docs: Dict = {}, not_refresh_vs_cache: bool = False, no_remote_api: bool = None, ): @@ -525,97 +557,113 @@ class ApiRequest: if no_remote_api is None: no_remote_api = self.no_remote_api - if isinstance(file, bytes): # raw bytes - file = BytesIO(file) - elif hasattr(file, "read"): # a file io like object - filename = filename or file.name - else: # a local path - file = Path(file).absolute().open("rb") - filename = filename or file.name + def convert_file(file, filename=None): + if isinstance(file, bytes): # raw bytes + file = BytesIO(file) + elif hasattr(file, "read"): # a file io like object + filename = filename or file.name + else: # a local path + file = Path(file).absolute().open("rb") + filename = filename or file.name + return filename, file + + files = [convert_file(file) for file in files] + data={ + "knowledge_base_name": knowledge_base_name, + "override": override, + "to_vector_store": to_vector_store, + "docs": docs, + "not_refresh_vs_cache": not_refresh_vs_cache, + } if no_remote_api: - from server.knowledge_base.kb_doc_api import upload_doc + from server.knowledge_base.kb_doc_api import upload_docs from fastapi import UploadFile from tempfile import SpooledTemporaryFile - temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024) - temp_file.write(file.read()) - temp_file.seek(0) - response = run_async(upload_doc( - UploadFile(file=temp_file, filename=filename), - knowledge_base_name, - override, - )) + upload_files = [] + for file, filename in files: + temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024) + temp_file.write(file.read()) + temp_file.seek(0) + upload_files.append(UploadFile(file=temp_file, filename=filename)) + + response = run_async(upload_docs(upload_files, **data)) return response.dict() else: + if isinstance(data["docs"], dict): + data["docs"] = json.dumps(data["docs"], ensure_ascii=False) response = self.post( - "/knowledge_base/upload_doc", - data={ - "knowledge_base_name": knowledge_base_name, - "override": override, - "not_refresh_vs_cache": not_refresh_vs_cache, - }, - files={"file": (filename, file)}, + "/knowledge_base/upload_docs", + data=data, + files=[("files", (filename, file)) for filename, file in files], ) return self._check_httpx_json_response(response) - def delete_kb_doc( + def delete_kb_docs( self, knowledge_base_name: str, - doc_name: str, + file_names: List[str], delete_content: bool = False, not_refresh_vs_cache: bool = False, no_remote_api: bool = None, ): ''' - 对应api.py/knowledge_base/delete_doc接口 + 对应api.py/knowledge_base/delete_docs接口 ''' if no_remote_api is None: no_remote_api = self.no_remote_api data = { "knowledge_base_name": knowledge_base_name, - "doc_name": doc_name, + "file_names": file_names, "delete_content": delete_content, "not_refresh_vs_cache": not_refresh_vs_cache, } if no_remote_api: - from server.knowledge_base.kb_doc_api import delete_doc - response = run_async(delete_doc(**data)) + from server.knowledge_base.kb_doc_api import delete_docs + response = run_async(delete_docs(**data)) return response.dict() else: response = self.post( - "/knowledge_base/delete_doc", + "/knowledge_base/delete_docs", json=data, ) return self._check_httpx_json_response(response) - def update_kb_doc( + def update_kb_docs( self, knowledge_base_name: str, - file_name: str, + file_names: List[str], + override_custom_docs: bool = False, + docs: Dict = {}, not_refresh_vs_cache: bool = False, no_remote_api: bool = None, ): ''' - 对应api.py/knowledge_base/update_doc接口 + 对应api.py/knowledge_base/update_docs接口 ''' if no_remote_api is None: no_remote_api = self.no_remote_api + data = { + "knowledge_base_name": knowledge_base_name, + "file_names": file_names, + "override_custom_docs": override_custom_docs, + "docs": docs, + "not_refresh_vs_cache": not_refresh_vs_cache, + } if no_remote_api: - from server.knowledge_base.kb_doc_api import update_doc - response = run_async(update_doc(knowledge_base_name, file_name)) + from server.knowledge_base.kb_doc_api import update_docs + response = run_async(update_docs(**data)) return response.dict() else: + if isinstance(data["docs"], dict): + data["docs"] = json.dumps(data["docs"], ensure_ascii=False) response = self.post( - "/knowledge_base/update_doc", - json={ - "knowledge_base_name": knowledge_base_name, - "file_name": file_name, - "not_refresh_vs_cache": not_refresh_vs_cache, - }, + "/knowledge_base/update_docs", + json=data, ) return self._check_httpx_json_response(response)