use KBServiceFactory to replace all the KnowledgeBase.
make KBServiceFactory support embed_model parameter. rewrite api: recreate_vector_store. fix some bugs.
This commit is contained in:
parent
0f46185cfb
commit
44c713ef98
|
|
@ -21,3 +21,6 @@ streamlit-option-menu
|
|||
streamlit-antd-components
|
||||
streamlit-chatbox>=1.1.6
|
||||
httpx
|
||||
|
||||
faiss-cpu
|
||||
pymilvus==2.1.3 # requires milvus==2.1.3
|
||||
|
|
|
|||
|
|
@ -10,7 +10,8 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
|
|||
from typing import AsyncIterable
|
||||
import asyncio
|
||||
from langchain.prompts import PromptTemplate
|
||||
from server.knowledge_base.knowledge_base import KnowledgeBase
|
||||
from server.knowledge_base.knowledge_base_factory import KBServiceFactory
|
||||
from server.knowledge_base.kb_service.base import KBService
|
||||
import json
|
||||
|
||||
|
||||
|
|
@ -18,12 +19,12 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
|||
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
):
|
||||
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
||||
|
||||
async def knowledge_base_chat_iterator(query: str,
|
||||
kb: KnowledgeBase,
|
||||
kb: KBService,
|
||||
top_k: int,
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from .kb_api import list_kbs, create_kb, delete_kb
|
||||
from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store
|
||||
from .knowledge_base import KnowledgeBase
|
||||
from .knowledge_file import KnowledgeFile
|
||||
from .knowledge_base_factory import KBServiceFactory
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
import urllib
|
||||
from server.utils import BaseResponse, ListResponse
|
||||
from server.knowledge_base.utils import validate_kb_name
|
||||
from server.knowledge_base.knowledge_base import KnowledgeBase
|
||||
from server.knowledge_base.knowledge_base_factory import KBServiceFactory
|
||||
from server.knowledge_base.kb_service.base import list_kbs_from_db
|
||||
from configs.model_config import EMBEDDING_MODEL
|
||||
|
||||
|
||||
async def list_kbs():
|
||||
# Get List of Knowledge Base
|
||||
return ListResponse(data=KnowledgeBase.list_kbs())
|
||||
return ListResponse(data=list_kbs_from_db())
|
||||
|
||||
|
||||
async def create_kb(knowledge_base_name: str,
|
||||
|
|
@ -19,11 +20,10 @@ async def create_kb(knowledge_base_name: str,
|
|||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
if knowledge_base_name is None or knowledge_base_name.strip() == "":
|
||||
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
|
||||
if KnowledgeBase.exists(knowledge_base_name):
|
||||
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, "faiss")
|
||||
if kb is not None:
|
||||
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
|
||||
kb = KnowledgeBase(knowledge_base_name=knowledge_base_name,
|
||||
vector_store_type=vector_store_type,
|
||||
embed_model=embed_model)
|
||||
kb.create()
|
||||
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
|
||||
|
||||
|
|
@ -34,10 +34,12 @@ async def delete_kb(knowledge_base_name: str):
|
|||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
||||
|
||||
if not KnowledgeBase.exists(knowledge_base_name):
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
status = KnowledgeBase.delete(knowledge_base_name)
|
||||
status = kb.drop_kb()
|
||||
if status:
|
||||
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,9 @@ from server.knowledge_base.utils import (validate_kb_name)
|
|||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
from server.knowledge_base.knowledge_file import KnowledgeFile
|
||||
from server.knowledge_base.knowledge_base import KnowledgeBase
|
||||
from server.knowledge_base.knowledge_base_factory import KBServiceFactory
|
||||
from server.knowledge_base.kb_service.base import SupportedVSType, list_docs_from_folder
|
||||
from server.knowledge_base.kb_service.faiss_kb_service import refresh_vs_cache
|
||||
|
||||
|
||||
async def list_docs(knowledge_base_name: str):
|
||||
|
|
@ -14,10 +16,11 @@ async def list_docs(knowledge_base_name: str):
|
|||
return ListResponse(code=403, msg="Don't attack me", data=[])
|
||||
|
||||
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
||||
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
|
||||
else:
|
||||
all_doc_names = KnowledgeBase.load(knowledge_base_name=knowledge_base_name).list_docs()
|
||||
all_doc_names = kb.list_docs()
|
||||
return ListResponse(data=all_doc_names)
|
||||
|
||||
|
||||
|
|
@ -28,11 +31,10 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
|
|||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
||||
|
||||
file_content = await file.read() # 读取上传文件的内容
|
||||
|
||||
kb_file = KnowledgeFile(filename=file.filename,
|
||||
|
|
@ -63,10 +65,10 @@ async def delete_doc(knowledge_base_name: str,
|
|||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
||||
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
||||
if not kb.exist_doc(doc_name):
|
||||
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
|
||||
kb_file = KnowledgeFile(filename=doc_name,
|
||||
|
|
@ -92,21 +94,26 @@ async def recreate_vector_store(knowledge_base_name: str):
|
|||
recreate vector store from the content.
|
||||
this is usefull when user can copy files to content folder directly instead of upload through network.
|
||||
'''
|
||||
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
async def output(kb: KnowledgeBase):
|
||||
kb.recreate_vs()
|
||||
async def output(kb):
|
||||
kb.clear_vs()
|
||||
print(f"start to recreate vector store of {kb.kb_name}")
|
||||
docs = kb.list_docs()
|
||||
docs = list_docs_from_folder(knowledge_base_name)
|
||||
print(docs)
|
||||
for i, filename in enumerate(docs):
|
||||
yield json.dumps({
|
||||
"total": len(docs),
|
||||
"finished": i,
|
||||
"doc": filename,
|
||||
})
|
||||
kb_file = KnowledgeFile(filename=filename,
|
||||
knowledge_base_name=kb.kb_name)
|
||||
print(f"processing {kb_file.filepath} to vector store.")
|
||||
kb.add_doc(kb_file)
|
||||
yield json.dumps({
|
||||
"total": len(docs),
|
||||
"finished": i + 1,
|
||||
"doc": filename,
|
||||
})
|
||||
if kb.vs_type == SupportedVSType.FAISS:
|
||||
refresh_vs_cache(knowledge_base_name)
|
||||
|
||||
return StreamingResponse(output(kb), media_type="text/event-stream")
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import datetime
|
|||
from server.knowledge_base.utils import (get_kb_path, get_doc_path)
|
||||
from server.knowledge_base.knowledge_file import KnowledgeFile
|
||||
from typing import List
|
||||
import os
|
||||
|
||||
|
||||
class SupportedVSType:
|
||||
|
|
@ -125,6 +126,10 @@ def list_docs_from_db(kb_name):
|
|||
conn.close()
|
||||
return kbs
|
||||
|
||||
def list_docs_from_folder(kb_name: str):
|
||||
doc_path = get_doc_path(kb_name)
|
||||
return [file for file in os.listdir(doc_path)
|
||||
if os.path.isfile(os.path.join(doc_path, file))]
|
||||
|
||||
def add_doc_to_db(kb_file: KnowledgeFile):
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
|
|
|
|||
|
|
@ -5,13 +5,13 @@ class DefaultKBService(KBService):
|
|||
def vs_type(self) -> str:
|
||||
return "default"
|
||||
|
||||
def do_create_kbs(self):
|
||||
def do_create_kb(self):
|
||||
pass
|
||||
|
||||
def do_init(self):
|
||||
pass
|
||||
|
||||
def do_drop_kbs(self):
|
||||
def do_drop_kb(self):
|
||||
pass
|
||||
|
||||
def do_search(self):
|
||||
|
|
@ -25,3 +25,6 @@ class DefaultKBService(KBService):
|
|||
|
||||
def do_delete_doc(self):
|
||||
pass
|
||||
|
||||
def kb_exists(self):
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -122,3 +122,4 @@ class FaissKBService(KBService):
|
|||
|
||||
def do_clear_vs(self):
|
||||
shutil.rmtree(self.vs_path)
|
||||
os.makedirs(self.vs_path)
|
||||
|
|
|
|||
|
|
@ -1,28 +1,34 @@
|
|||
from typing import Union
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, init_db, load_kb_from_db
|
||||
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
|
||||
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||
from configs.model_config import EMBEDDING_MODEL
|
||||
|
||||
|
||||
class KBServiceFactory:
|
||||
|
||||
@staticmethod
|
||||
def get_service(kb_name: str,
|
||||
vector_store_type: SupportedVSType
|
||||
vector_store_type: Union[str, SupportedVSType],
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
) -> KBService:
|
||||
if isinstance(vector_store_type, str):
|
||||
vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
|
||||
if SupportedVSType.FAISS == vector_store_type:
|
||||
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
||||
return FaissKBService(kb_name)
|
||||
elif SupportedVSType.MILVUS == vector_store_type:
|
||||
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||
return MilvusKBService(kb_name)
|
||||
elif SupportedVSType.DEFAULT == vector_store_type:
|
||||
return FaissKBService(kb_name, embed_model=embed_model)
|
||||
# todo: Milvus has different init params
|
||||
# elif SupportedVSType.MILVUS == vector_store_type:
|
||||
# from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||
# return MilvusKBService(kb_name,)
|
||||
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
|
||||
return DefaultKBService(kb_name)
|
||||
|
||||
@staticmethod
|
||||
def get_service_by_name(kb_name: str
|
||||
) -> KBService:
|
||||
kb_name, vs_type = load_kb_from_db(kb_name)
|
||||
return KBServiceFactory.get_service(kb_name, vs_type)
|
||||
kb_name, vs_type, embed_model = load_kb_from_db(kb_name)
|
||||
return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
|
||||
|
||||
@staticmethod
|
||||
def get_default():
|
||||
|
|
|
|||
Loading…
Reference in New Issue