开发者: (#2091)
- 修复列出知识库磁盘文件时跳过临时文件的bug:只有目录被排除了,文件未排除 - 优化知识库文档加载器: - 将 elements 模式改为 single 模式,避免文档被切分得太碎 - 给 get_loader 和 KnowledgeFile 增加 loader_kwargs 参数,可以自定义文档加载器参数
This commit is contained in:
parent
68a544ea33
commit
ad7a6fd438
|
|
@ -51,16 +51,22 @@ def list_kbs_from_folder():
|
|||
|
||||
|
||||
def list_files_from_folder(kb_name: str):
|
||||
def is_skiped_path(path: str): # 跳过 [temp, tmp, ., ~$] 开头的目录和文件
|
||||
tail = os.path.basename(path).lower()
|
||||
flag = False
|
||||
for x in ["temp", "tmp", ".", "~$"]:
|
||||
if tail.startswith(x):
|
||||
flag = True
|
||||
break
|
||||
return flag
|
||||
|
||||
doc_path = get_doc_path(kb_name)
|
||||
result = []
|
||||
for root, _, files in os.walk(doc_path):
|
||||
tail = os.path.basename(root).lower()
|
||||
if (tail.startswith("temp")
|
||||
or tail.startswith("tmp")
|
||||
or tail.startswith(".")): # 跳过 [temp, tmp, .] 开头的文件夹
|
||||
if is_skiped_path(root):
|
||||
continue
|
||||
for file in files:
|
||||
if file.startswith("~$"): # 跳过 ~$ 开头的文件
|
||||
if is_skiped_path(file):
|
||||
continue
|
||||
path = Path(doc_path) / root / file
|
||||
result.append(path.resolve().relative_to(doc_path).as_posix())
|
||||
|
|
@ -114,10 +120,11 @@ def get_LoaderClass(file_extension):
|
|||
|
||||
|
||||
# 把一些向量化共用逻辑从KnowledgeFile抽取出来,等langchain支持内存文件的时候,可以将非磁盘文件向量化
|
||||
def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.StringIO, io.BytesIO]):
|
||||
def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
|
||||
'''
|
||||
根据loader_name和文件路径或内容返回文档加载器。
|
||||
'''
|
||||
loader_kwargs = loader_kwargs or {}
|
||||
try:
|
||||
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader","FilteredCSVLoader"]:
|
||||
document_loaders_module = importlib.import_module('document_loaders')
|
||||
|
|
@ -125,44 +132,32 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri
|
|||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||
DocumentLoader = getattr(document_loaders_module, loader_name)
|
||||
except Exception as e:
|
||||
msg = f"为文件{file_path_or_content}查找加载器{loader_name}时出错:{e}"
|
||||
msg = f"为文件{file_path}查找加载器{loader_name}时出错:{e}"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
|
||||
|
||||
if loader_name == "UnstructuredFileLoader":
|
||||
loader = DocumentLoader(file_path_or_content, autodetect_encoding=True)
|
||||
loader_kwargs.setdefault("autodetect_encoding", True)
|
||||
elif loader_name == "CSVLoader":
|
||||
# 自动识别文件编码类型,避免langchain loader 加载文件报编码错误
|
||||
with open(file_path_or_content, 'rb') as struct_file:
|
||||
encode_detect = chardet.detect(struct_file.read())
|
||||
if encode_detect is None:
|
||||
encode_detect = {"encoding": "utf-8"}
|
||||
|
||||
loader = DocumentLoader(file_path_or_content, encoding=encode_detect["encoding"])
|
||||
if not loader_kwargs.get("encoding"):
|
||||
# 如果未指定 encoding,自动识别文件编码类型,避免langchain loader 加载文件报编码错误
|
||||
with open(file_path, 'rb') as struct_file:
|
||||
encode_detect = chardet.detect(struct_file.read())
|
||||
if encode_detect is None:
|
||||
encode_detect = {"encoding": "utf-8"}
|
||||
loader_kwargs["encoding"] = encode_detect["encoding"]
|
||||
## TODO:支持更多的自定义CSV读取逻辑
|
||||
|
||||
elif loader_name == "JSONLoader":
|
||||
loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False)
|
||||
loader_kwargs.setdefault("jq_schema", ".")
|
||||
loader_kwargs.setdefault("text_content", False)
|
||||
elif loader_name == "JSONLinesLoader":
|
||||
loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False, json_lines=True)
|
||||
elif loader_name == "UnstructuredMarkdownLoader":
|
||||
loader = DocumentLoader(file_path_or_content, mode="elements")
|
||||
elif loader_name == "UnstructuredHTMLLoader":
|
||||
loader = DocumentLoader(file_path_or_content, mode="elements")
|
||||
elif loader_name == "UnstructuredXMLLoader":
|
||||
loader = DocumentLoader(file_path_or_content, mode="elements")
|
||||
elif loader_name == "UnstructuredRSTLoader":
|
||||
loader = DocumentLoader(file_path_or_content, mode="elements")
|
||||
elif loader_name == "UnstructuredExcelLoader":
|
||||
loader = DocumentLoader(file_path_or_content, mode="elements")
|
||||
elif loader_name == "UnstructuredWordDocumentLoader":
|
||||
loader = DocumentLoader(file_path_or_content, mode="elements")
|
||||
elif loader_name == "UnstructuredPowerPointLoader":
|
||||
loader = DocumentLoader(file_path_or_content, mode="elements")
|
||||
else:
|
||||
loader = DocumentLoader(file_path_or_content)
|
||||
loader_kwargs.setdefault("jq_schema", ".")
|
||||
loader_kwargs.setdefault("text_content", False)
|
||||
|
||||
loader = DocumentLoader(file_path, **loader_kwargs)
|
||||
return loader
|
||||
|
||||
|
||||
|
|
@ -248,7 +243,8 @@ class KnowledgeFile:
|
|||
def __init__(
|
||||
self,
|
||||
filename: str,
|
||||
knowledge_base_name: str
|
||||
knowledge_base_name: str,
|
||||
loader_kwargs: Dict = {},
|
||||
):
|
||||
'''
|
||||
对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。
|
||||
|
|
@ -257,7 +253,8 @@ class KnowledgeFile:
|
|||
self.filename = filename
|
||||
self.ext = os.path.splitext(filename)[-1].lower()
|
||||
if self.ext not in SUPPORTED_EXTS:
|
||||
raise ValueError(f"暂未支持的文件格式 {self.ext}")
|
||||
raise ValueError(f"暂未支持的文件格式 {self.filename}")
|
||||
self.loader_kwargs = loader_kwargs
|
||||
self.filepath = get_file_path(knowledge_base_name, filename)
|
||||
self.docs = None
|
||||
self.splited_docs = None
|
||||
|
|
@ -267,7 +264,9 @@ class KnowledgeFile:
|
|||
def file2docs(self, refresh: bool = False):
|
||||
if self.docs is None or refresh:
|
||||
logger.info(f"{self.document_loader_name} used for {self.filepath}")
|
||||
loader = get_loader(self.document_loader_name, self.filepath)
|
||||
loader = get_loader(loader_name=self.document_loader_name,
|
||||
file_path=self.filepath,
|
||||
loader_kwargs=self.loader_kwargs)
|
||||
self.docs = loader.load()
|
||||
return self.docs
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue