开发者: (#2091)

- 修复列出知识库磁盘文件时跳过临时文件的bug:只有目录被排除了,文件未排除
- 优化知识库文档加载器:
  - 将 elements 模式改为 single 模式,避免文档被切分得太碎
  - 给 get_loader 和 KnowledgeFile 增加 loader_kwargs 参数,可以自定义文档加载器参数
This commit is contained in:
liunux4odoo 2023-11-17 11:39:32 +08:00 committed by GitHub
parent 68a544ea33
commit ad7a6fd438
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 35 additions and 36 deletions

View File

@ -51,16 +51,22 @@ def list_kbs_from_folder():
def list_files_from_folder(kb_name: str):
def is_skiped_path(path: str): # 跳过 [temp, tmp, ., ~$] 开头的目录和文件
tail = os.path.basename(path).lower()
flag = False
for x in ["temp", "tmp", ".", "~$"]:
if tail.startswith(x):
flag = True
break
return flag
doc_path = get_doc_path(kb_name)
result = []
for root, _, files in os.walk(doc_path):
tail = os.path.basename(root).lower()
if (tail.startswith("temp")
or tail.startswith("tmp")
or tail.startswith(".")): # 跳过 [temp, tmp, .] 开头的文件夹
if is_skiped_path(root):
continue
for file in files:
if file.startswith("~$"): # 跳过 ~$ 开头的文件
if is_skiped_path(file):
continue
path = Path(doc_path) / root / file
result.append(path.resolve().relative_to(doc_path).as_posix())
@ -114,10 +120,11 @@ def get_LoaderClass(file_extension):
# 把一些向量化共用逻辑从KnowledgeFile抽取出来等langchain支持内存文件的时候可以将非磁盘文件向量化
def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.StringIO, io.BytesIO]):
def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
'''
根据loader_name和文件路径或内容返回文档加载器
'''
loader_kwargs = loader_kwargs or {}
try:
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader","FilteredCSVLoader"]:
document_loaders_module = importlib.import_module('document_loaders')
@ -125,44 +132,32 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, loader_name)
except Exception as e:
msg = f"为文件{file_path_or_content}查找加载器{loader_name}时出错:{e}"
msg = f"为文件{file_path}查找加载器{loader_name}时出错:{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
if loader_name == "UnstructuredFileLoader":
loader = DocumentLoader(file_path_or_content, autodetect_encoding=True)
loader_kwargs.setdefault("autodetect_encoding", True)
elif loader_name == "CSVLoader":
# 自动识别文件编码类型避免langchain loader 加载文件报编码错误
with open(file_path_or_content, 'rb') as struct_file:
encode_detect = chardet.detect(struct_file.read())
if encode_detect is None:
encode_detect = {"encoding": "utf-8"}
loader = DocumentLoader(file_path_or_content, encoding=encode_detect["encoding"])
if not loader_kwargs.get("encoding"):
# 如果未指定 encoding自动识别文件编码类型避免langchain loader 加载文件报编码错误
with open(file_path, 'rb') as struct_file:
encode_detect = chardet.detect(struct_file.read())
if encode_detect is None:
encode_detect = {"encoding": "utf-8"}
loader_kwargs["encoding"] = encode_detect["encoding"]
## TODO支持更多的自定义CSV读取逻辑
elif loader_name == "JSONLoader":
loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False)
loader_kwargs.setdefault("jq_schema", ".")
loader_kwargs.setdefault("text_content", False)
elif loader_name == "JSONLinesLoader":
loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False, json_lines=True)
elif loader_name == "UnstructuredMarkdownLoader":
loader = DocumentLoader(file_path_or_content, mode="elements")
elif loader_name == "UnstructuredHTMLLoader":
loader = DocumentLoader(file_path_or_content, mode="elements")
elif loader_name == "UnstructuredXMLLoader":
loader = DocumentLoader(file_path_or_content, mode="elements")
elif loader_name == "UnstructuredRSTLoader":
loader = DocumentLoader(file_path_or_content, mode="elements")
elif loader_name == "UnstructuredExcelLoader":
loader = DocumentLoader(file_path_or_content, mode="elements")
elif loader_name == "UnstructuredWordDocumentLoader":
loader = DocumentLoader(file_path_or_content, mode="elements")
elif loader_name == "UnstructuredPowerPointLoader":
loader = DocumentLoader(file_path_or_content, mode="elements")
else:
loader = DocumentLoader(file_path_or_content)
loader_kwargs.setdefault("jq_schema", ".")
loader_kwargs.setdefault("text_content", False)
loader = DocumentLoader(file_path, **loader_kwargs)
return loader
@ -248,7 +243,8 @@ class KnowledgeFile:
def __init__(
self,
filename: str,
knowledge_base_name: str
knowledge_base_name: str,
loader_kwargs: Dict = {},
):
'''
对应知识库目录中的文件必须是磁盘上存在的才能进行向量化等操作
@ -257,7 +253,8 @@ class KnowledgeFile:
self.filename = filename
self.ext = os.path.splitext(filename)[-1].lower()
if self.ext not in SUPPORTED_EXTS:
raise ValueError(f"暂未支持的文件格式 {self.ext}")
raise ValueError(f"暂未支持的文件格式 {self.filename}")
self.loader_kwargs = loader_kwargs
self.filepath = get_file_path(knowledge_base_name, filename)
self.docs = None
self.splited_docs = None
@ -267,7 +264,9 @@ class KnowledgeFile:
def file2docs(self, refresh: bool = False):
if self.docs is None or refresh:
logger.info(f"{self.document_loader_name} used for {self.filepath}")
loader = get_loader(self.document_loader_name, self.filepath)
loader = get_loader(loader_name=self.document_loader_name,
file_path=self.filepath,
loader_kwargs=self.loader_kwargs)
self.docs = loader.load()
return self.docs