Langchain-Chatchat/server/knowledge_base/knowledge_base_factory.py

47 lines
2.0 KiB
Python

from typing import Union
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, init_db, load_kb_from_db
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
from configs.model_config import EMBEDDING_MODEL
class KBServiceFactory:
@staticmethod
def get_service(kb_name: str,
vector_store_type: Union[str, SupportedVSType],
embed_model: str = EMBEDDING_MODEL,
) -> KBService:
if isinstance(vector_store_type, str):
vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
if SupportedVSType.FAISS == vector_store_type:
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
return FaissKBService(kb_name, embed_model=embed_model)
# todo: Milvus has different init params
# elif SupportedVSType.MILVUS == vector_store_type:
# from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
# return MilvusKBService(kb_name,)
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
return DefaultKBService(kb_name)
@staticmethod
def get_service_by_name(kb_name: str
) -> KBService:
kb_name, vs_type, embed_model = load_kb_from_db(kb_name)
return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
@staticmethod
def get_default():
return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT)
if __name__ == '__main__':
KBService = KBServiceFactory.get_service("test", SupportedVSType.FAISS)
init_db()
KBService.create_kb()
KBService = KBServiceFactory.get_default()
print(KBService.list_kbs())
KBService = KBServiceFactory.get_service_by_name("test")
print(KBService.list_docs())
KBService.drop_kb()