From 44c713ef989ae640079a710d28cb48cd8a5685bd Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Mon, 7 Aug 2023 20:37:16 +0800 Subject: [PATCH] use KBServiceFactory to replace all the KnowledgeBase. make KBServiceFactory support embed_model parameter. rewrite api: recreate_vector_store. fix some bugs. --- requirements.txt | 3 ++ server/chat/knowledge_base_chat.py | 9 ++-- server/knowledge_base/__init__.py | 2 +- server/knowledge_base/kb_api.py | 18 ++++---- server/knowledge_base/kb_doc_api.py | 41 +++++++++++-------- server/knowledge_base/kb_service/base.py | 5 +++ .../kb_service/default_kb_service.py | 7 +++- .../kb_service/faiss_kb_service.py | 1 + .../knowledge_base/knowledge_base_factory.py | 22 ++++++---- 9 files changed, 68 insertions(+), 40 deletions(-) diff --git a/requirements.txt b/requirements.txt index 765cf9f..7ec558f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,6 @@ streamlit-option-menu streamlit-antd-components streamlit-chatbox>=1.1.6 httpx + +faiss-cpu +pymilvus==2.1.3 # requires milvus==2.1.3 diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 4eaf3fd..0a85df3 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -10,7 +10,8 @@ from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable import asyncio from langchain.prompts import PromptTemplate -from server.knowledge_base.knowledge_base import KnowledgeBase +from server.knowledge_base.knowledge_base_factory import KBServiceFactory +from server.knowledge_base.kb_service.base import KBService import json @@ -18,12 +19,12 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp knowledge_base_name: str = Body(..., description="知识库名称", example="samples"), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), ): - if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name) async def knowledge_base_chat_iterator(query: str, - kb: KnowledgeBase, + kb: KBService, top_k: int, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() diff --git a/server/knowledge_base/__init__.py b/server/knowledge_base/__init__.py index 14b7898..b64e921 100644 --- a/server/knowledge_base/__init__.py +++ b/server/knowledge_base/__init__.py @@ -1,4 +1,4 @@ from .kb_api import list_kbs, create_kb, delete_kb from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store -from .knowledge_base import KnowledgeBase from .knowledge_file import KnowledgeFile +from .knowledge_base_factory import KBServiceFactory diff --git a/server/knowledge_base/kb_api.py b/server/knowledge_base/kb_api.py index d5ef57b..9a92ea6 100644 --- a/server/knowledge_base/kb_api.py +++ b/server/knowledge_base/kb_api.py @@ -1,13 +1,14 @@ import urllib from server.utils import BaseResponse, ListResponse from server.knowledge_base.utils import validate_kb_name -from server.knowledge_base.knowledge_base import KnowledgeBase +from server.knowledge_base.knowledge_base_factory import KBServiceFactory +from server.knowledge_base.kb_service.base import list_kbs_from_db from configs.model_config import EMBEDDING_MODEL async def list_kbs(): # Get List of Knowledge Base - return ListResponse(data=KnowledgeBase.list_kbs()) + return ListResponse(data=list_kbs_from_db()) async def create_kb(knowledge_base_name: str, @@ -19,11 +20,10 @@ async def create_kb(knowledge_base_name: str, return BaseResponse(code=403, msg="Don't attack me") if knowledge_base_name is None or knowledge_base_name.strip() == "": return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称") - if KnowledgeBase.exists(knowledge_base_name): + + kb = KBServiceFactory.get_service(knowledge_base_name, "faiss") + if kb is not None: return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}") - kb = KnowledgeBase(knowledge_base_name=knowledge_base_name, - vector_store_type=vector_store_type, - embed_model=embed_model) kb.create() return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") @@ -34,10 +34,12 @@ async def delete_kb(knowledge_base_name: str): return BaseResponse(code=403, msg="Don't attack me") knowledge_base_name = urllib.parse.unquote(knowledge_base_name) - if not KnowledgeBase.exists(knowledge_base_name): + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + + if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - status = KnowledgeBase.delete(knowledge_base_name) + status = kb.drop_kb() if status: return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}") else: diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 62de2d9..2ebe828 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -6,7 +6,9 @@ from server.knowledge_base.utils import (validate_kb_name) from fastapi.responses import StreamingResponse import json from server.knowledge_base.knowledge_file import KnowledgeFile -from server.knowledge_base.knowledge_base import KnowledgeBase +from server.knowledge_base.knowledge_base_factory import KBServiceFactory +from server.knowledge_base.kb_service.base import SupportedVSType, list_docs_from_folder +from server.knowledge_base.kb_service.faiss_kb_service import refresh_vs_cache async def list_docs(knowledge_base_name: str): @@ -14,10 +16,11 @@ async def list_docs(knowledge_base_name: str): return ListResponse(code=403, msg="Don't attack me", data=[]) knowledge_base_name = urllib.parse.unquote(knowledge_base_name) - if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[]) else: - all_doc_names = KnowledgeBase.load(knowledge_base_name=knowledge_base_name).list_docs() + all_doc_names = kb.list_docs() return ListResponse(data=all_doc_names) @@ -28,11 +31,10 @@ async def upload_doc(file: UploadFile = File(description="上传文件"), if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") - if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name) - file_content = await file.read() # 读取上传文件的内容 kb_file = KnowledgeFile(filename=file.filename, @@ -63,10 +65,10 @@ async def delete_doc(knowledge_base_name: str, return BaseResponse(code=403, msg="Don't attack me") knowledge_base_name = urllib.parse.unquote(knowledge_base_name) - if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name) if not kb.exist_doc(doc_name): return BaseResponse(code=404, msg=f"未找到文件 {doc_name}") kb_file = KnowledgeFile(filename=doc_name, @@ -92,21 +94,26 @@ async def recreate_vector_store(knowledge_base_name: str): recreate vector store from the content. this is usefull when user can copy files to content folder directly instead of upload through network. ''' - kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name) + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: + return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - async def output(kb: KnowledgeBase): - kb.recreate_vs() + async def output(kb): + kb.clear_vs() print(f"start to recreate vector store of {kb.kb_name}") - docs = kb.list_docs() + docs = list_docs_from_folder(knowledge_base_name) + print(docs) for i, filename in enumerate(docs): + yield json.dumps({ + "total": len(docs), + "finished": i, + "doc": filename, + }) kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb.kb_name) print(f"processing {kb_file.filepath} to vector store.") kb.add_doc(kb_file) - yield json.dumps({ - "total": len(docs), - "finished": i + 1, - "doc": filename, - }) + if kb.vs_type == SupportedVSType.FAISS: + refresh_vs_cache(knowledge_base_name) return StreamingResponse(output(kb), media_type="text/event-stream") diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index c7fb24e..243ceda 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -14,6 +14,7 @@ import datetime from server.knowledge_base.utils import (get_kb_path, get_doc_path) from server.knowledge_base.knowledge_file import KnowledgeFile from typing import List +import os class SupportedVSType: @@ -125,6 +126,10 @@ def list_docs_from_db(kb_name): conn.close() return kbs +def list_docs_from_folder(kb_name: str): + doc_path = get_doc_path(kb_name) + return [file for file in os.listdir(doc_path) + if os.path.isfile(os.path.join(doc_path, file))] def add_doc_to_db(kb_file: KnowledgeFile): conn = sqlite3.connect(DB_ROOT_PATH) diff --git a/server/knowledge_base/kb_service/default_kb_service.py b/server/knowledge_base/kb_service/default_kb_service.py index fb31859..36ebfcf 100644 --- a/server/knowledge_base/kb_service/default_kb_service.py +++ b/server/knowledge_base/kb_service/default_kb_service.py @@ -5,13 +5,13 @@ class DefaultKBService(KBService): def vs_type(self) -> str: return "default" - def do_create_kbs(self): + def do_create_kb(self): pass def do_init(self): pass - def do_drop_kbs(self): + def do_drop_kb(self): pass def do_search(self): @@ -25,3 +25,6 @@ class DefaultKBService(KBService): def do_delete_doc(self): pass + + def kb_exists(self): + return False diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 64ed9ed..5e9e8b7 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -122,3 +122,4 @@ class FaissKBService(KBService): def do_clear_vs(self): shutil.rmtree(self.vs_path) + os.makedirs(self.vs_path) diff --git a/server/knowledge_base/knowledge_base_factory.py b/server/knowledge_base/knowledge_base_factory.py index 7d61748..0ae93c0 100644 --- a/server/knowledge_base/knowledge_base_factory.py +++ b/server/knowledge_base/knowledge_base_factory.py @@ -1,28 +1,34 @@ +from typing import Union from server.knowledge_base.kb_service.base import KBService, SupportedVSType, init_db, load_kb_from_db from server.knowledge_base.kb_service.default_kb_service import DefaultKBService from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService +from configs.model_config import EMBEDDING_MODEL class KBServiceFactory: @staticmethod def get_service(kb_name: str, - vector_store_type: SupportedVSType + vector_store_type: Union[str, SupportedVSType], + embed_model: str = EMBEDDING_MODEL, ) -> KBService: + if isinstance(vector_store_type, str): + vector_store_type = getattr(SupportedVSType, vector_store_type.upper()) if SupportedVSType.FAISS == vector_store_type: from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService - return FaissKBService(kb_name) - elif SupportedVSType.MILVUS == vector_store_type: - from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService - return MilvusKBService(kb_name) - elif SupportedVSType.DEFAULT == vector_store_type: + return FaissKBService(kb_name, embed_model=embed_model) + # todo: Milvus has different init params + # elif SupportedVSType.MILVUS == vector_store_type: + # from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService + # return MilvusKBService(kb_name,) + elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier. return DefaultKBService(kb_name) @staticmethod def get_service_by_name(kb_name: str ) -> KBService: - kb_name, vs_type = load_kb_from_db(kb_name) - return KBServiceFactory.get_service(kb_name, vs_type) + kb_name, vs_type, embed_model = load_kb_from_db(kb_name) + return KBServiceFactory.get_service(kb_name, vs_type, embed_model) @staticmethod def get_default():