From 2b0f8caa620a1ba593cde1195c76a76feb5ef7fe Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Tue, 8 Aug 2023 17:59:41 +0800 Subject: [PATCH] improve recreate_vector_store: 1. allow empty knowledge base 2. skip unspported files instead of exception --- server/knowledge_base/kb_doc_api.py | 25 +++++++++++++++++++------ server/knowledge_base/utils.py | 1 + webui_pages/utils.py | 10 ++++++++-- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 41c399e..a995014 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -9,6 +9,7 @@ from server.knowledge_base.utils import KnowledgeFile, list_docs_from_folder from server.knowledge_base.kb_service.base import KBServiceFactory from server.knowledge_base.kb_service.base import SupportedVSType from server.knowledge_base.kb_service.faiss_kb_service import refresh_vs_cache +from typing import Union async def list_docs(knowledge_base_name: str): @@ -89,14 +90,23 @@ async def download_doc(): pass -async def recreate_vector_store(knowledge_base_name: str): +async def recreate_vector_store( + knowledge_base_name: str, + allow_empty_kb: bool = True, + vs_type: Union[str, SupportedVSType] = "faiss", + ): ''' recreate vector store from the content. this is usefull when user can copy files to content folder directly instead of upload through network. + by default, get_service_by_name only return knowledge base in the info.db and having document files in it. + set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents. ''' kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: - return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") + if allow_empty_kb: + kb = KBServiceFactory.get_service(knowledge_base_name, vs_type) + else: + return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") async def output(kb): kb.clear_vs() @@ -109,10 +119,13 @@ async def recreate_vector_store(knowledge_base_name: str): "finished": i, "doc": filename, }) - kb_file = KnowledgeFile(filename=filename, - knowledge_base_name=kb.kb_name) - print(f"processing {kb_file.filepath} to vector store.") - kb.add_doc(kb_file) + try: + kb_file = KnowledgeFile(filename=filename, + knowledge_base_name=kb.kb_name) + print(f"processing {kb_file.filepath} to vector store.") + kb.add_doc(kb_file) + except ValueError as e: + print(e) if kb.vs_type == SupportedVSType.FAISS: refresh_vs_cache(knowledge_base_name) diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index a616424..99f9472 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -3,6 +3,7 @@ import os from langchain.embeddings.huggingface import HuggingFaceEmbeddings from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, EMBEDDING_MODEL, kbs_config) from functools import lru_cache +import sys def validate_kb_name(knowledge_base_id: str) -> bool: diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 5745fc5..7054115 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -16,6 +16,7 @@ from fastapi.responses import StreamingResponse import contextlib import json from io import BytesIO +from server.knowledge_base.utils import list_kbs_from_folder def set_httpx_timeout(timeout=60.0): @@ -468,6 +469,9 @@ class ApiRequest: if __name__ == "__main__": + from server.db.base import Base, engine + Base.metadata.create_all(bind=engine) + api = ApiRequest(no_remote_api=True) # print(api.chat_fastchat( @@ -488,5 +492,7 @@ if __name__ == "__main__": # print(api.list_knowledge_bases()) - for t in api.recreate_vector_store('kblog'): - print(t) + # recreate all vector store + for kb in list_kbs_from_folder(): + for t in api.recreate_vector_store(kb): + print(t)