diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 2f91652..d302a04 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -51,16 +51,22 @@ def list_kbs_from_folder(): def list_files_from_folder(kb_name: str): + def is_skiped_path(path: str): # 跳过 [temp, tmp, ., ~$] 开头的目录和文件 + tail = os.path.basename(path).lower() + flag = False + for x in ["temp", "tmp", ".", "~$"]: + if tail.startswith(x): + flag = True + break + return flag + doc_path = get_doc_path(kb_name) result = [] for root, _, files in os.walk(doc_path): - tail = os.path.basename(root).lower() - if (tail.startswith("temp") - or tail.startswith("tmp") - or tail.startswith(".")): # 跳过 [temp, tmp, .] 开头的文件夹 + if is_skiped_path(root): continue for file in files: - if file.startswith("~$"): # 跳过 ~$ 开头的文件 + if is_skiped_path(file): continue path = Path(doc_path) / root / file result.append(path.resolve().relative_to(doc_path).as_posix()) @@ -114,10 +120,11 @@ def get_LoaderClass(file_extension): # 把一些向量化共用逻辑从KnowledgeFile抽取出来,等langchain支持内存文件的时候,可以将非磁盘文件向量化 -def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.StringIO, io.BytesIO]): +def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None): ''' 根据loader_name和文件路径或内容返回文档加载器。 ''' + loader_kwargs = loader_kwargs or {} try: if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader","FilteredCSVLoader"]: document_loaders_module = importlib.import_module('document_loaders') @@ -125,44 +132,32 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri document_loaders_module = importlib.import_module('langchain.document_loaders') DocumentLoader = getattr(document_loaders_module, loader_name) except Exception as e: - msg = f"为文件{file_path_or_content}查找加载器{loader_name}时出错:{e}" + msg = f"为文件{file_path}查找加载器{loader_name}时出错:{e}" logger.error(f'{e.__class__.__name__}: {msg}', exc_info=e if log_verbose else None) document_loaders_module = importlib.import_module('langchain.document_loaders') DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader") if loader_name == "UnstructuredFileLoader": - loader = DocumentLoader(file_path_or_content, autodetect_encoding=True) + loader_kwargs.setdefault("autodetect_encoding", True) elif loader_name == "CSVLoader": - # 自动识别文件编码类型,避免langchain loader 加载文件报编码错误 - with open(file_path_or_content, 'rb') as struct_file: - encode_detect = chardet.detect(struct_file.read()) - if encode_detect is None: - encode_detect = {"encoding": "utf-8"} - - loader = DocumentLoader(file_path_or_content, encoding=encode_detect["encoding"]) + if not loader_kwargs.get("encoding"): + # 如果未指定 encoding,自动识别文件编码类型,避免langchain loader 加载文件报编码错误 + with open(file_path, 'rb') as struct_file: + encode_detect = chardet.detect(struct_file.read()) + if encode_detect is None: + encode_detect = {"encoding": "utf-8"} + loader_kwargs["encoding"] = encode_detect["encoding"] ## TODO:支持更多的自定义CSV读取逻辑 elif loader_name == "JSONLoader": - loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False) + loader_kwargs.setdefault("jq_schema", ".") + loader_kwargs.setdefault("text_content", False) elif loader_name == "JSONLinesLoader": - loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False, json_lines=True) - elif loader_name == "UnstructuredMarkdownLoader": - loader = DocumentLoader(file_path_or_content, mode="elements") - elif loader_name == "UnstructuredHTMLLoader": - loader = DocumentLoader(file_path_or_content, mode="elements") - elif loader_name == "UnstructuredXMLLoader": - loader = DocumentLoader(file_path_or_content, mode="elements") - elif loader_name == "UnstructuredRSTLoader": - loader = DocumentLoader(file_path_or_content, mode="elements") - elif loader_name == "UnstructuredExcelLoader": - loader = DocumentLoader(file_path_or_content, mode="elements") - elif loader_name == "UnstructuredWordDocumentLoader": - loader = DocumentLoader(file_path_or_content, mode="elements") - elif loader_name == "UnstructuredPowerPointLoader": - loader = DocumentLoader(file_path_or_content, mode="elements") - else: - loader = DocumentLoader(file_path_or_content) + loader_kwargs.setdefault("jq_schema", ".") + loader_kwargs.setdefault("text_content", False) + + loader = DocumentLoader(file_path, **loader_kwargs) return loader @@ -248,7 +243,8 @@ class KnowledgeFile: def __init__( self, filename: str, - knowledge_base_name: str + knowledge_base_name: str, + loader_kwargs: Dict = {}, ): ''' 对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。 @@ -257,7 +253,8 @@ class KnowledgeFile: self.filename = filename self.ext = os.path.splitext(filename)[-1].lower() if self.ext not in SUPPORTED_EXTS: - raise ValueError(f"暂未支持的文件格式 {self.ext}") + raise ValueError(f"暂未支持的文件格式 {self.filename}") + self.loader_kwargs = loader_kwargs self.filepath = get_file_path(knowledge_base_name, filename) self.docs = None self.splited_docs = None @@ -267,7 +264,9 @@ class KnowledgeFile: def file2docs(self, refresh: bool = False): if self.docs is None or refresh: logger.info(f"{self.document_loader_name} used for {self.filepath}") - loader = get_loader(self.document_loader_name, self.filepath) + loader = get_loader(loader_name=self.document_loader_name, + file_path=self.filepath, + loader_kwargs=self.loader_kwargs) self.docs = loader.load() return self.docs