diff --git a/server/knowledge_base/__init__.py b/server/knowledge_base/__init__.py index b64e921..4556f7d 100644 --- a/server/knowledge_base/__init__.py +++ b/server/knowledge_base/__init__.py @@ -1,4 +1,3 @@ from .kb_api import list_kbs, create_kb, delete_kb from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store -from .knowledge_file import KnowledgeFile -from .knowledge_base_factory import KBServiceFactory +from .utils import KnowledgeFile, KBServiceFactory diff --git a/server/knowledge_base/kb_api.py b/server/knowledge_base/kb_api.py index 9a92ea6..3abee4b 100644 --- a/server/knowledge_base/kb_api.py +++ b/server/knowledge_base/kb_api.py @@ -1,7 +1,6 @@ import urllib from server.utils import BaseResponse, ListResponse -from server.knowledge_base.utils import validate_kb_name -from server.knowledge_base.knowledge_base_factory import KBServiceFactory +from server.knowledge_base.utils import validate_kb_name, KBServiceFactory from server.knowledge_base.kb_service.base import list_kbs_from_db from configs.model_config import EMBEDDING_MODEL diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 2ebe828..fc8a782 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -5,8 +5,7 @@ from server.utils import BaseResponse, ListResponse from server.knowledge_base.utils import (validate_kb_name) from fastapi.responses import StreamingResponse import json -from server.knowledge_base.knowledge_file import KnowledgeFile -from server.knowledge_base.knowledge_base_factory import KBServiceFactory +from server.knowledge_base.utils import KnowledgeFile, KBServiceFactory from server.knowledge_base.kb_service.base import SupportedVSType, list_docs_from_folder from server.knowledge_base.kb_service.faiss_kb_service import refresh_vs_cache diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 277f5a7..6f78649 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -1,22 +1,18 @@ from abc import ABC, abstractmethod import os -import sqlite3 from functools import lru_cache from langchain.embeddings import HuggingFaceEmbeddings from langchain.embeddings.base import Embeddings from langchain.docstore.document import Document -from configs.model_config import (VECTOR_SEARCH_TOP_K, - embedding_model_dict, EMBEDDING_DEVICE, EMBEDDING_MODEL, - kbs_config) - from server.db.repository.knowledge_base_repository import add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists from server.db.repository.knowledge_file_repository import add_doc_to_db, delete_file_from_db, doc_exists, \ list_docs_from_db -from server.knowledge_base.utils import (get_kb_path, get_doc_path) -from server.knowledge_base.knowledge_file import KnowledgeFile +from configs.model_config import (DB_ROOT_PATH, kbs_config, VECTOR_SEARCH_TOP_K, + embedding_model_dict, EMBEDDING_DEVICE, EMBEDDING_MODEL) +from server.knowledge_base.utils import (get_kb_path, get_doc_path, KnowledgeFile) from typing import List diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 5e9e8b7..9798206 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -4,8 +4,7 @@ import shutil from configs.model_config import KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_DEVICE from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings from functools import lru_cache -from server.knowledge_base.utils import get_vs_path -from server.knowledge_base.knowledge_file import KnowledgeFile +from server.knowledge_base.utils import get_vs_path, KnowledgeFile from langchain.vectorstores import FAISS from langchain.embeddings.base import Embeddings from typing import List diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index 4854a9d..641a43b 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -5,8 +5,9 @@ from langchain.schema import Document from langchain.vectorstores import Milvus from configs.model_config import EMBEDDING_DEVICE, kbs_config -from server.knowledge_base import KnowledgeFile + from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings +from server.knowledge_base.utils import KnowledgeFile class MilvusKBService(KBService): diff --git a/server/knowledge_base/knowledge_base_factory.py b/server/knowledge_base/knowledge_base_factory.py deleted file mode 100644 index b45b6d6..0000000 --- a/server/knowledge_base/knowledge_base_factory.py +++ /dev/null @@ -1,42 +0,0 @@ -from server.knowledge_base.kb_service.base import KBService, SupportedVSType -from server.db.repository.knowledge_base_repository import load_kb_from_db -from server.knowledge_base.kb_service.default_kb_service import DefaultKBService - - -class KBServiceFactory: - - @staticmethod - def get_service(kb_name: str, - vector_store_type: SupportedVSType - ) -> KBService: - if SupportedVSType.FAISS == vector_store_type: - from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService - return FaissKBService(kb_name) - elif SupportedVSType.MILVUS == vector_store_type: - from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService - return MilvusKBService(kb_name) - elif SupportedVSType.DEFAULT == vector_store_type: - return DefaultKBService(kb_name) - - @staticmethod - def get_service_by_name(kb_name: str - ) -> KBService: - kb_name, vs_type, _ = load_kb_from_db(kb_name) - return KBServiceFactory.get_service(kb_name, vs_type) - - @staticmethod - def get_default(): - return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT) - - -if __name__ == '__main__': - # 测试建表使用 - # from server.db.base import Base, engine - # Base.metadata.create_all(bind=engine) - KBService = KBServiceFactory.get_service("test", SupportedVSType.FAISS) - KBService.create_kb() - KBService = KBServiceFactory.get_default() - print(KBService.list_kbs()) - KBService = KBServiceFactory.get_service_by_name("test") - print(KBService.list_docs()) - KBService.drop_kb() diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index fa6f132..00344e1 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -1,7 +1,12 @@ +from typing import Union import os from langchain.embeddings.huggingface import HuggingFaceEmbeddings -from configs.model_config import (embedding_model_dict, KB_ROOT_PATH) +from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, EMBEDDING_MODEL, kbs_config) from functools import lru_cache +from server.knowledge_base.kb_service.base import KBService, SupportedVSType +from server.db.repository.knowledge_base_repository import load_kb_from_db +from server.knowledge_base.kb_service.default_kb_service import DefaultKBService +from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService def validate_kb_name(knowledge_base_id: str) -> bool: @@ -22,8 +27,81 @@ def get_vs_path(knowledge_base_name: str): def get_file_path(knowledge_base_name: str, doc_name: str): return os.path.join(get_doc_path(knowledge_base_name), doc_name) + @lru_cache(1) def load_embeddings(model: str, device: str): embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device}) return embeddings + + +LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst', + '.rtf', '.txt', '.xml', + '.doc', '.docx', '.epub', '.odt', '.pdf', + '.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv' + "CSVLoader": [".csv"], + } +SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] + +def get_LoaderClass(file_extension): + for LoaderClass, extensions in LOADER_DICT.items(): + if file_extension in extensions: + return LoaderClass + + +class KnowledgeFile: + def __init__( + self, + filename: str, + knowledge_base_name: str + ): + self.kb_name = knowledge_base_name + self.filename = filename + self.ext = os.path.splitext(filename)[-1] + if self.ext not in SUPPORTED_EXTS: + raise ValueError(f"暂未支持的文件格式 {self.ext}") + self.filepath = get_file_path(knowledge_base_name, filename) + self.docs = None + self.document_loader_name = get_LoaderClass(self.ext) + + # TODO: 增加依据文件格式匹配text_splitter + self.text_splitter_name = "CharacterTextSplitter" + + def file2text(self): + DocumentLoader = getattr(sys.modules['langchain.document_loaders'], self.document_loader_name) + loader = DocumentLoader(self.filepath) + + # TODO: 增加依据文件格式匹配text_splitter + TextSplitter = getattr(sys.modules['langchain.text_splitter'], self.text_splitter_name) + text_splitter = TextSplitter(chunk_size=500, chunk_overlap=200) + return loader.load_and_split(text_splitter) + + +class KBServiceFactory: + + @staticmethod + def get_service(kb_name: str, + vector_store_type: Union[str, SupportedVSType], + embed_model: str = EMBEDDING_MODEL, + ) -> KBService: + if isinstance(vector_store_type, str): + vector_store_type = getattr(SupportedVSType, vector_store_type.upper()) + if SupportedVSType.FAISS == vector_store_type: + from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService + return FaissKBService(kb_name, embed_model=embed_model) + elif SupportedVSType.MILVUS == vector_store_type: + from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService + return MilvusKBService(kb_name, embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config + elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier. + return DefaultKBService(kb_name) + + @staticmethod + def get_service_by_name(kb_name: str + ) -> KBService: + kb_name, vs_type, embed_model = load_kb_from_db(kb_name) + return KBServiceFactory.get_service(kb_name, vs_type, embed_model) + + @staticmethod + def get_default(): + return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT) +