Langchain-Chatchat/server/knowledge_base/utils.py

235 lines
9.6 KiB
Python
Raw Normal View History

import os
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceBgeEmbeddings
from configs.model_config import (
embedding_model_dict,
KB_ROOT_PATH,
CHUNK_SIZE,
OVERLAP_SIZE,
ZH_TITLE_ENHANCE
)
from functools import lru_cache
2023-08-10 23:04:05 +08:00
import importlib
2023-08-09 23:09:24 +08:00
from text_splitter import zh_title_enhance
import langchain.document_loaders
from langchain.docstore.document import Document
from pathlib import Path
import json
from typing import List, Union, Callable, Dict, Optional
2023-07-27 23:22:07 +08:00
def validate_kb_name(knowledge_base_id: str) -> bool:
# 检查是否包含预期外的字符或路径攻击关键字
if "../" in knowledge_base_id:
return False
return True
2023-08-27 20:23:07 +08:00
def get_kb_path(knowledge_base_name: str):
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
2023-08-27 20:23:07 +08:00
def get_doc_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "content")
2023-08-27 20:23:07 +08:00
def get_vs_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
2023-08-27 20:23:07 +08:00
def get_file_path(knowledge_base_name: str, doc_name: str):
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
2023-08-27 20:23:07 +08:00
def list_kbs_from_folder():
return [f for f in os.listdir(KB_ROOT_PATH)
if os.path.isdir(os.path.join(KB_ROOT_PATH, f))]
2023-08-27 20:23:07 +08:00
def list_docs_from_folder(kb_name: str):
doc_path = get_doc_path(kb_name)
return [file for file in os.listdir(doc_path)
if os.path.isfile(os.path.join(doc_path, file))]
2023-08-27 20:23:07 +08:00
@lru_cache(1)
def load_embeddings(model: str, device: str):
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE)
elif 'bge-' in model:
embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model],
model_kwargs={'device': device},
query_instruction="为这个句子生成表示以用于检索相关文章:")
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
embeddings.query_instruction = ""
else:
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device})
return embeddings
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
"UnstructuredMarkdownLoader": ['.md'],
"CustomJSONLoader": [".json"],
"CSVLoader": [".csv"],
2023-08-09 23:36:28 +08:00
"PyPDFLoader": [".pdf"],
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
'.rtf', '.txt', '.xml',
'.doc', '.docx', '.epub', '.odt',
'.ppt', '.pptx', '.tsv'], # '.xlsx'
}
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
class CustomJSONLoader(langchain.document_loaders.JSONLoader):
'''
langchain的JSONLoader需要jq在win上使用不便进行替代
'''
def __init__(
2023-08-27 20:23:07 +08:00
self,
file_path: Union[str, Path],
content_key: Optional[str] = None,
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
text_content: bool = True,
json_lines: bool = False,
):
"""Initialize the JSONLoader.
Args:
file_path (Union[str, Path]): The path to the JSON or JSON Lines file.
content_key (str): The key to use to extract the content from the JSON if
results to a list of objects (dict).
metadata_func (Callable[Dict, Dict]): A function that takes in the JSON
object extracted by the jq_schema and the default metadata and returns
a dict of the updated metadata.
text_content (bool): Boolean flag to indicate whether the content is in
string format, default to True.
json_lines (bool): Boolean flag to indicate whether the input is in
JSON Lines format.
"""
self.file_path = Path(file_path).resolve()
self._content_key = content_key
self._metadata_func = metadata_func
self._text_content = text_content
self._json_lines = json_lines
# TODO: langchain's JSONLoader.load has a encoding bug, raise gbk encoding error on windows.
# This is a workaround for langchain==0.0.266. I have make a pr(#9785) to langchain, it should be deleted after langchain upgraded.
def load(self) -> List[Document]:
"""Load and return documents from the JSON file."""
docs: List[Document] = []
if self._json_lines:
with self.file_path.open(encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
self._parse(line, docs)
else:
self._parse(self.file_path.read_text(encoding="utf-8"), docs)
return docs
def _parse(self, content: str, docs: List[Document]) -> None:
"""Convert given content to documents."""
data = json.loads(content)
# Perform some validation
# This is not a perfect validation, but it should catch most cases
# and prevent the user from getting a cryptic error later on.
if self._content_key is not None:
self._validate_content_key(data)
for i, sample in enumerate(data, len(docs) + 1):
metadata = dict(
source=str(self.file_path),
seq_num=i,
)
text = self._get_text(sample=sample, metadata=metadata)
docs.append(Document(page_content=text, metadata=metadata))
2023-08-27 20:23:07 +08:00
langchain.document_loaders.CustomJSONLoader = CustomJSONLoader
def get_LoaderClass(file_extension):
for LoaderClass, extensions in LOADER_DICT.items():
if file_extension in extensions:
return LoaderClass
class KnowledgeFile:
def __init__(
self,
filename: str,
knowledge_base_name: str
):
self.kb_name = knowledge_base_name
self.filename = filename
self.ext = os.path.splitext(filename)[-1].lower()
if self.ext not in SUPPORTED_EXTS:
raise ValueError(f"暂未支持的文件格式 {self.ext}")
self.filepath = get_file_path(knowledge_base_name, filename)
self.docs = None
self.document_loader_name = get_LoaderClass(self.ext)
# TODO: 增加依据文件格式匹配text_splitter
2023-08-09 23:36:28 +08:00
self.text_splitter_name = None
def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE):
print(self.document_loader_name)
try:
2023-08-10 23:04:05 +08:00
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, self.document_loader_name)
except Exception as e:
print(e)
2023-08-10 23:04:05 +08:00
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
if self.document_loader_name == "UnstructuredFileLoader":
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
elif self.document_loader_name == "CSVLoader":
loader = DocumentLoader(self.filepath, encoding="utf-8")
elif self.document_loader_name == "JSONLoader":
loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False)
elif self.document_loader_name == "CustomJSONLoader":
loader = DocumentLoader(self.filepath, text_content=False)
elif self.document_loader_name == "UnstructuredMarkdownLoader":
2023-08-27 20:23:07 +08:00
loader = DocumentLoader(self.filepath, mode="elements") # TODO: 需要在实践中测试`elements`是否优于`single`
elif self.document_loader_name == "UnstructuredHTMLLoader":
2023-08-27 20:23:07 +08:00
loader = DocumentLoader(self.filepath, mode="elements") # TODO: 需要在实践中测试`elements`是否优于`single`
else:
loader = DocumentLoader(self.filepath)
if self.ext in ".csv":
docs = loader.load()
else:
try:
if self.text_splitter_name is None:
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, "SpacyTextSplitter")
text_splitter = TextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
)
self.text_splitter_name = "SpacyTextSplitter"
else:
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, self.text_splitter_name)
text_splitter = TextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE)
except Exception as e:
print(e)
2023-08-10 23:04:05 +08:00
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
2023-08-09 23:36:28 +08:00
text_splitter = TextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
)
docs = loader.load_and_split(text_splitter)
print(docs[0])
2023-08-09 23:09:24 +08:00
if using_zh_title_enhance:
docs = zh_title_enhance(docs)
return docs