From ec984205aece0932e6465aaf643d750472d5aef4 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Mon, 14 Aug 2023 19:09:02 +0800 Subject: [PATCH] fix knowledge base management: 1. docs in database were note deleted when clear vector store 2. diable buttons when local doc file not exist. --- .../repository/knowledge_file_repository.py | 12 ++++++++ server/knowledge_base/kb_api.py | 14 +++++---- server/knowledge_base/kb_service/base.py | 9 ++++-- webui_pages/knowledge_base/knowledge_base.py | 30 ++++++++++++++----- 4 files changed, 49 insertions(+), 16 deletions(-) diff --git a/server/db/repository/knowledge_file_repository.py b/server/db/repository/knowledge_file_repository.py index 13c60c8..f5e912f 100644 --- a/server/db/repository/knowledge_file_repository.py +++ b/server/db/repository/knowledge_file_repository.py @@ -49,6 +49,18 @@ def delete_file_from_db(session, kb_file: KnowledgeFile): return True +@with_session +def delete_files_from_db(session, knowledge_base_name: str): + session.query(KnowledgeFileModel).filter_by(kb_name=knowledge_base_name).delete() + + kb = session.query(KnowledgeBaseModel).filter_by(kb_name=knowledge_base_name).first() + if kb: + kb.file_count = 0 + + session.commit() + return True + + @with_session def doc_exists(session, kb_file: KnowledgeFile): existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename, diff --git a/server/knowledge_base/kb_api.py b/server/knowledge_base/kb_api.py index 8df14cb..4753ba4 100644 --- a/server/knowledge_base/kb_api.py +++ b/server/knowledge_base/kb_api.py @@ -44,8 +44,12 @@ async def delete_kb( if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - status = kb.drop_kb() - if status: - return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}") - else: - return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}") + try: + status = kb.clear_vs() + status = kb.drop_kb() + if status: + return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}") + except Exception as e: + print(e) + + return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}") diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index dcd18cc..ec1c692 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -9,8 +9,8 @@ from server.db.repository.knowledge_base_repository import ( load_kb_from_db, get_kb_detail, ) from server.db.repository.knowledge_file_repository import ( - add_doc_to_db, delete_file_from_db, doc_exists, - list_docs_from_db, get_file_detail + add_doc_to_db, delete_file_from_db, delete_files_from_db, doc_exists, + list_docs_from_db, get_file_detail, delete_file_from_db ) from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, @@ -56,9 +56,12 @@ class KBService(ABC): def clear_vs(self): """ - 用知识库中已上传文件重建向量库 + 删除向量库中所有内容 """ self.do_clear_vs() + status = delete_files_from_db(self.kb_name) + return status + def drop_kb(self): """ diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index c7b9d14..f963a67 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -1,5 +1,3 @@ -import sqlite3 - import streamlit as st from webui_pages.utils import * from st_aggrid import AgGrid, JsCode @@ -9,6 +7,9 @@ from server.knowledge_base.utils import get_file_path, LOADER_DICT from server.knowledge_base.kb_service.base import get_kb_details, get_kb_doc_details from typing import Literal, Dict, Tuple from configs.model_config import embedding_model_dict, kbs_config, EMBEDDING_MODEL, DEFAULT_VS_TYPE +import os +import time + # SENTENCE_SIZE = 100 @@ -33,6 +34,19 @@ def config_aggrid( return gb +def file_exists(kb: str, selected_rows: List) -> Tuple[str, str]: + ''' + check whether a doc file exists in local knowledge base folder. + return the file's name and path if it exists. + ''' + if selected_rows: + file_name = selected_rows[0]["file_name"] + file_path = get_file_path(kb, file_name) + if os.path.isfile(file_path): + return file_name, file_path + return "", "" + + def knowledge_base_page(api: ApiRequest): try: kb_list = get_kb_details() @@ -174,9 +188,8 @@ def knowledge_base_page(api: ApiRequest): selected_rows = doc_grid.get("selected_rows", []) cols = st.columns(4) - if selected_rows: - file_name = selected_rows[0]["file_name"] - file_path = get_file_path(kb, file_name) + file_name, file_path = file_exists(kb, selected_rows) + if file_path: with open(file_path, "rb") as fp: cols[0].download_button( "下载选中文档", @@ -194,7 +207,7 @@ def knowledge_base_page(api: ApiRequest): # 将文件分词并加载到向量库中 if cols[1].button( "重新添加至向量库" if selected_rows and (pd.DataFrame(selected_rows)["in_db"]).any() else "添加至向量库", - disabled=len(selected_rows) == 0, + disabled=not file_exists(kb, selected_rows)[0], use_container_width=True, ): for row in selected_rows: @@ -204,7 +217,7 @@ def knowledge_base_page(api: ApiRequest): # 将文件从向量库中删除,但不删除文件本身。 if cols[2].button( "从向量库删除", - disabled=len(selected_rows) == 0, + disabled=not (selected_rows and selected_rows[0]["in_db"]), use_container_width=True, ): for row in selected_rows: @@ -245,5 +258,6 @@ def knowledge_base_page(api: ApiRequest): use_container_width=True, ): ret = api.delete_knowledge_base(kb) - st.experimental_rerun() st.toast(ret["msg"]) + time.sleep(1) + st.experimental_rerun()