Langchain-Chatchat/server/knowledge_base/utils.py.bak

432 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceBgeEmbeddings
from configs.model_config import (
embedding_model_dict,
KB_ROOT_PATH,
CHUNK_SIZE,
OVERLAP_SIZE,
ZH_TITLE_ENHANCE,
logger,
)
from functools import lru_cache
import importlib
from text_splitter import zh_title_enhance
import langchain.document_loaders
import document_loaders
import unstructured.partition
from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter
from pathlib import Path
import json
from concurrent.futures import ThreadPoolExecutor
from server.utils import run_in_thread_pool
import io
import builtins
from datetime import datetime
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
# make HuggingFaceEmbeddings hashable
def _embeddings_hash(self):
if isinstance(self, HuggingFaceEmbeddings):
return hash(self.model_name)
elif isinstance(self, HuggingFaceBgeEmbeddings):
return hash(self.model_name)
elif isinstance(self, OpenAIEmbeddings):
return hash(self.model)
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
OpenAIEmbeddings.__hash__ = _embeddings_hash
HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash
# patch langchain.document_loaders和项目自定义document_loaders替换其中的open函数。
# 使其支持对str,bytes,io.StringIO,io.BytesIO进行向量化
def _new_open(content: Union[str, bytes, io.StringIO, io.BytesIO, Path], *args, **kw):
if isinstance(content, (io.StringIO, io.BytesIO)):
return content
if isinstance(content, str):
if os.path.isfile(content):
return builtins.open(content, *args, **kw)
else:
return io.StringIO(content)
if isinstance(content, bytes):
return io.BytesIO(bytes)
if isinstance(content, Path):
return Path.open(*args, **kw)
return open(content, *args, **kw)
for module in [langchain.document_loaders, document_loaders]:
for k, v in module.__dict__.items():
if type(v) == type(langchain.document_loaders):
v.open = _new_open
# path unstructured 使其在处理非磁盘文件时不会出错
def _new_get_last_modified_date(filename: str) -> Union[str, None]:
try:
modify_date = datetime.fromtimestamp(os.path.getmtime(filename))
return modify_date.strftime("%Y-%m-%dT%H:%M:%S%z")
except:
return None
for k, v in unstructured.partition.__dict__.items():
if type(v) == type(unstructured.partition):
v.open = _new_open
v.get_last_modified_date = _new_get_last_modified_date
def validate_kb_name(knowledge_base_id: str) -> bool:
# 检查是否包含预期外的字符或路径攻击关键字
if "../" in knowledge_base_id:
return False
return True
def get_kb_path(knowledge_base_name: str):
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
def get_doc_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "content")
def get_vs_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
def get_file_path(knowledge_base_name: str, doc_name: str):
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
def list_kbs_from_folder():
return [f for f in os.listdir(KB_ROOT_PATH)
if os.path.isdir(os.path.join(KB_ROOT_PATH, f))]
def list_files_from_folder(kb_name: str):
doc_path = get_doc_path(kb_name)
return [file for file in os.listdir(doc_path)
if os.path.isfile(os.path.join(doc_path, file))]
@lru_cache(1)
def load_embeddings(model: str, device: str):
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE)
elif 'bge-' in model:
embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model],
model_kwargs={'device': device},
query_instruction="为这个句子生成表示以用于检索相关文章:")
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
embeddings.query_instruction = ""
else:
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device})
return embeddings
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
"UnstructuredMarkdownLoader": ['.md'],
"CustomJSONLoader": [".json"],
"CSVLoader": [".csv"],
"RapidOCRPDFLoader": [".pdf"],
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
'.rtf', '.txt', '.xml',
'.doc', '.docx', '.epub', '.odt',
'.ppt', '.pptx', '.tsv'], # '.xlsx'
}
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
class CustomJSONLoader(langchain.document_loaders.JSONLoader):
'''
langchain的JSONLoader需要jq在win上使用不便进行替代。
'''
def __init__(
self,
file_path: Union[str, Path],
content_key: Optional[str] = None,
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
text_content: bool = True,
json_lines: bool = False,
):
"""Initialize the JSONLoader.
Args:
file_path (Union[str, Path]): The path to the JSON or JSON Lines file.
content_key (str): The key to use to extract the content from the JSON if
results to a list of objects (dict).
metadata_func (Callable[Dict, Dict]): A function that takes in the JSON
object extracted by the jq_schema and the default metadata and returns
a dict of the updated metadata.
text_content (bool): Boolean flag to indicate whether the content is in
string format, default to True.
json_lines (bool): Boolean flag to indicate whether the input is in
JSON Lines format.
"""
self.file_path = Path(file_path).resolve()
self._content_key = content_key
self._metadata_func = metadata_func
self._text_content = text_content
self._json_lines = json_lines
# TODO: langchain's JSONLoader.load has a encoding bug, raise gbk encoding error on windows.
# This is a workaround for langchain==0.0.266. I have make a pr(#9785) to langchain, it should be deleted after langchain upgraded.
def load(self) -> List[Document]:
"""Load and return documents from the JSON file."""
docs: List[Document] = []
if self._json_lines:
with self.file_path.open(encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
self._parse(line, docs)
else:
self._parse(self.file_path.read_text(encoding="utf-8"), docs)
return docs
def _parse(self, content: str, docs: List[Document]) -> None:
"""Convert given content to documents."""
data = json.loads(content)
# Perform some validation
# This is not a perfect validation, but it should catch most cases
# and prevent the user from getting a cryptic error later on.
if self._content_key is not None:
self._validate_content_key(data)
for i, sample in enumerate(data, len(docs) + 1):
metadata = dict(
source=str(self.file_path),
seq_num=i,
)
text = self._get_text(sample=sample, metadata=metadata)
docs.append(Document(page_content=text, metadata=metadata))
langchain.document_loaders.CustomJSONLoader = CustomJSONLoader
def get_LoaderClass(file_extension):
for LoaderClass, extensions in LOADER_DICT.items():
if file_extension in extensions:
return LoaderClass
def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.StringIO, io.BytesIO]):
'''
根据loader_name和文件路径或内容返回文档加载器。
'''
try:
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
document_loaders_module = importlib.import_module('document_loaders')
else:
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, loader_name)
except Exception as e:
logger.error(e)
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
if loader_name == "UnstructuredFileLoader":
loader = DocumentLoader(file_path_or_content, autodetect_encoding=True)
elif loader_name == "CSVLoader":
loader = DocumentLoader(file_path_or_content, encoding="utf-8")
elif loader_name == "JSONLoader":
loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False)
elif loader_name == "CustomJSONLoader":
loader = DocumentLoader(file_path_or_content, text_content=False)
elif loader_name == "UnstructuredMarkdownLoader":
loader = DocumentLoader(file_path_or_content, mode="elements")
elif loader_name == "UnstructuredHTMLLoader":
loader = DocumentLoader(file_path_or_content, mode="elements")
else:
loader = DocumentLoader(file_path_or_content)
return loader
def make_text_splitter(
splitter_name: str = "SpacyTextSplitter",
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
):
'''
根据参数获取特定的分词器
'''
splitter_name = splitter_name or "SpacyTextSplitter"
text_splitter_module = importlib.import_module('langchain.text_splitter')
try:
TextSplitter = getattr(text_splitter_module, splitter_name)
text_splitter = TextSplitter(
pipeline="zh_core_web_sm",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
except Exception as e:
logger.error(e)
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
text_splitter = TextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
return text_splitter
def content_to_docs(content: Union[str, bytes, io.StringIO, io.BytesIO, Path], ext: str = ".md") -> List[Document]:
'''
将磁盘文件、文本、字节、内存文件等转化成Document
'''
if not ext.startswith("."):
ext = "." + ext
ext = ext.lower()
if ext not in SUPPORTED_EXTS:
raise ValueError(f"暂未支持的文件格式 {ext}")
loader_name = get_LoaderClass(ext)
loader = get_loader(loader_name=loader_name, file_path_or_content=content)
return loader.load()
def split_docs(
docs: List[Document],
splitter_name: str = "SpacyTextSplitter",
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
) -> List[Document]:
text_splitter = make_text_splitter(splitter_name=splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
return text_splitter.split_documents(docs)
class KnowledgeFile:
def __init__(
self,
filename: str,
knowledge_base_name: str
):
'''
对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。
'''
self.kb_name = knowledge_base_name
self.filename = filename
self.ext = os.path.splitext(filename)[-1].lower()
if self.ext not in SUPPORTED_EXTS:
raise ValueError(f"暂未支持的文件格式 {self.ext}")
self.filepath = get_file_path(knowledge_base_name, filename)
self.docs = None
self.splited_docs = None
self.document_loader_name = get_LoaderClass(self.ext)
# TODO: 增加依据文件格式匹配text_splitter
self.text_splitter_name = None
def file2docs(self, refresh: bool=False):
if self.docs is None or refresh:
logger.info(f"{self.document_loader_name} used for {self.filepath}")
loader = get_loader(self.document_loader_name, self.filepath)
self.docs = loader.load()
return self.docs
def docs2texts(
self,
docs: List[Document] = None,
using_zh_title_enhance=ZH_TITLE_ENHANCE,
refresh: bool = False,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
text_splitter: TextSplitter = None,
):
docs = docs or self.file2docs(refresh=refresh)
if self.ext not in [".csv"]:
if text_splitter is None:
text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
docs = text_splitter.split_documents(docs)
print(f"文档切分示例:{docs[0]}")
if using_zh_title_enhance:
docs = zh_title_enhance(docs)
self.splited_docs = docs
return self.splited_docs
def file2text(
self,
using_zh_title_enhance=ZH_TITLE_ENHANCE,
refresh: bool = False,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
text_splitter: TextSplitter = None,
):
if self.splited_docs is None or refresh:
self.splited_docs = self.docs2texts(using_zh_title_enhance=using_zh_title_enhance,
refresh=refresh,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
text_splitter=text_splitter)
return self.splited_docs
def file_exist(self):
return os.path.isfile(self.filepath)
def get_mtime(self):
return os.path.getmtime(self.filepath)
def get_size(self):
return os.path.getsize(self.filepath)
def files2docs_in_thread(
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], # 如果是Tuple形式为(filename, kb_name)
pool: ThreadPoolExecutor = None,
) -> Generator:
'''
利用多线程批量将磁盘文件转化成langchain Document.
生成器返回值为 status, (kb_name, file_name, docs)
'''
def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
try:
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
except Exception as e:
return False, e
kwargs_list = []
for i, file in enumerate(files):
kwargs = {}
if isinstance(file, tuple) and len(file) >= 2:
file = KnowledgeFile(filename=file[0], knowledge_base_name=file[1])
elif isinstance(file, dict):
filename = file.pop("filename")
kb_name = file.pop("kb_name")
kwargs = file
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kwargs["file"] = file
kwargs_list.append(kwargs)
for result in run_in_thread_pool(func=file2docs, params=kwargs_list, pool=pool):
yield result
if __name__ == "__main__":
from pprint import pprint
# kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples")
# # kb_file.text_splitter_name = "RecursiveCharacterTextSplitter"
# docs = kb_file.file2docs()
# pprint(docs[-1])
# docs = kb_file.file2text()
# pprint(docs[-1])
docs = content_to_docs("""
## this is a title
## another title
how are you
this a wonderful day.
""", "txt")
pprint(docs)
pprint(split_docs(docs, chunk_size=10, chunk_overlap=0))