from typing import Union import os from langchain.embeddings.huggingface import HuggingFaceEmbeddings from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, EMBEDDING_MODEL, kbs_config) from functools import lru_cache from server.knowledge_base.kb_service.base import KBService, SupportedVSType from server.db.repository.knowledge_base_repository import load_kb_from_db from server.knowledge_base.kb_service.default_kb_service import DefaultKBService def validate_kb_name(knowledge_base_id: str) -> bool: # 检查是否包含预期外的字符或路径攻击关键字 if "../" in knowledge_base_id: return False return True def get_kb_path(knowledge_base_name: str): return os.path.join(KB_ROOT_PATH, knowledge_base_name) def get_doc_path(knowledge_base_name: str): return os.path.join(get_kb_path(knowledge_base_name), "content") def get_vs_path(knowledge_base_name: str): return os.path.join(get_kb_path(knowledge_base_name), "vector_store") def get_file_path(knowledge_base_name: str, doc_name: str): return os.path.join(get_doc_path(knowledge_base_name), doc_name) def list_docs_from_folder(kb_name: str): doc_path = get_doc_path(kb_name) return [file for file in os.listdir(doc_path) if os.path.isfile(os.path.join(doc_path, file))] @lru_cache(1) def load_embeddings(model: str, device: str): embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device}) return embeddings LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst', '.rtf', '.txt', '.xml', '.doc', '.docx', '.epub', '.odt', '.pdf', '.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv' "CSVLoader": [".csv"], } SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] def get_LoaderClass(file_extension): for LoaderClass, extensions in LOADER_DICT.items(): if file_extension in extensions: return LoaderClass class KnowledgeFile: def __init__( self, filename: str, knowledge_base_name: str ): self.kb_name = knowledge_base_name self.filename = filename self.ext = os.path.splitext(filename)[-1] if self.ext not in SUPPORTED_EXTS: raise ValueError(f"暂未支持的文件格式 {self.ext}") self.filepath = get_file_path(knowledge_base_name, filename) self.docs = None self.document_loader_name = get_LoaderClass(self.ext) # TODO: 增加依据文件格式匹配text_splitter self.text_splitter_name = "CharacterTextSplitter" def file2text(self): DocumentLoader = getattr(sys.modules['langchain.document_loaders'], self.document_loader_name) loader = DocumentLoader(self.filepath) # TODO: 增加依据文件格式匹配text_splitter TextSplitter = getattr(sys.modules['langchain.text_splitter'], self.text_splitter_name) text_splitter = TextSplitter(chunk_size=500, chunk_overlap=200) return loader.load_and_split(text_splitter) class KBServiceFactory: @staticmethod def get_service(kb_name: str, vector_store_type: Union[str, SupportedVSType], embed_model: str = EMBEDDING_MODEL, ) -> KBService: if isinstance(vector_store_type, str): vector_store_type = getattr(SupportedVSType, vector_store_type.upper()) if SupportedVSType.FAISS == vector_store_type: from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService return FaissKBService(kb_name, embed_model=embed_model) elif SupportedVSType.MILVUS == vector_store_type: from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService return MilvusKBService(kb_name, embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier. return DefaultKBService(kb_name) @staticmethod def get_service_by_name(kb_name: str ) -> KBService: kb_name, vs_type, embed_model = load_kb_from_db(kb_name) return KBServiceFactory.get_service(kb_name, vs_type, embed_model) @staticmethod def get_default(): return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT)