remove server/knowledge_base/knowledge_base_factory.py
This commit is contained in:
parent
7e07f0bbf4
commit
584e3a9234
|
|
@ -1,4 +1,3 @@
|
||||||
from .kb_api import list_kbs, create_kb, delete_kb
|
from .kb_api import list_kbs, create_kb, delete_kb
|
||||||
from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store
|
from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store
|
||||||
from .knowledge_file import KnowledgeFile
|
from .utils import KnowledgeFile, KBServiceFactory
|
||||||
from .knowledge_base_factory import KBServiceFactory
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
import urllib
|
import urllib
|
||||||
from server.utils import BaseResponse, ListResponse
|
from server.utils import BaseResponse, ListResponse
|
||||||
from server.knowledge_base.utils import validate_kb_name
|
from server.knowledge_base.utils import validate_kb_name, KBServiceFactory
|
||||||
from server.knowledge_base.knowledge_base_factory import KBServiceFactory
|
|
||||||
from server.knowledge_base.kb_service.base import list_kbs_from_db
|
from server.knowledge_base.kb_service.base import list_kbs_from_db
|
||||||
from configs.model_config import EMBEDDING_MODEL
|
from configs.model_config import EMBEDDING_MODEL
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,7 @@ from server.utils import BaseResponse, ListResponse
|
||||||
from server.knowledge_base.utils import (validate_kb_name)
|
from server.knowledge_base.utils import (validate_kb_name)
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import json
|
import json
|
||||||
from server.knowledge_base.knowledge_file import KnowledgeFile
|
from server.knowledge_base.utils import KnowledgeFile, KBServiceFactory
|
||||||
from server.knowledge_base.knowledge_base_factory import KBServiceFactory
|
|
||||||
from server.knowledge_base.kb_service.base import SupportedVSType, list_docs_from_folder
|
from server.knowledge_base.kb_service.base import SupportedVSType, list_docs_from_folder
|
||||||
from server.knowledge_base.kb_service.faiss_kb_service import refresh_vs_cache
|
from server.knowledge_base.kb_service.faiss_kb_service import refresh_vs_cache
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,18 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sqlite3
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
from configs.model_config import (VECTOR_SEARCH_TOP_K,
|
|
||||||
embedding_model_dict, EMBEDDING_DEVICE, EMBEDDING_MODEL,
|
|
||||||
kbs_config)
|
|
||||||
|
|
||||||
from server.db.repository.knowledge_base_repository import add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists
|
from server.db.repository.knowledge_base_repository import add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists
|
||||||
from server.db.repository.knowledge_file_repository import add_doc_to_db, delete_file_from_db, doc_exists, \
|
from server.db.repository.knowledge_file_repository import add_doc_to_db, delete_file_from_db, doc_exists, \
|
||||||
list_docs_from_db
|
list_docs_from_db
|
||||||
from server.knowledge_base.utils import (get_kb_path, get_doc_path)
|
from configs.model_config import (DB_ROOT_PATH, kbs_config, VECTOR_SEARCH_TOP_K,
|
||||||
from server.knowledge_base.knowledge_file import KnowledgeFile
|
embedding_model_dict, EMBEDDING_DEVICE, EMBEDDING_MODEL)
|
||||||
|
from server.knowledge_base.utils import (get_kb_path, get_doc_path, KnowledgeFile)
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,7 @@ import shutil
|
||||||
from configs.model_config import KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_DEVICE
|
from configs.model_config import KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_DEVICE
|
||||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from server.knowledge_base.utils import get_vs_path
|
from server.knowledge_base.utils import get_vs_path, KnowledgeFile
|
||||||
from server.knowledge_base.knowledge_file import KnowledgeFile
|
|
||||||
from langchain.vectorstores import FAISS
|
from langchain.vectorstores import FAISS
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,9 @@ from langchain.schema import Document
|
||||||
from langchain.vectorstores import Milvus
|
from langchain.vectorstores import Milvus
|
||||||
|
|
||||||
from configs.model_config import EMBEDDING_DEVICE, kbs_config
|
from configs.model_config import EMBEDDING_DEVICE, kbs_config
|
||||||
from server.knowledge_base import KnowledgeFile
|
|
||||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings
|
||||||
|
from server.knowledge_base.utils import KnowledgeFile
|
||||||
|
|
||||||
|
|
||||||
class MilvusKBService(KBService):
|
class MilvusKBService(KBService):
|
||||||
|
|
|
||||||
|
|
@ -1,42 +0,0 @@
|
||||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
|
||||||
from server.db.repository.knowledge_base_repository import load_kb_from_db
|
|
||||||
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
|
|
||||||
|
|
||||||
|
|
||||||
class KBServiceFactory:
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_service(kb_name: str,
|
|
||||||
vector_store_type: SupportedVSType
|
|
||||||
) -> KBService:
|
|
||||||
if SupportedVSType.FAISS == vector_store_type:
|
|
||||||
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
|
||||||
return FaissKBService(kb_name)
|
|
||||||
elif SupportedVSType.MILVUS == vector_store_type:
|
|
||||||
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
|
||||||
return MilvusKBService(kb_name)
|
|
||||||
elif SupportedVSType.DEFAULT == vector_store_type:
|
|
||||||
return DefaultKBService(kb_name)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_service_by_name(kb_name: str
|
|
||||||
) -> KBService:
|
|
||||||
kb_name, vs_type, _ = load_kb_from_db(kb_name)
|
|
||||||
return KBServiceFactory.get_service(kb_name, vs_type)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_default():
|
|
||||||
return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# 测试建表使用
|
|
||||||
# from server.db.base import Base, engine
|
|
||||||
# Base.metadata.create_all(bind=engine)
|
|
||||||
KBService = KBServiceFactory.get_service("test", SupportedVSType.FAISS)
|
|
||||||
KBService.create_kb()
|
|
||||||
KBService = KBServiceFactory.get_default()
|
|
||||||
print(KBService.list_kbs())
|
|
||||||
KBService = KBServiceFactory.get_service_by_name("test")
|
|
||||||
print(KBService.list_docs())
|
|
||||||
KBService.drop_kb()
|
|
||||||
|
|
@ -1,7 +1,12 @@
|
||||||
|
from typing import Union
|
||||||
import os
|
import os
|
||||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||||
from configs.model_config import (embedding_model_dict, KB_ROOT_PATH)
|
from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, EMBEDDING_MODEL, kbs_config)
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||||
|
from server.db.repository.knowledge_base_repository import load_kb_from_db
|
||||||
|
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
|
||||||
|
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||||
|
|
||||||
|
|
||||||
def validate_kb_name(knowledge_base_id: str) -> bool:
|
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||||
|
|
@ -22,8 +27,81 @@ def get_vs_path(knowledge_base_name: str):
|
||||||
def get_file_path(knowledge_base_name: str, doc_name: str):
|
def get_file_path(knowledge_base_name: str, doc_name: str):
|
||||||
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
|
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(1)
|
@lru_cache(1)
|
||||||
def load_embeddings(model: str, device: str):
|
def load_embeddings(model: str, device: str):
|
||||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
|
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
|
||||||
model_kwargs={'device': device})
|
model_kwargs={'device': device})
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst',
|
||||||
|
'.rtf', '.txt', '.xml',
|
||||||
|
'.doc', '.docx', '.epub', '.odt', '.pdf',
|
||||||
|
'.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv'
|
||||||
|
"CSVLoader": [".csv"],
|
||||||
|
}
|
||||||
|
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
||||||
|
|
||||||
|
def get_LoaderClass(file_extension):
|
||||||
|
for LoaderClass, extensions in LOADER_DICT.items():
|
||||||
|
if file_extension in extensions:
|
||||||
|
return LoaderClass
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeFile:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
filename: str,
|
||||||
|
knowledge_base_name: str
|
||||||
|
):
|
||||||
|
self.kb_name = knowledge_base_name
|
||||||
|
self.filename = filename
|
||||||
|
self.ext = os.path.splitext(filename)[-1]
|
||||||
|
if self.ext not in SUPPORTED_EXTS:
|
||||||
|
raise ValueError(f"暂未支持的文件格式 {self.ext}")
|
||||||
|
self.filepath = get_file_path(knowledge_base_name, filename)
|
||||||
|
self.docs = None
|
||||||
|
self.document_loader_name = get_LoaderClass(self.ext)
|
||||||
|
|
||||||
|
# TODO: 增加依据文件格式匹配text_splitter
|
||||||
|
self.text_splitter_name = "CharacterTextSplitter"
|
||||||
|
|
||||||
|
def file2text(self):
|
||||||
|
DocumentLoader = getattr(sys.modules['langchain.document_loaders'], self.document_loader_name)
|
||||||
|
loader = DocumentLoader(self.filepath)
|
||||||
|
|
||||||
|
# TODO: 增加依据文件格式匹配text_splitter
|
||||||
|
TextSplitter = getattr(sys.modules['langchain.text_splitter'], self.text_splitter_name)
|
||||||
|
text_splitter = TextSplitter(chunk_size=500, chunk_overlap=200)
|
||||||
|
return loader.load_and_split(text_splitter)
|
||||||
|
|
||||||
|
|
||||||
|
class KBServiceFactory:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_service(kb_name: str,
|
||||||
|
vector_store_type: Union[str, SupportedVSType],
|
||||||
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
|
) -> KBService:
|
||||||
|
if isinstance(vector_store_type, str):
|
||||||
|
vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
|
||||||
|
if SupportedVSType.FAISS == vector_store_type:
|
||||||
|
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
||||||
|
return FaissKBService(kb_name, embed_model=embed_model)
|
||||||
|
elif SupportedVSType.MILVUS == vector_store_type:
|
||||||
|
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||||
|
return MilvusKBService(kb_name, embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
|
||||||
|
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
|
||||||
|
return DefaultKBService(kb_name)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_service_by_name(kb_name: str
|
||||||
|
) -> KBService:
|
||||||
|
kb_name, vs_type, embed_model = load_kb_from_db(kb_name)
|
||||||
|
return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_default():
|
||||||
|
return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue