开发者: (#2091)
- 修复列出知识库磁盘文件时跳过临时文件的bug:只有目录被排除了,文件未排除 - 优化知识库文档加载器: - 将 elements 模式改为 single 模式,避免文档被切分得太碎 - 给 get_loader 和 KnowledgeFile 增加 loader_kwargs 参数,可以自定义文档加载器参数
This commit is contained in:
parent
68a544ea33
commit
ad7a6fd438
|
|
@ -51,16 +51,22 @@ def list_kbs_from_folder():
|
||||||
|
|
||||||
|
|
||||||
def list_files_from_folder(kb_name: str):
|
def list_files_from_folder(kb_name: str):
|
||||||
|
def is_skiped_path(path: str): # 跳过 [temp, tmp, ., ~$] 开头的目录和文件
|
||||||
|
tail = os.path.basename(path).lower()
|
||||||
|
flag = False
|
||||||
|
for x in ["temp", "tmp", ".", "~$"]:
|
||||||
|
if tail.startswith(x):
|
||||||
|
flag = True
|
||||||
|
break
|
||||||
|
return flag
|
||||||
|
|
||||||
doc_path = get_doc_path(kb_name)
|
doc_path = get_doc_path(kb_name)
|
||||||
result = []
|
result = []
|
||||||
for root, _, files in os.walk(doc_path):
|
for root, _, files in os.walk(doc_path):
|
||||||
tail = os.path.basename(root).lower()
|
if is_skiped_path(root):
|
||||||
if (tail.startswith("temp")
|
|
||||||
or tail.startswith("tmp")
|
|
||||||
or tail.startswith(".")): # 跳过 [temp, tmp, .] 开头的文件夹
|
|
||||||
continue
|
continue
|
||||||
for file in files:
|
for file in files:
|
||||||
if file.startswith("~$"): # 跳过 ~$ 开头的文件
|
if is_skiped_path(file):
|
||||||
continue
|
continue
|
||||||
path = Path(doc_path) / root / file
|
path = Path(doc_path) / root / file
|
||||||
result.append(path.resolve().relative_to(doc_path).as_posix())
|
result.append(path.resolve().relative_to(doc_path).as_posix())
|
||||||
|
|
@ -114,10 +120,11 @@ def get_LoaderClass(file_extension):
|
||||||
|
|
||||||
|
|
||||||
# 把一些向量化共用逻辑从KnowledgeFile抽取出来,等langchain支持内存文件的时候,可以将非磁盘文件向量化
|
# 把一些向量化共用逻辑从KnowledgeFile抽取出来,等langchain支持内存文件的时候,可以将非磁盘文件向量化
|
||||||
def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.StringIO, io.BytesIO]):
|
def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
|
||||||
'''
|
'''
|
||||||
根据loader_name和文件路径或内容返回文档加载器。
|
根据loader_name和文件路径或内容返回文档加载器。
|
||||||
'''
|
'''
|
||||||
|
loader_kwargs = loader_kwargs or {}
|
||||||
try:
|
try:
|
||||||
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader","FilteredCSVLoader"]:
|
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader","FilteredCSVLoader"]:
|
||||||
document_loaders_module = importlib.import_module('document_loaders')
|
document_loaders_module = importlib.import_module('document_loaders')
|
||||||
|
|
@ -125,44 +132,32 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri
|
||||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||||
DocumentLoader = getattr(document_loaders_module, loader_name)
|
DocumentLoader = getattr(document_loaders_module, loader_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"为文件{file_path_or_content}查找加载器{loader_name}时出错:{e}"
|
msg = f"为文件{file_path}查找加载器{loader_name}时出错:{e}"
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
exc_info=e if log_verbose else None)
|
exc_info=e if log_verbose else None)
|
||||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||||
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
|
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
|
||||||
|
|
||||||
if loader_name == "UnstructuredFileLoader":
|
if loader_name == "UnstructuredFileLoader":
|
||||||
loader = DocumentLoader(file_path_or_content, autodetect_encoding=True)
|
loader_kwargs.setdefault("autodetect_encoding", True)
|
||||||
elif loader_name == "CSVLoader":
|
elif loader_name == "CSVLoader":
|
||||||
# 自动识别文件编码类型,避免langchain loader 加载文件报编码错误
|
if not loader_kwargs.get("encoding"):
|
||||||
with open(file_path_or_content, 'rb') as struct_file:
|
# 如果未指定 encoding,自动识别文件编码类型,避免langchain loader 加载文件报编码错误
|
||||||
encode_detect = chardet.detect(struct_file.read())
|
with open(file_path, 'rb') as struct_file:
|
||||||
if encode_detect is None:
|
encode_detect = chardet.detect(struct_file.read())
|
||||||
encode_detect = {"encoding": "utf-8"}
|
if encode_detect is None:
|
||||||
|
encode_detect = {"encoding": "utf-8"}
|
||||||
loader = DocumentLoader(file_path_or_content, encoding=encode_detect["encoding"])
|
loader_kwargs["encoding"] = encode_detect["encoding"]
|
||||||
## TODO:支持更多的自定义CSV读取逻辑
|
## TODO:支持更多的自定义CSV读取逻辑
|
||||||
|
|
||||||
elif loader_name == "JSONLoader":
|
elif loader_name == "JSONLoader":
|
||||||
loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False)
|
loader_kwargs.setdefault("jq_schema", ".")
|
||||||
|
loader_kwargs.setdefault("text_content", False)
|
||||||
elif loader_name == "JSONLinesLoader":
|
elif loader_name == "JSONLinesLoader":
|
||||||
loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False, json_lines=True)
|
loader_kwargs.setdefault("jq_schema", ".")
|
||||||
elif loader_name == "UnstructuredMarkdownLoader":
|
loader_kwargs.setdefault("text_content", False)
|
||||||
loader = DocumentLoader(file_path_or_content, mode="elements")
|
|
||||||
elif loader_name == "UnstructuredHTMLLoader":
|
loader = DocumentLoader(file_path, **loader_kwargs)
|
||||||
loader = DocumentLoader(file_path_or_content, mode="elements")
|
|
||||||
elif loader_name == "UnstructuredXMLLoader":
|
|
||||||
loader = DocumentLoader(file_path_or_content, mode="elements")
|
|
||||||
elif loader_name == "UnstructuredRSTLoader":
|
|
||||||
loader = DocumentLoader(file_path_or_content, mode="elements")
|
|
||||||
elif loader_name == "UnstructuredExcelLoader":
|
|
||||||
loader = DocumentLoader(file_path_or_content, mode="elements")
|
|
||||||
elif loader_name == "UnstructuredWordDocumentLoader":
|
|
||||||
loader = DocumentLoader(file_path_or_content, mode="elements")
|
|
||||||
elif loader_name == "UnstructuredPowerPointLoader":
|
|
||||||
loader = DocumentLoader(file_path_or_content, mode="elements")
|
|
||||||
else:
|
|
||||||
loader = DocumentLoader(file_path_or_content)
|
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -248,7 +243,8 @@ class KnowledgeFile:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
filename: str,
|
filename: str,
|
||||||
knowledge_base_name: str
|
knowledge_base_name: str,
|
||||||
|
loader_kwargs: Dict = {},
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。
|
对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。
|
||||||
|
|
@ -257,7 +253,8 @@ class KnowledgeFile:
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
self.ext = os.path.splitext(filename)[-1].lower()
|
self.ext = os.path.splitext(filename)[-1].lower()
|
||||||
if self.ext not in SUPPORTED_EXTS:
|
if self.ext not in SUPPORTED_EXTS:
|
||||||
raise ValueError(f"暂未支持的文件格式 {self.ext}")
|
raise ValueError(f"暂未支持的文件格式 {self.filename}")
|
||||||
|
self.loader_kwargs = loader_kwargs
|
||||||
self.filepath = get_file_path(knowledge_base_name, filename)
|
self.filepath = get_file_path(knowledge_base_name, filename)
|
||||||
self.docs = None
|
self.docs = None
|
||||||
self.splited_docs = None
|
self.splited_docs = None
|
||||||
|
|
@ -267,7 +264,9 @@ class KnowledgeFile:
|
||||||
def file2docs(self, refresh: bool = False):
|
def file2docs(self, refresh: bool = False):
|
||||||
if self.docs is None or refresh:
|
if self.docs is None or refresh:
|
||||||
logger.info(f"{self.document_loader_name} used for {self.filepath}")
|
logger.info(f"{self.document_loader_name} used for {self.filepath}")
|
||||||
loader = get_loader(self.document_loader_name, self.filepath)
|
loader = get_loader(loader_name=self.document_loader_name,
|
||||||
|
file_path=self.filepath,
|
||||||
|
loader_kwargs=self.loader_kwargs)
|
||||||
self.docs = loader.load()
|
self.docs = loader.load()
|
||||||
return self.docs
|
return self.docs
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue