diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index a863acb..dbe6886 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -1,6 +1,6 @@ import os from langchain.embeddings.huggingface import HuggingFaceEmbeddings -from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, EMBEDDING_MODEL, kbs_config) +from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, CHUNK_SIZE, OVERLAP_SIZE) from functools import lru_cache import sys from text_splitter import zh_title_enhance @@ -45,6 +45,8 @@ LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg '.doc', '.docx', '.epub', '.odt', '.pdf', '.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv' "CSVLoader": [".csv"], + "PyPDFLoader": [".pdf"], + } SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] @@ -70,15 +72,34 @@ class KnowledgeFile: self.document_loader_name = get_LoaderClass(self.ext) # TODO: 增加依据文件格式匹配text_splitter - self.text_splitter_name = "CharacterTextSplitter" + self.text_splitter_name = None def file2text(self, using_zh_title_enhance): DocumentLoader = getattr(sys.modules['langchain.document_loaders'], self.document_loader_name) loader = DocumentLoader(self.filepath) # TODO: 增加依据文件格式匹配text_splitter - TextSplitter = getattr(sys.modules['langchain.text_splitter'], self.text_splitter_name) - text_splitter = TextSplitter(chunk_size=250, chunk_overlap=200) + try: + if self.text_splitter_name is None: + TextSplitter = getattr(sys.modules['langchain.text_splitter'], "SpacyTextSplitter") + text_splitter = TextSplitter( + pipeline="zh_core_web_sm", + chunk_size=CHUNK_SIZE, + chunk_overlap=OVERLAP_SIZE, + ) + else: + TextSplitter = getattr(sys.modules['langchain.text_splitter'], self.text_splitter_name) + text_splitter = TextSplitter( + chunk_size=CHUNK_SIZE, + chunk_overlap=50) + except Exception as e: + print(e) + TextSplitter = getattr(sys.modules['langchain.text_splitter'], "RecursiveCharacterTextSplitter") + text_splitter = TextSplitter( + chunk_size=CHUNK_SIZE, + chunk_overlap=OVERLAP_SIZE, + ) + docs = loader.load_and_split(text_splitter) if using_zh_title_enhance: docs = zh_title_enhance(docs)