Co-authored-by: liunux4odoo <liunu@qq.com>
This commit is contained in:
parent
f95d41ef47
commit
9ca53fa3ad
|
|
@ -12,6 +12,11 @@ from configs.model_config import (
|
|||
from functools import lru_cache
|
||||
import importlib
|
||||
from text_splitter import zh_title_enhance
|
||||
import langchain.document_loaders
|
||||
from langchain.docstore.document import Document
|
||||
from pathlib import Path
|
||||
import json
|
||||
from typing import List, Union, Callable, Dict, Optional
|
||||
|
||||
|
||||
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||
|
|
@ -57,15 +62,88 @@ def load_embeddings(model: str, device: str):
|
|||
|
||||
|
||||
|
||||
LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst',
|
||||
'.rtf', '.txt', '.xml',
|
||||
'.doc', '.docx', '.epub', '.odt', '.pdf',
|
||||
'.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv'
|
||||
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
||||
"UnstructuredMarkdownLoader": ['.md'],
|
||||
"CustomJSONLoader": [".json"],
|
||||
"CSVLoader": [".csv"],
|
||||
"PyPDFLoader": [".pdf"],
|
||||
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
|
||||
'.rtf', '.txt', '.xml',
|
||||
'.doc', '.docx', '.epub', '.odt',
|
||||
'.ppt', '.pptx', '.tsv'], # '.xlsx'
|
||||
}
|
||||
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
||||
|
||||
|
||||
class CustomJSONLoader(langchain.document_loaders.JSONLoader):
|
||||
'''
|
||||
langchain的JSONLoader需要jq,在win上使用不便,进行替代。
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: Union[str, Path],
|
||||
content_key: Optional[str] = None,
|
||||
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
|
||||
text_content: bool = True,
|
||||
json_lines: bool = False,
|
||||
):
|
||||
"""Initialize the JSONLoader.
|
||||
|
||||
Args:
|
||||
file_path (Union[str, Path]): The path to the JSON or JSON Lines file.
|
||||
content_key (str): The key to use to extract the content from the JSON if
|
||||
results to a list of objects (dict).
|
||||
metadata_func (Callable[Dict, Dict]): A function that takes in the JSON
|
||||
object extracted by the jq_schema and the default metadata and returns
|
||||
a dict of the updated metadata.
|
||||
text_content (bool): Boolean flag to indicate whether the content is in
|
||||
string format, default to True.
|
||||
json_lines (bool): Boolean flag to indicate whether the input is in
|
||||
JSON Lines format.
|
||||
"""
|
||||
self.file_path = Path(file_path).resolve()
|
||||
self._content_key = content_key
|
||||
self._metadata_func = metadata_func
|
||||
self._text_content = text_content
|
||||
self._json_lines = json_lines
|
||||
|
||||
# TODO: langchain's JSONLoader.load has a encoding bug, raise gbk encoding error on windows.
|
||||
# This is a workaround for langchain==0.0.266. I have make a pr(#9785) to langchain, it should be deleted after langchain upgraded.
|
||||
def load(self) -> List[Document]:
|
||||
"""Load and return documents from the JSON file."""
|
||||
docs: List[Document] = []
|
||||
if self._json_lines:
|
||||
with self.file_path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
self._parse(line, docs)
|
||||
else:
|
||||
self._parse(self.file_path.read_text(encoding="utf-8"), docs)
|
||||
return docs
|
||||
|
||||
def _parse(self, content: str, docs: List[Document]) -> None:
|
||||
"""Convert given content to documents."""
|
||||
data = json.loads(content)
|
||||
|
||||
# Perform some validation
|
||||
# This is not a perfect validation, but it should catch most cases
|
||||
# and prevent the user from getting a cryptic error later on.
|
||||
if self._content_key is not None:
|
||||
self._validate_content_key(data)
|
||||
|
||||
for i, sample in enumerate(data, len(docs) + 1):
|
||||
metadata = dict(
|
||||
source=str(self.file_path),
|
||||
seq_num=i,
|
||||
)
|
||||
text = self._get_text(sample=sample, metadata=metadata)
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
langchain.document_loaders.CustomJSONLoader = CustomJSONLoader
|
||||
|
||||
|
||||
def get_LoaderClass(file_extension):
|
||||
for LoaderClass, extensions in LOADER_DICT.items():
|
||||
if file_extension in extensions:
|
||||
|
|
@ -101,6 +179,16 @@ class KnowledgeFile:
|
|||
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
|
||||
if self.document_loader_name == "UnstructuredFileLoader":
|
||||
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
|
||||
elif self.document_loader_name == "CSVLoader":
|
||||
loader = DocumentLoader(self.filepath, encoding="utf-8")
|
||||
elif self.document_loader_name == "JSONLoader":
|
||||
loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False)
|
||||
elif self.document_loader_name == "CustomJSONLoader":
|
||||
loader = DocumentLoader(self.filepath, text_content=False)
|
||||
elif self.document_loader_name == "UnstructuredMarkdownLoader":
|
||||
loader = DocumentLoader(self.filepath, mode="elements") # TODO: 需要在实践中测试`elements`是否优于`single`
|
||||
elif self.document_loader_name == "UnstructuredHTMLLoader":
|
||||
loader = DocumentLoader(self.filepath, mode="elements") # TODO: 需要在实践中测试`elements`是否优于`single`
|
||||
else:
|
||||
loader = DocumentLoader(self.filepath)
|
||||
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ def knowledge_base_page(api: ApiRequest):
|
|||
|
||||
# 上传文件
|
||||
# sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True)
|
||||
files = st.file_uploader("上传知识文件",
|
||||
files = st.file_uploader("上传知识文件(暂不支持扫描PDF)",
|
||||
[i for ls in LOADER_DICT.values() for i in ls],
|
||||
accept_multiple_files=True,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue