use KBServiceFactory to replace all the KnowledgeBase.
make KBServiceFactory support embed_model parameter. rewrite api: recreate_vector_store. fix some bugs.
This commit is contained in:
parent
0f46185cfb
commit
44c713ef98
|
|
@ -21,3 +21,6 @@ streamlit-option-menu
|
||||||
streamlit-antd-components
|
streamlit-antd-components
|
||||||
streamlit-chatbox>=1.1.6
|
streamlit-chatbox>=1.1.6
|
||||||
httpx
|
httpx
|
||||||
|
|
||||||
|
faiss-cpu
|
||||||
|
pymilvus==2.1.3 # requires milvus==2.1.3
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,8 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
from typing import AsyncIterable
|
from typing import AsyncIterable
|
||||||
import asyncio
|
import asyncio
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from server.knowledge_base.knowledge_base import KnowledgeBase
|
from server.knowledge_base.knowledge_base_factory import KBServiceFactory
|
||||||
|
from server.knowledge_base.kb_service.base import KBService
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -18,12 +19,12 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||||
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
|
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
|
||||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||||
):
|
):
|
||||||
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
|
if kb is None:
|
||||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||||
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
|
||||||
|
|
||||||
async def knowledge_base_chat_iterator(query: str,
|
async def knowledge_base_chat_iterator(query: str,
|
||||||
kb: KnowledgeBase,
|
kb: KBService,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
) -> AsyncIterable[str]:
|
) -> AsyncIterable[str]:
|
||||||
callback = AsyncIteratorCallbackHandler()
|
callback = AsyncIteratorCallbackHandler()
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from .kb_api import list_kbs, create_kb, delete_kb
|
from .kb_api import list_kbs, create_kb, delete_kb
|
||||||
from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store
|
from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store
|
||||||
from .knowledge_base import KnowledgeBase
|
|
||||||
from .knowledge_file import KnowledgeFile
|
from .knowledge_file import KnowledgeFile
|
||||||
|
from .knowledge_base_factory import KBServiceFactory
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,14 @@
|
||||||
import urllib
|
import urllib
|
||||||
from server.utils import BaseResponse, ListResponse
|
from server.utils import BaseResponse, ListResponse
|
||||||
from server.knowledge_base.utils import validate_kb_name
|
from server.knowledge_base.utils import validate_kb_name
|
||||||
from server.knowledge_base.knowledge_base import KnowledgeBase
|
from server.knowledge_base.knowledge_base_factory import KBServiceFactory
|
||||||
|
from server.knowledge_base.kb_service.base import list_kbs_from_db
|
||||||
from configs.model_config import EMBEDDING_MODEL
|
from configs.model_config import EMBEDDING_MODEL
|
||||||
|
|
||||||
|
|
||||||
async def list_kbs():
|
async def list_kbs():
|
||||||
# Get List of Knowledge Base
|
# Get List of Knowledge Base
|
||||||
return ListResponse(data=KnowledgeBase.list_kbs())
|
return ListResponse(data=list_kbs_from_db())
|
||||||
|
|
||||||
|
|
||||||
async def create_kb(knowledge_base_name: str,
|
async def create_kb(knowledge_base_name: str,
|
||||||
|
|
@ -19,11 +20,10 @@ async def create_kb(knowledge_base_name: str,
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
if knowledge_base_name is None or knowledge_base_name.strip() == "":
|
if knowledge_base_name is None or knowledge_base_name.strip() == "":
|
||||||
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
|
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
|
||||||
if KnowledgeBase.exists(knowledge_base_name):
|
|
||||||
|
kb = KBServiceFactory.get_service(knowledge_base_name, "faiss")
|
||||||
|
if kb is not None:
|
||||||
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
|
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
|
||||||
kb = KnowledgeBase(knowledge_base_name=knowledge_base_name,
|
|
||||||
vector_store_type=vector_store_type,
|
|
||||||
embed_model=embed_model)
|
|
||||||
kb.create()
|
kb.create()
|
||||||
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
|
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
|
|
@ -34,10 +34,12 @@ async def delete_kb(knowledge_base_name: str):
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
||||||
|
|
||||||
if not KnowledgeBase.exists(knowledge_base_name):
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
|
|
||||||
|
if kb is None:
|
||||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
status = KnowledgeBase.delete(knowledge_base_name)
|
status = kb.drop_kb()
|
||||||
if status:
|
if status:
|
||||||
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
|
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,9 @@ from server.knowledge_base.utils import (validate_kb_name)
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import json
|
import json
|
||||||
from server.knowledge_base.knowledge_file import KnowledgeFile
|
from server.knowledge_base.knowledge_file import KnowledgeFile
|
||||||
from server.knowledge_base.knowledge_base import KnowledgeBase
|
from server.knowledge_base.knowledge_base_factory import KBServiceFactory
|
||||||
|
from server.knowledge_base.kb_service.base import SupportedVSType, list_docs_from_folder
|
||||||
|
from server.knowledge_base.kb_service.faiss_kb_service import refresh_vs_cache
|
||||||
|
|
||||||
|
|
||||||
async def list_docs(knowledge_base_name: str):
|
async def list_docs(knowledge_base_name: str):
|
||||||
|
|
@ -14,10 +16,11 @@ async def list_docs(knowledge_base_name: str):
|
||||||
return ListResponse(code=403, msg="Don't attack me", data=[])
|
return ListResponse(code=403, msg="Don't attack me", data=[])
|
||||||
|
|
||||||
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
||||||
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
|
if kb is None:
|
||||||
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
|
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
|
||||||
else:
|
else:
|
||||||
all_doc_names = KnowledgeBase.load(knowledge_base_name=knowledge_base_name).list_docs()
|
all_doc_names = kb.list_docs()
|
||||||
return ListResponse(data=all_doc_names)
|
return ListResponse(data=all_doc_names)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -28,11 +31,10 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
|
|
||||||
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
|
if kb is None:
|
||||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
|
||||||
|
|
||||||
file_content = await file.read() # 读取上传文件的内容
|
file_content = await file.read() # 读取上传文件的内容
|
||||||
|
|
||||||
kb_file = KnowledgeFile(filename=file.filename,
|
kb_file = KnowledgeFile(filename=file.filename,
|
||||||
|
|
@ -63,10 +65,10 @@ async def delete_doc(knowledge_base_name: str,
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
|
|
||||||
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
||||||
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
|
if kb is None:
|
||||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
|
||||||
if not kb.exist_doc(doc_name):
|
if not kb.exist_doc(doc_name):
|
||||||
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
|
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
|
||||||
kb_file = KnowledgeFile(filename=doc_name,
|
kb_file = KnowledgeFile(filename=doc_name,
|
||||||
|
|
@ -92,21 +94,26 @@ async def recreate_vector_store(knowledge_base_name: str):
|
||||||
recreate vector store from the content.
|
recreate vector store from the content.
|
||||||
this is usefull when user can copy files to content folder directly instead of upload through network.
|
this is usefull when user can copy files to content folder directly instead of upload through network.
|
||||||
'''
|
'''
|
||||||
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
|
if kb is None:
|
||||||
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
async def output(kb: KnowledgeBase):
|
async def output(kb):
|
||||||
kb.recreate_vs()
|
kb.clear_vs()
|
||||||
print(f"start to recreate vector store of {kb.kb_name}")
|
print(f"start to recreate vector store of {kb.kb_name}")
|
||||||
docs = kb.list_docs()
|
docs = list_docs_from_folder(knowledge_base_name)
|
||||||
|
print(docs)
|
||||||
for i, filename in enumerate(docs):
|
for i, filename in enumerate(docs):
|
||||||
|
yield json.dumps({
|
||||||
|
"total": len(docs),
|
||||||
|
"finished": i,
|
||||||
|
"doc": filename,
|
||||||
|
})
|
||||||
kb_file = KnowledgeFile(filename=filename,
|
kb_file = KnowledgeFile(filename=filename,
|
||||||
knowledge_base_name=kb.kb_name)
|
knowledge_base_name=kb.kb_name)
|
||||||
print(f"processing {kb_file.filepath} to vector store.")
|
print(f"processing {kb_file.filepath} to vector store.")
|
||||||
kb.add_doc(kb_file)
|
kb.add_doc(kb_file)
|
||||||
yield json.dumps({
|
if kb.vs_type == SupportedVSType.FAISS:
|
||||||
"total": len(docs),
|
refresh_vs_cache(knowledge_base_name)
|
||||||
"finished": i + 1,
|
|
||||||
"doc": filename,
|
|
||||||
})
|
|
||||||
|
|
||||||
return StreamingResponse(output(kb), media_type="text/event-stream")
|
return StreamingResponse(output(kb), media_type="text/event-stream")
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ import datetime
|
||||||
from server.knowledge_base.utils import (get_kb_path, get_doc_path)
|
from server.knowledge_base.utils import (get_kb_path, get_doc_path)
|
||||||
from server.knowledge_base.knowledge_file import KnowledgeFile
|
from server.knowledge_base.knowledge_file import KnowledgeFile
|
||||||
from typing import List
|
from typing import List
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
class SupportedVSType:
|
class SupportedVSType:
|
||||||
|
|
@ -125,6 +126,10 @@ def list_docs_from_db(kb_name):
|
||||||
conn.close()
|
conn.close()
|
||||||
return kbs
|
return kbs
|
||||||
|
|
||||||
|
def list_docs_from_folder(kb_name: str):
|
||||||
|
doc_path = get_doc_path(kb_name)
|
||||||
|
return [file for file in os.listdir(doc_path)
|
||||||
|
if os.path.isfile(os.path.join(doc_path, file))]
|
||||||
|
|
||||||
def add_doc_to_db(kb_file: KnowledgeFile):
|
def add_doc_to_db(kb_file: KnowledgeFile):
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||||
|
|
|
||||||
|
|
@ -5,13 +5,13 @@ class DefaultKBService(KBService):
|
||||||
def vs_type(self) -> str:
|
def vs_type(self) -> str:
|
||||||
return "default"
|
return "default"
|
||||||
|
|
||||||
def do_create_kbs(self):
|
def do_create_kb(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def do_init(self):
|
def do_init(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def do_drop_kbs(self):
|
def do_drop_kb(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def do_search(self):
|
def do_search(self):
|
||||||
|
|
@ -25,3 +25,6 @@ class DefaultKBService(KBService):
|
||||||
|
|
||||||
def do_delete_doc(self):
|
def do_delete_doc(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def kb_exists(self):
|
||||||
|
return False
|
||||||
|
|
|
||||||
|
|
@ -122,3 +122,4 @@ class FaissKBService(KBService):
|
||||||
|
|
||||||
def do_clear_vs(self):
|
def do_clear_vs(self):
|
||||||
shutil.rmtree(self.vs_path)
|
shutil.rmtree(self.vs_path)
|
||||||
|
os.makedirs(self.vs_path)
|
||||||
|
|
|
||||||
|
|
@ -1,28 +1,34 @@
|
||||||
|
from typing import Union
|
||||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, init_db, load_kb_from_db
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, init_db, load_kb_from_db
|
||||||
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
|
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
|
||||||
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||||
|
from configs.model_config import EMBEDDING_MODEL
|
||||||
|
|
||||||
|
|
||||||
class KBServiceFactory:
|
class KBServiceFactory:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_service(kb_name: str,
|
def get_service(kb_name: str,
|
||||||
vector_store_type: SupportedVSType
|
vector_store_type: Union[str, SupportedVSType],
|
||||||
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
) -> KBService:
|
) -> KBService:
|
||||||
|
if isinstance(vector_store_type, str):
|
||||||
|
vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
|
||||||
if SupportedVSType.FAISS == vector_store_type:
|
if SupportedVSType.FAISS == vector_store_type:
|
||||||
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
||||||
return FaissKBService(kb_name)
|
return FaissKBService(kb_name, embed_model=embed_model)
|
||||||
elif SupportedVSType.MILVUS == vector_store_type:
|
# todo: Milvus has different init params
|
||||||
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
# elif SupportedVSType.MILVUS == vector_store_type:
|
||||||
return MilvusKBService(kb_name)
|
# from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||||
elif SupportedVSType.DEFAULT == vector_store_type:
|
# return MilvusKBService(kb_name,)
|
||||||
|
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
|
||||||
return DefaultKBService(kb_name)
|
return DefaultKBService(kb_name)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_service_by_name(kb_name: str
|
def get_service_by_name(kb_name: str
|
||||||
) -> KBService:
|
) -> KBService:
|
||||||
kb_name, vs_type = load_kb_from_db(kb_name)
|
kb_name, vs_type, embed_model = load_kb_from_db(kb_name)
|
||||||
return KBServiceFactory.get_service(kb_name, vs_type)
|
return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_default():
|
def get_default():
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue