2023-08-08 16:21:00 +08:00
|
|
|
from typing import Union
|
2023-08-08 13:36:20 +08:00
|
|
|
import os
|
|
|
|
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
2023-08-08 16:21:00 +08:00
|
|
|
from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, EMBEDDING_MODEL, kbs_config)
|
2023-08-08 13:36:20 +08:00
|
|
|
from functools import lru_cache
|
2023-08-09 16:52:04 +08:00
|
|
|
import langchain.document_loaders
|
2023-08-08 17:59:41 +08:00
|
|
|
import sys
|
2023-08-08 13:36:20 +08:00
|
|
|
|
2023-08-06 23:43:54 +08:00
|
|
|
|
2023-07-27 23:22:07 +08:00
|
|
|
def validate_kb_name(knowledge_base_id: str) -> bool:
|
|
|
|
|
# 检查是否包含预期外的字符或路径攻击关键字
|
|
|
|
|
if "../" in knowledge_base_id:
|
|
|
|
|
return False
|
|
|
|
|
return True
|
2023-08-06 23:43:54 +08:00
|
|
|
|
|
|
|
|
def get_kb_path(knowledge_base_name: str):
|
|
|
|
|
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
|
|
|
|
|
|
|
|
|
|
def get_doc_path(knowledge_base_name: str):
|
|
|
|
|
return os.path.join(get_kb_path(knowledge_base_name), "content")
|
|
|
|
|
|
|
|
|
|
def get_vs_path(knowledge_base_name: str):
|
|
|
|
|
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
|
|
|
|
|
|
|
|
|
|
def get_file_path(knowledge_base_name: str, doc_name: str):
|
2023-08-08 13:36:20 +08:00
|
|
|
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
|
|
|
|
|
|
2023-08-08 17:41:58 +08:00
|
|
|
def list_kbs_from_folder():
|
|
|
|
|
return [f for f in os.listdir(KB_ROOT_PATH)
|
|
|
|
|
if os.path.isdir(os.path.join(KB_ROOT_PATH, f))]
|
|
|
|
|
|
2023-08-08 16:40:18 +08:00
|
|
|
def list_docs_from_folder(kb_name: str):
|
|
|
|
|
doc_path = get_doc_path(kb_name)
|
|
|
|
|
return [file for file in os.listdir(doc_path)
|
|
|
|
|
if os.path.isfile(os.path.join(doc_path, file))]
|
2023-08-08 16:21:00 +08:00
|
|
|
|
2023-08-08 13:36:20 +08:00
|
|
|
@lru_cache(1)
|
|
|
|
|
def load_embeddings(model: str, device: str):
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
|
|
|
|
|
model_kwargs={'device': device})
|
|
|
|
|
return embeddings
|
2023-08-08 16:21:00 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst',
|
|
|
|
|
'.rtf', '.txt', '.xml',
|
|
|
|
|
'.doc', '.docx', '.epub', '.odt', '.pdf',
|
|
|
|
|
'.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv'
|
|
|
|
|
"CSVLoader": [".csv"],
|
|
|
|
|
}
|
|
|
|
|
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
|
|
|
|
|
|
|
|
|
def get_LoaderClass(file_extension):
|
|
|
|
|
for LoaderClass, extensions in LOADER_DICT.items():
|
|
|
|
|
if file_extension in extensions:
|
|
|
|
|
return LoaderClass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class KnowledgeFile:
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
filename: str,
|
|
|
|
|
knowledge_base_name: str
|
|
|
|
|
):
|
|
|
|
|
self.kb_name = knowledge_base_name
|
|
|
|
|
self.filename = filename
|
|
|
|
|
self.ext = os.path.splitext(filename)[-1]
|
|
|
|
|
if self.ext not in SUPPORTED_EXTS:
|
|
|
|
|
raise ValueError(f"暂未支持的文件格式 {self.ext}")
|
|
|
|
|
self.filepath = get_file_path(knowledge_base_name, filename)
|
|
|
|
|
self.docs = None
|
|
|
|
|
self.document_loader_name = get_LoaderClass(self.ext)
|
|
|
|
|
|
|
|
|
|
# TODO: 增加依据文件格式匹配text_splitter
|
|
|
|
|
self.text_splitter_name = "CharacterTextSplitter"
|
|
|
|
|
|
|
|
|
|
def file2text(self):
|
|
|
|
|
DocumentLoader = getattr(sys.modules['langchain.document_loaders'], self.document_loader_name)
|
|
|
|
|
loader = DocumentLoader(self.filepath)
|
|
|
|
|
|
|
|
|
|
# TODO: 增加依据文件格式匹配text_splitter
|
|
|
|
|
TextSplitter = getattr(sys.modules['langchain.text_splitter'], self.text_splitter_name)
|
|
|
|
|
text_splitter = TextSplitter(chunk_size=500, chunk_overlap=200)
|
|
|
|
|
return loader.load_and_split(text_splitter)
|