improve recreate_vector_store:

1. allow empty knowledge base
2. skip unspported files instead of exception
This commit is contained in:
liunux4odoo 2023-08-08 17:59:41 +08:00
parent e918244159
commit 2b0f8caa62
3 changed files with 28 additions and 8 deletions

View File

@ -9,6 +9,7 @@ from server.knowledge_base.utils import KnowledgeFile, list_docs_from_folder
from server.knowledge_base.kb_service.base import KBServiceFactory from server.knowledge_base.kb_service.base import KBServiceFactory
from server.knowledge_base.kb_service.base import SupportedVSType from server.knowledge_base.kb_service.base import SupportedVSType
from server.knowledge_base.kb_service.faiss_kb_service import refresh_vs_cache from server.knowledge_base.kb_service.faiss_kb_service import refresh_vs_cache
from typing import Union
async def list_docs(knowledge_base_name: str): async def list_docs(knowledge_base_name: str):
@ -89,14 +90,23 @@ async def download_doc():
pass pass
async def recreate_vector_store(knowledge_base_name: str): async def recreate_vector_store(
knowledge_base_name: str,
allow_empty_kb: bool = True,
vs_type: Union[str, SupportedVSType] = "faiss",
):
''' '''
recreate vector store from the content. recreate vector store from the content.
this is usefull when user can copy files to content folder directly instead of upload through network. this is usefull when user can copy files to content folder directly instead of upload through network.
by default, get_service_by_name only return knowledge base in the info.db and having document files in it.
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
''' '''
kb = KBServiceFactory.get_service_by_name(knowledge_base_name) kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None: if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") if allow_empty_kb:
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type)
else:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
async def output(kb): async def output(kb):
kb.clear_vs() kb.clear_vs()
@ -109,10 +119,13 @@ async def recreate_vector_store(knowledge_base_name: str):
"finished": i, "finished": i,
"doc": filename, "doc": filename,
}) })
kb_file = KnowledgeFile(filename=filename, try:
knowledge_base_name=kb.kb_name) kb_file = KnowledgeFile(filename=filename,
print(f"processing {kb_file.filepath} to vector store.") knowledge_base_name=kb.kb_name)
kb.add_doc(kb_file) print(f"processing {kb_file.filepath} to vector store.")
kb.add_doc(kb_file)
except ValueError as e:
print(e)
if kb.vs_type == SupportedVSType.FAISS: if kb.vs_type == SupportedVSType.FAISS:
refresh_vs_cache(knowledge_base_name) refresh_vs_cache(knowledge_base_name)

View File

@ -3,6 +3,7 @@ import os
from langchain.embeddings.huggingface import HuggingFaceEmbeddings from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, EMBEDDING_MODEL, kbs_config) from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, EMBEDDING_MODEL, kbs_config)
from functools import lru_cache from functools import lru_cache
import sys
def validate_kb_name(knowledge_base_id: str) -> bool: def validate_kb_name(knowledge_base_id: str) -> bool:

View File

@ -16,6 +16,7 @@ from fastapi.responses import StreamingResponse
import contextlib import contextlib
import json import json
from io import BytesIO from io import BytesIO
from server.knowledge_base.utils import list_kbs_from_folder
def set_httpx_timeout(timeout=60.0): def set_httpx_timeout(timeout=60.0):
@ -468,6 +469,9 @@ class ApiRequest:
if __name__ == "__main__": if __name__ == "__main__":
from server.db.base import Base, engine
Base.metadata.create_all(bind=engine)
api = ApiRequest(no_remote_api=True) api = ApiRequest(no_remote_api=True)
# print(api.chat_fastchat( # print(api.chat_fastchat(
@ -488,5 +492,7 @@ if __name__ == "__main__":
# print(api.list_knowledge_bases()) # print(api.list_knowledge_bases())
for t in api.recreate_vector_store('kblog'): # recreate all vector store
print(t) for kb in list_kbs_from_folder():
for t in api.recreate_vector_store(kb):
print(t)