from typing import Union import os from langchain.embeddings.huggingface import HuggingFaceEmbeddings from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, EMBEDDING_MODEL, kbs_config) from functools import lru_cache def validate_kb_name(knowledge_base_id: str) -> bool: # 检查是否包含预期外的字符或路径攻击关键字 if "../" in knowledge_base_id: return False return True def get_kb_path(knowledge_base_name: str): return os.path.join(KB_ROOT_PATH, knowledge_base_name) def get_doc_path(knowledge_base_name: str): return os.path.join(get_kb_path(knowledge_base_name), "content") def get_vs_path(knowledge_base_name: str): return os.path.join(get_kb_path(knowledge_base_name), "vector_store") def get_file_path(knowledge_base_name: str, doc_name: str): return os.path.join(get_doc_path(knowledge_base_name), doc_name) def list_kbs_from_folder(): return [f for f in os.listdir(KB_ROOT_PATH) if os.path.isdir(os.path.join(KB_ROOT_PATH, f))] def list_docs_from_folder(kb_name: str): doc_path = get_doc_path(kb_name) return [file for file in os.listdir(doc_path) if os.path.isfile(os.path.join(doc_path, file))] @lru_cache(1) def load_embeddings(model: str, device: str): embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device}) return embeddings LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst', '.rtf', '.txt', '.xml', '.doc', '.docx', '.epub', '.odt', '.pdf', '.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv' "CSVLoader": [".csv"], } SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] def get_LoaderClass(file_extension): for LoaderClass, extensions in LOADER_DICT.items(): if file_extension in extensions: return LoaderClass class KnowledgeFile: def __init__( self, filename: str, knowledge_base_name: str ): self.kb_name = knowledge_base_name self.filename = filename self.ext = os.path.splitext(filename)[-1] if self.ext not in SUPPORTED_EXTS: raise ValueError(f"暂未支持的文件格式 {self.ext}") self.filepath = get_file_path(knowledge_base_name, filename) self.docs = None self.document_loader_name = get_LoaderClass(self.ext) # TODO: 增加依据文件格式匹配text_splitter self.text_splitter_name = "CharacterTextSplitter" def file2text(self): DocumentLoader = getattr(sys.modules['langchain.document_loaders'], self.document_loader_name) loader = DocumentLoader(self.filepath) # TODO: 增加依据文件格式匹配text_splitter TextSplitter = getattr(sys.modules['langchain.text_splitter'], self.text_splitter_name) text_splitter = TextSplitter(chunk_size=500, chunk_overlap=200) return loader.load_and_split(text_splitter)