Langchain-Chatchat/server/knowledge_base/utils.py

126 lines
4.8 KiB
Python
Raw Normal View History

import os
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from configs.model_config import (
embedding_model_dict,
KB_ROOT_PATH,
CHUNK_SIZE,
OVERLAP_SIZE,
ZH_TITLE_ENHANCE
)
from functools import lru_cache
2023-08-10 23:04:05 +08:00
import importlib
2023-08-09 23:09:24 +08:00
from text_splitter import zh_title_enhance
2023-07-27 23:22:07 +08:00
def validate_kb_name(knowledge_base_id: str) -> bool:
# 检查是否包含预期外的字符或路径攻击关键字
if "../" in knowledge_base_id:
return False
return True
def get_kb_path(knowledge_base_name: str):
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
def get_doc_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "content")
def get_vs_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
def get_file_path(knowledge_base_name: str, doc_name: str):
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
def list_kbs_from_folder():
return [f for f in os.listdir(KB_ROOT_PATH)
if os.path.isdir(os.path.join(KB_ROOT_PATH, f))]
def list_docs_from_folder(kb_name: str):
doc_path = get_doc_path(kb_name)
return [file for file in os.listdir(doc_path)
if os.path.isfile(os.path.join(doc_path, file))]
@lru_cache(1)
def load_embeddings(model: str, device: str):
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
model_kwargs={'device': device})
return embeddings
LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst',
'.rtf', '.txt', '.xml',
'.doc', '.docx', '.epub', '.odt', '.pdf',
'.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv'
"CSVLoader": [".csv"],
2023-08-09 23:36:28 +08:00
"PyPDFLoader": [".pdf"],
}
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
def get_LoaderClass(file_extension):
for LoaderClass, extensions in LOADER_DICT.items():
if file_extension in extensions:
return LoaderClass
class KnowledgeFile:
def __init__(
self,
filename: str,
knowledge_base_name: str
):
self.kb_name = knowledge_base_name
self.filename = filename
self.ext = os.path.splitext(filename)[-1]
if self.ext not in SUPPORTED_EXTS:
raise ValueError(f"暂未支持的文件格式 {self.ext}")
self.filepath = get_file_path(knowledge_base_name, filename)
self.docs = None
self.document_loader_name = get_LoaderClass(self.ext)
# TODO: 增加依据文件格式匹配text_splitter
2023-08-09 23:36:28 +08:00
self.text_splitter_name = None
def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE):
print(self.document_loader_name)
try:
2023-08-10 23:04:05 +08:00
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, self.document_loader_name)
except Exception as e:
print(e)
2023-08-10 23:04:05 +08:00
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
if self.document_loader_name == "UnstructuredFileLoader":
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
else:
loader = DocumentLoader(self.filepath)
2023-08-09 23:36:28 +08:00
try:
if self.text_splitter_name is None:
2023-08-10 23:04:05 +08:00
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, "SpacyTextSplitter")
2023-08-09 23:36:28 +08:00
text_splitter = TextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
)
self.text_splitter_name = "SpacyTextSplitter"
2023-08-09 23:36:28 +08:00
else:
2023-08-10 23:04:05 +08:00
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, self.text_splitter_name)
2023-08-09 23:36:28 +08:00
text_splitter = TextSplitter(
chunk_size=CHUNK_SIZE,
2023-08-10 23:04:05 +08:00
chunk_overlap=OVERLAP_SIZE)
2023-08-09 23:36:28 +08:00
except Exception as e:
print(e)
2023-08-10 23:04:05 +08:00
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
2023-08-09 23:36:28 +08:00
text_splitter = TextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
)
2023-08-09 23:09:24 +08:00
docs = loader.load_and_split(text_splitter)
print(docs[0])
2023-08-09 23:09:24 +08:00
if using_zh_title_enhance:
docs = zh_title_enhance(docs)
return docs