improve recreate_vector_store:
1. allow empty knowledge base 2. skip unspported files instead of exception
This commit is contained in:
parent
e918244159
commit
2b0f8caa62
|
|
@ -9,6 +9,7 @@ from server.knowledge_base.utils import KnowledgeFile, list_docs_from_folder
|
||||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||||
from server.knowledge_base.kb_service.base import SupportedVSType
|
from server.knowledge_base.kb_service.base import SupportedVSType
|
||||||
from server.knowledge_base.kb_service.faiss_kb_service import refresh_vs_cache
|
from server.knowledge_base.kb_service.faiss_kb_service import refresh_vs_cache
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
async def list_docs(knowledge_base_name: str):
|
async def list_docs(knowledge_base_name: str):
|
||||||
|
|
@ -89,14 +90,23 @@ async def download_doc():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def recreate_vector_store(knowledge_base_name: str):
|
async def recreate_vector_store(
|
||||||
|
knowledge_base_name: str,
|
||||||
|
allow_empty_kb: bool = True,
|
||||||
|
vs_type: Union[str, SupportedVSType] = "faiss",
|
||||||
|
):
|
||||||
'''
|
'''
|
||||||
recreate vector store from the content.
|
recreate vector store from the content.
|
||||||
this is usefull when user can copy files to content folder directly instead of upload through network.
|
this is usefull when user can copy files to content folder directly instead of upload through network.
|
||||||
|
by default, get_service_by_name only return knowledge base in the info.db and having document files in it.
|
||||||
|
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
|
||||||
'''
|
'''
|
||||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
if kb is None:
|
if kb is None:
|
||||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
if allow_empty_kb:
|
||||||
|
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type)
|
||||||
|
else:
|
||||||
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
async def output(kb):
|
async def output(kb):
|
||||||
kb.clear_vs()
|
kb.clear_vs()
|
||||||
|
|
@ -109,10 +119,13 @@ async def recreate_vector_store(knowledge_base_name: str):
|
||||||
"finished": i,
|
"finished": i,
|
||||||
"doc": filename,
|
"doc": filename,
|
||||||
})
|
})
|
||||||
kb_file = KnowledgeFile(filename=filename,
|
try:
|
||||||
knowledge_base_name=kb.kb_name)
|
kb_file = KnowledgeFile(filename=filename,
|
||||||
print(f"processing {kb_file.filepath} to vector store.")
|
knowledge_base_name=kb.kb_name)
|
||||||
kb.add_doc(kb_file)
|
print(f"processing {kb_file.filepath} to vector store.")
|
||||||
|
kb.add_doc(kb_file)
|
||||||
|
except ValueError as e:
|
||||||
|
print(e)
|
||||||
if kb.vs_type == SupportedVSType.FAISS:
|
if kb.vs_type == SupportedVSType.FAISS:
|
||||||
refresh_vs_cache(knowledge_base_name)
|
refresh_vs_cache(knowledge_base_name)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import os
|
||||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||||
from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, EMBEDDING_MODEL, kbs_config)
|
from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, EMBEDDING_MODEL, kbs_config)
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
def validate_kb_name(knowledge_base_id: str) -> bool:
|
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ from fastapi.responses import StreamingResponse
|
||||||
import contextlib
|
import contextlib
|
||||||
import json
|
import json
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
from server.knowledge_base.utils import list_kbs_from_folder
|
||||||
|
|
||||||
|
|
||||||
def set_httpx_timeout(timeout=60.0):
|
def set_httpx_timeout(timeout=60.0):
|
||||||
|
|
@ -468,6 +469,9 @@ class ApiRequest:
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
from server.db.base import Base, engine
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
|
||||||
api = ApiRequest(no_remote_api=True)
|
api = ApiRequest(no_remote_api=True)
|
||||||
|
|
||||||
# print(api.chat_fastchat(
|
# print(api.chat_fastchat(
|
||||||
|
|
@ -488,5 +492,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# print(api.list_knowledge_bases())
|
# print(api.list_knowledge_bases())
|
||||||
|
|
||||||
for t in api.recreate_vector_store('kblog'):
|
# recreate all vector store
|
||||||
print(t)
|
for kb in list_kbs_from_folder():
|
||||||
|
for t in api.recreate_vector_store(kb):
|
||||||
|
print(t)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue