use KBServiceFactory to replace all the KnowledgeBase.

make KBServiceFactory support embed_model parameter.
rewrite api: recreate_vector_store.
fix some bugs.
This commit is contained in:
liunux4odoo 2023-08-07 20:37:16 +08:00
parent 0f46185cfb
commit 44c713ef98
9 changed files with 68 additions and 40 deletions

View File

@ -21,3 +21,6 @@ streamlit-option-menu
streamlit-antd-components streamlit-antd-components
streamlit-chatbox>=1.1.6 streamlit-chatbox>=1.1.6
httpx httpx
faiss-cpu
pymilvus==2.1.3 # requires milvus==2.1.3

View File

@ -10,7 +10,8 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable from typing import AsyncIterable
import asyncio import asyncio
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from server.knowledge_base.knowledge_base import KnowledgeBase from server.knowledge_base.knowledge_base_factory import KBServiceFactory
from server.knowledge_base.kb_service.base import KBService
import json import json
@ -18,12 +19,12 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"), knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
): ):
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
async def knowledge_base_chat_iterator(query: str, async def knowledge_base_chat_iterator(query: str,
kb: KnowledgeBase, kb: KBService,
top_k: int, top_k: int,
) -> AsyncIterable[str]: ) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler() callback = AsyncIteratorCallbackHandler()

View File

@ -1,4 +1,4 @@
from .kb_api import list_kbs, create_kb, delete_kb from .kb_api import list_kbs, create_kb, delete_kb
from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store
from .knowledge_base import KnowledgeBase
from .knowledge_file import KnowledgeFile from .knowledge_file import KnowledgeFile
from .knowledge_base_factory import KBServiceFactory

View File

@ -1,13 +1,14 @@
import urllib import urllib
from server.utils import BaseResponse, ListResponse from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import validate_kb_name from server.knowledge_base.utils import validate_kb_name
from server.knowledge_base.knowledge_base import KnowledgeBase from server.knowledge_base.knowledge_base_factory import KBServiceFactory
from server.knowledge_base.kb_service.base import list_kbs_from_db
from configs.model_config import EMBEDDING_MODEL from configs.model_config import EMBEDDING_MODEL
async def list_kbs(): async def list_kbs():
# Get List of Knowledge Base # Get List of Knowledge Base
return ListResponse(data=KnowledgeBase.list_kbs()) return ListResponse(data=list_kbs_from_db())
async def create_kb(knowledge_base_name: str, async def create_kb(knowledge_base_name: str,
@ -19,11 +20,10 @@ async def create_kb(knowledge_base_name: str,
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")
if knowledge_base_name is None or knowledge_base_name.strip() == "": if knowledge_base_name is None or knowledge_base_name.strip() == "":
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称") return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
if KnowledgeBase.exists(knowledge_base_name):
kb = KBServiceFactory.get_service(knowledge_base_name, "faiss")
if kb is not None:
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}") return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
kb = KnowledgeBase(knowledge_base_name=knowledge_base_name,
vector_store_type=vector_store_type,
embed_model=embed_model)
kb.create() kb.create()
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
@ -34,10 +34,12 @@ async def delete_kb(knowledge_base_name: str):
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")
knowledge_base_name = urllib.parse.unquote(knowledge_base_name) knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
if not KnowledgeBase.exists(knowledge_base_name): kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
status = KnowledgeBase.delete(knowledge_base_name) status = kb.drop_kb()
if status: if status:
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}") return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
else: else:

View File

@ -6,7 +6,9 @@ from server.knowledge_base.utils import (validate_kb_name)
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import json import json
from server.knowledge_base.knowledge_file import KnowledgeFile from server.knowledge_base.knowledge_file import KnowledgeFile
from server.knowledge_base.knowledge_base import KnowledgeBase from server.knowledge_base.knowledge_base_factory import KBServiceFactory
from server.knowledge_base.kb_service.base import SupportedVSType, list_docs_from_folder
from server.knowledge_base.kb_service.faiss_kb_service import refresh_vs_cache
async def list_docs(knowledge_base_name: str): async def list_docs(knowledge_base_name: str):
@ -14,10 +16,11 @@ async def list_docs(knowledge_base_name: str):
return ListResponse(code=403, msg="Don't attack me", data=[]) return ListResponse(code=403, msg="Don't attack me", data=[])
knowledge_base_name = urllib.parse.unquote(knowledge_base_name) knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[]) return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
else: else:
all_doc_names = KnowledgeBase.load(knowledge_base_name=knowledge_base_name).list_docs() all_doc_names = kb.list_docs()
return ListResponse(data=all_doc_names) return ListResponse(data=all_doc_names)
@ -28,11 +31,10 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
file_content = await file.read() # 读取上传文件的内容 file_content = await file.read() # 读取上传文件的内容
kb_file = KnowledgeFile(filename=file.filename, kb_file = KnowledgeFile(filename=file.filename,
@ -63,10 +65,10 @@ async def delete_doc(knowledge_base_name: str,
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")
knowledge_base_name = urllib.parse.unquote(knowledge_base_name) knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
if not kb.exist_doc(doc_name): if not kb.exist_doc(doc_name):
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}") return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
kb_file = KnowledgeFile(filename=doc_name, kb_file = KnowledgeFile(filename=doc_name,
@ -92,21 +94,26 @@ async def recreate_vector_store(knowledge_base_name: str):
recreate vector store from the content. recreate vector store from the content.
this is usefull when user can copy files to content folder directly instead of upload through network. this is usefull when user can copy files to content folder directly instead of upload through network.
''' '''
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name) kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
async def output(kb: KnowledgeBase): async def output(kb):
kb.recreate_vs() kb.clear_vs()
print(f"start to recreate vector store of {kb.kb_name}") print(f"start to recreate vector store of {kb.kb_name}")
docs = kb.list_docs() docs = list_docs_from_folder(knowledge_base_name)
print(docs)
for i, filename in enumerate(docs): for i, filename in enumerate(docs):
yield json.dumps({
"total": len(docs),
"finished": i,
"doc": filename,
})
kb_file = KnowledgeFile(filename=filename, kb_file = KnowledgeFile(filename=filename,
knowledge_base_name=kb.kb_name) knowledge_base_name=kb.kb_name)
print(f"processing {kb_file.filepath} to vector store.") print(f"processing {kb_file.filepath} to vector store.")
kb.add_doc(kb_file) kb.add_doc(kb_file)
yield json.dumps({ if kb.vs_type == SupportedVSType.FAISS:
"total": len(docs), refresh_vs_cache(knowledge_base_name)
"finished": i + 1,
"doc": filename,
})
return StreamingResponse(output(kb), media_type="text/event-stream") return StreamingResponse(output(kb), media_type="text/event-stream")

View File

@ -14,6 +14,7 @@ import datetime
from server.knowledge_base.utils import (get_kb_path, get_doc_path) from server.knowledge_base.utils import (get_kb_path, get_doc_path)
from server.knowledge_base.knowledge_file import KnowledgeFile from server.knowledge_base.knowledge_file import KnowledgeFile
from typing import List from typing import List
import os
class SupportedVSType: class SupportedVSType:
@ -125,6 +126,10 @@ def list_docs_from_db(kb_name):
conn.close() conn.close()
return kbs return kbs
def list_docs_from_folder(kb_name: str):
doc_path = get_doc_path(kb_name)
return [file for file in os.listdir(doc_path)
if os.path.isfile(os.path.join(doc_path, file))]
def add_doc_to_db(kb_file: KnowledgeFile): def add_doc_to_db(kb_file: KnowledgeFile):
conn = sqlite3.connect(DB_ROOT_PATH) conn = sqlite3.connect(DB_ROOT_PATH)

View File

@ -5,13 +5,13 @@ class DefaultKBService(KBService):
def vs_type(self) -> str: def vs_type(self) -> str:
return "default" return "default"
def do_create_kbs(self): def do_create_kb(self):
pass pass
def do_init(self): def do_init(self):
pass pass
def do_drop_kbs(self): def do_drop_kb(self):
pass pass
def do_search(self): def do_search(self):
@ -25,3 +25,6 @@ class DefaultKBService(KBService):
def do_delete_doc(self): def do_delete_doc(self):
pass pass
def kb_exists(self):
return False

View File

@ -122,3 +122,4 @@ class FaissKBService(KBService):
def do_clear_vs(self): def do_clear_vs(self):
shutil.rmtree(self.vs_path) shutil.rmtree(self.vs_path)
os.makedirs(self.vs_path)

View File

@ -1,28 +1,34 @@
from typing import Union
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, init_db, load_kb_from_db from server.knowledge_base.kb_service.base import KBService, SupportedVSType, init_db, load_kb_from_db
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
from configs.model_config import EMBEDDING_MODEL
class KBServiceFactory: class KBServiceFactory:
@staticmethod @staticmethod
def get_service(kb_name: str, def get_service(kb_name: str,
vector_store_type: SupportedVSType vector_store_type: Union[str, SupportedVSType],
embed_model: str = EMBEDDING_MODEL,
) -> KBService: ) -> KBService:
if isinstance(vector_store_type, str):
vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
if SupportedVSType.FAISS == vector_store_type: if SupportedVSType.FAISS == vector_store_type:
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
return FaissKBService(kb_name) return FaissKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.MILVUS == vector_store_type: # todo: Milvus has different init params
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService # elif SupportedVSType.MILVUS == vector_store_type:
return MilvusKBService(kb_name) # from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
elif SupportedVSType.DEFAULT == vector_store_type: # return MilvusKBService(kb_name,)
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
return DefaultKBService(kb_name) return DefaultKBService(kb_name)
@staticmethod @staticmethod
def get_service_by_name(kb_name: str def get_service_by_name(kb_name: str
) -> KBService: ) -> KBService:
kb_name, vs_type = load_kb_from_db(kb_name) kb_name, vs_type, embed_model = load_kb_from_db(kb_name)
return KBServiceFactory.get_service(kb_name, vs_type) return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
@staticmethod @staticmethod
def get_default(): def get_default():