diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index a8a9bcc..2d63308 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -14,9 +14,11 @@ import importlib from text_splitter import zh_title_enhance import langchain.document_loaders from langchain.docstore.document import Document +from langchain.text_splitter import TextSplitter from pathlib import Path import json -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import ThreadPoolExecutor +from server.utils import run_in_thread_pool from typing import List, Union, Callable, Dict, Optional, Tuple, Generator @@ -186,75 +188,110 @@ class KnowledgeFile: raise ValueError(f"暂未支持的文件格式 {self.ext}") self.filepath = get_file_path(knowledge_base_name, filename) self.docs = None + self.splited_docs = None self.document_loader_name = get_LoaderClass(self.ext) # TODO: 增加依据文件格式匹配text_splitter self.text_splitter_name = None - def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE, refresh: bool = False): - if self.docs is not None and not refresh: - return self.docs - - print(f"{self.document_loader_name} used for {self.filepath}") - try: - if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]: - document_loaders_module = importlib.import_module('document_loaders') - else: - document_loaders_module = importlib.import_module('langchain.document_loaders') - DocumentLoader = getattr(document_loaders_module, self.document_loader_name) - except Exception as e: - print(e) - document_loaders_module = importlib.import_module('langchain.document_loaders') - DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader") - if self.document_loader_name == "UnstructuredFileLoader": - loader = DocumentLoader(self.filepath, autodetect_encoding=True) - elif self.document_loader_name == "CSVLoader": - loader = DocumentLoader(self.filepath, encoding="utf-8") - elif self.document_loader_name == "JSONLoader": - loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False) - elif self.document_loader_name == "CustomJSONLoader": - loader = DocumentLoader(self.filepath, text_content=False) - elif self.document_loader_name == "UnstructuredMarkdownLoader": - loader = DocumentLoader(self.filepath, mode="elements") - elif self.document_loader_name == "UnstructuredHTMLLoader": - loader = DocumentLoader(self.filepath, mode="elements") - else: - loader = DocumentLoader(self.filepath) - - if self.ext in ".csv": - docs = loader.load() - else: + def file2docs(self, refresh: bool=False): + if self.docs is None or refresh: + print(f"{self.document_loader_name} used for {self.filepath}") try: - if self.text_splitter_name is None: - text_splitter_module = importlib.import_module('langchain.text_splitter') - TextSplitter = getattr(text_splitter_module, "SpacyTextSplitter") - text_splitter = TextSplitter( - pipeline="zh_core_web_sm", - chunk_size=CHUNK_SIZE, - chunk_overlap=OVERLAP_SIZE, - ) - self.text_splitter_name = "SpacyTextSplitter" + if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]: + document_loaders_module = importlib.import_module('document_loaders') else: - text_splitter_module = importlib.import_module('langchain.text_splitter') - TextSplitter = getattr(text_splitter_module, self.text_splitter_name) - text_splitter = TextSplitter( - chunk_size=CHUNK_SIZE, - chunk_overlap=OVERLAP_SIZE) + document_loaders_module = importlib.import_module('langchain.document_loaders') + DocumentLoader = getattr(document_loaders_module, self.document_loader_name) except Exception as e: print(e) - text_splitter_module = importlib.import_module('langchain.text_splitter') - TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter") - text_splitter = TextSplitter( - chunk_size=CHUNK_SIZE, - chunk_overlap=OVERLAP_SIZE, - ) + document_loaders_module = importlib.import_module('langchain.document_loaders') + DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader") + if self.document_loader_name == "UnstructuredFileLoader": + loader = DocumentLoader(self.filepath, autodetect_encoding=True) + elif self.document_loader_name == "CSVLoader": + loader = DocumentLoader(self.filepath, encoding="utf-8") + elif self.document_loader_name == "JSONLoader": + loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False) + elif self.document_loader_name == "CustomJSONLoader": + loader = DocumentLoader(self.filepath, text_content=False) + elif self.document_loader_name == "UnstructuredMarkdownLoader": + loader = DocumentLoader(self.filepath, mode="elements") + elif self.document_loader_name == "UnstructuredHTMLLoader": + loader = DocumentLoader(self.filepath, mode="elements") + else: + loader = DocumentLoader(self.filepath) + self.docs = loader.load() + return self.docs - docs = loader.load_and_split(text_splitter) - print(docs[0]) + def make_text_splitter( + self, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + ): + try: + if self.text_splitter_name is None: + text_splitter_module = importlib.import_module('langchain.text_splitter') + TextSplitter = getattr(text_splitter_module, "SpacyTextSplitter") + text_splitter = TextSplitter( + pipeline="zh_core_web_sm", + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + self.text_splitter_name = "SpacyTextSplitter" + else: + text_splitter_module = importlib.import_module('langchain.text_splitter') + TextSplitter = getattr(text_splitter_module, self.text_splitter_name) + text_splitter = TextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap) + except Exception as e: + print(e) + text_splitter_module = importlib.import_module('langchain.text_splitter') + TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter") + text_splitter = TextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + return text_splitter + + def docs2texts( + self, + docs: List[Document] = None, + using_zh_title_enhance=ZH_TITLE_ENHANCE, + refresh: bool = False, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + text_splitter: TextSplitter = None, + ): + docs = docs or self.file2docs(refresh=refresh) + + if self.ext not in [".csv"]: + if text_splitter is None: + text_splitter = self.make_text_splitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) + docs = text_splitter.split_documents(docs) + + print(f"文档切分示例:{docs[0]}") if using_zh_title_enhance: docs = zh_title_enhance(docs) - self.docs = docs - return docs + self.splited_docs = docs + return self.splited_docs + + def file2text( + self, + using_zh_title_enhance=ZH_TITLE_ENHANCE, + refresh: bool = False, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + text_splitter: TextSplitter = None, + ): + if self.splited_docs is None or refresh: + self.splited_docs = self.docs2texts(using_zh_title_enhance=using_zh_title_enhance, + refresh=refresh, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + text_splitter=text_splitter) + return self.splited_docs def get_mtime(self): return os.path.getmtime(self.filepath) @@ -263,36 +300,15 @@ class KnowledgeFile: return os.path.getsize(self.filepath) -def run_in_thread_pool( - func: Callable, - params: List[Dict] = [], - pool: ThreadPoolExecutor = None, -) -> Generator: - ''' - 在线程池中批量运行任务,并将运行结果以生成器的形式返回。 - 请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。 - ''' - tasks = [] - if pool is None: - pool = ThreadPoolExecutor() - - for kwargs in params: - thread = pool.submit(func, **kwargs) - tasks.append(thread) - - for obj in as_completed(tasks): - yield obj.result() - - def files2docs_in_thread( - files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], + files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], # 如果是Tuple,形式为(filename, kb_name) pool: ThreadPoolExecutor = None, ) -> Generator: ''' - 利用多线程批量将文件转化成langchain Document. + 利用多线程批量将磁盘文件转化成langchain Document. 生成器返回值为{(kb_name, file_name): docs} ''' - def task(*, file: KnowledgeFile, **kwargs) -> Dict[Tuple[str, str], List[Document]]: + def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]: try: return True, (file.kb_name, file.filename, file.file2text(**kwargs)) except Exception as e: @@ -302,14 +318,26 @@ def files2docs_in_thread( for i, file in enumerate(files): kwargs = {} if isinstance(file, tuple) and len(file) >= 2: - files[i] = KnowledgeFile(filename=file[0], knowledge_base_name=file[1]) + file = KnowledgeFile(filename=file[0], knowledge_base_name=file[1]) elif isinstance(file, dict): filename = file.pop("filename") kb_name = file.pop("kb_name") - files[i] = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) kwargs = file + file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) kwargs["file"] = file kwargs_list.append(kwargs) - for result in run_in_thread_pool(func=task, params=kwargs_list, pool=pool): + for result in run_in_thread_pool(func=file2docs, params=kwargs_list, pool=pool): yield result + + +if __name__ == "__main__": + from pprint import pprint + + kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples") + # kb_file.text_splitter_name = "RecursiveCharacterTextSplitter" + docs = kb_file.file2docs() + pprint(docs[-1]) + + docs = kb_file.file2text() + pprint(docs[-1]) diff --git a/server/utils.py b/server/utils.py index ec07dc6..9f4888f 100644 --- a/server/utils.py +++ b/server/utils.py @@ -9,7 +9,11 @@ from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDIN from configs.server_config import FSCHAT_MODEL_WORKERS import os from server import model_workers -from typing import Literal, Optional, Any +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Literal, Optional, Callable, Generator, Dict, Any + + +thread_pool = ThreadPoolExecutor() class BaseResponse(BaseModel): @@ -305,3 +309,24 @@ def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", " if device not in ["cuda", "mps", "cpu"]: device = detect_device() return device + + +def run_in_thread_pool( + func: Callable, + params: List[Dict] = [], + pool: ThreadPoolExecutor = None, +) -> Generator: + ''' + 在线程池中批量运行任务,并将运行结果以生成器的形式返回。 + 请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。 + ''' + tasks = [] + pool = pool or thread_pool + + for kwargs in params: + thread = pool.submit(func, **kwargs) + tasks.append(thread) + + for obj in as_completed(tasks): + yield obj.result() +