diff --git a/server/api.py b/server/api.py index eecd478..8a7a2c1 100644 --- a/server/api.py +++ b/server/api.py @@ -94,11 +94,11 @@ def create_app(): summary="删除知识库内的文件" )(delete_doc) - # app.post("/knowledge_base/update_doc", - # tags=["Knowledge Base Management"], - # response_model=BaseResponse, - # summary="上传文件到知识库,并删除另一个文件" - # )(update_doc) + app.post("/knowledge_base/update_doc", + tags=["Knowledge Base Management"], + response_model=BaseResponse, + summary="更新现有文件到知识库" + )(update_doc) app.post("/knowledge_base/recreate_vector_store", tags=["Knowledge Base Management"], diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 636e606..864662b 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -1,8 +1,8 @@ import os import urllib -from fastapi import File, Form, UploadFile +from fastapi import File, Form, Body, UploadFile from server.utils import BaseResponse, ListResponse -from server.knowledge_base.utils import (validate_kb_name) +from server.knowledge_base.utils import (get_file_path, validate_kb_name) from fastapi.responses import StreamingResponse import json from server.knowledge_base.utils import KnowledgeFile, list_docs_from_folder @@ -58,9 +58,10 @@ async def upload_doc(file: UploadFile = File(description="上传文件"), return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}") -async def delete_doc(knowledge_base_name: str, - doc_name: str, - ): +async def delete_doc(knowledge_base_name: str = Body(...), + doc_name: str = Body(...), + delete_content: bool = Body(...), + ): if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -73,14 +74,33 @@ async def delete_doc(knowledge_base_name: str, return BaseResponse(code=404, msg=f"未找到文件 {doc_name}") kb_file = KnowledgeFile(filename=doc_name, knowledge_base_name=knowledge_base_name) - kb.delete_doc(kb_file) + kb.delete_doc(kb_file, delete_content) return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功") # return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败") -async def update_doc(): - # TODO: 替换文件 - pass +async def update_doc( + knowledge_base_name: str = Body(...), + file_name: str = Body(...), + ): + ''' + 更新知识库文档 + ''' + if not validate_kb_name(knowledge_base_name): + return BaseResponse(code=403, msg="Don't attack me") + + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: + return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") + + kb_file = KnowledgeFile(filename=file_name, + knowledge_base_name=knowledge_base_name) + + if os.path.exists(kb_file.filepath): + kb.update_doc(kb_file) + return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}") + else: + return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}") async def download_doc(): @@ -89,9 +109,9 @@ async def download_doc(): async def recreate_vector_store( - knowledge_base_name: str, - allow_empty_kb: bool = True, - vs_type: str = "faiss", + knowledge_base_name: str = Body(...), + allow_empty_kb: bool = Body(True), + vs_type: str = Body("faiss"), ): ''' recreate vector store from the content. diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 5af47e7..8a675f3 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -65,21 +65,32 @@ class KBService(ABC): 向知识库添加文件 """ docs = kb_file.file2text() - embeddings = self._load_embeddings() - self.do_add_doc(docs, embeddings) - status = add_doc_to_db(kb_file) + if docs: + embeddings = self._load_embeddings() + self.do_add_doc(docs, embeddings) + status = add_doc_to_db(kb_file) + else: + status = False return status - def delete_doc(self, kb_file: KnowledgeFile): + def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False): """ 从知识库删除文件 """ - if os.path.exists(kb_file.filepath): + if delete_content and os.path.exists(kb_file.filepath): os.remove(kb_file.filepath) self.do_delete_doc(kb_file) status = delete_file_from_db(kb_file) return status + def update_doc(self, kb_file: KnowledgeFile): + """ + 使用content中的文件更新向量库 + """ + if os.path.exists(kb_file.filepath): + self.delete_doc(kb_file) + return self.add_doc(kb_file) + def exist_doc(self, file_name: str): return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name, filename=file_name)) diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 064eb43..87eca2b 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -136,3 +136,13 @@ class FaissKBService(KBService): def do_clear_vs(self): shutil.rmtree(self.vs_path) os.makedirs(self.vs_path) + + def exist_doc(self, file_name: str): + if super().exist_doc(file_name): + return "in_db" + + content_path = os.path.join(self.kb_path, "content") + if os.path.isfile(os.path.join(content_path, file_name)): + return "in_folder" + else: + return False diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 99f9472..5db5097 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -3,6 +3,7 @@ import os from langchain.embeddings.huggingface import HuggingFaceEmbeddings from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, EMBEDDING_MODEL, kbs_config) from functools import lru_cache +import langchain.document_loaders import sys diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 83bedbd..d6399a8 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -88,7 +88,7 @@ def config_aggrid( gb = GridOptionsBuilder.from_dataframe(df) gb.configure_column("No", width=50) for k, v in titles.items(): - gb.configure_column(k, v, maxWidth=100) + gb.configure_column(k, v, maxWidth=100, wrapHeaderText=True) gb.configure_selection(selection_mode, use_checkbox, pre_selected_rows=[0]) return gb @@ -149,7 +149,6 @@ def knowledge_base_page(api: ApiRequest): files = st.file_uploader("上传知识文件", ["docx", "txt", "md", "csv", "xlsx", "pdf"], accept_multiple_files=True, - key="files", ) if st.button( "添加文件到知识库", @@ -199,7 +198,7 @@ def knowledge_base_page(api: ApiRequest): cols = st.columns(3) selected_rows = doc_grid.get("selected_rows", []) - cols = st.columns([2, 3, 2]) + cols = st.columns(4) if selected_rows: file_name = selected_rows[0]["file_name"] file_path = get_file_path(kb, file_name) @@ -207,9 +206,20 @@ def knowledge_base_page(api: ApiRequest): cols[0].download_button("下载选中文档", fp, file_name=file_name) else: cols[0].download_button("下载选中文档", "", disabled=True) - if cols[2].button("删除选中文档!", type="primary"): + + if cols[1].button("入库", disabled=len(selected_rows)==0): for row in selected_rows: - ret = api.delete_kb_doc(kb, row["file_name"]) + api.update_kb_doc(kb, row["file_name"]) + st.experimental_rerun() + + if cols[2].button("出库", disabled=len(selected_rows)==0): + for row in selected_rows: + api.delete_kb_doc(kb, row["file_name"]) + st.experimental_rerun() + + if cols[3].button("删除选中文档!", type="primary"): + for row in selected_rows: + ret = api.delete_kb_doc(kb, row["file_name"], True) st.toast(ret["msg"]) st.experimental_rerun() diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 10c905e..a273d8b 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -397,9 +397,11 @@ class ApiRequest: if no_remote_api is None: no_remote_api = self.no_remote_api - if isinstance(file, bytes): + if isinstance(file, bytes): # raw bytes file = BytesIO(file) - else: + elif hasattr(file, "read"): # a file io like object + filename = filename or file.name + else: # a local path file = Path(file).absolute().open("rb") filename = filename or file.name @@ -410,6 +412,7 @@ class ApiRequest: temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024) temp_file.write(file.read()) + temp_file.seek(0) response = run_async(upload_doc( UploadFile(temp_file, filename=filename), knowledge_base_name, @@ -428,6 +431,7 @@ class ApiRequest: self, knowledge_base_name: str, doc_name: str, + delete_content: bool = False, no_remote_api: bool = None, ): ''' @@ -438,11 +442,34 @@ class ApiRequest: if no_remote_api: from server.knowledge_base.kb_doc_api import delete_doc - response = run_async(delete_doc(knowledge_base_name, doc_name)) + response = run_async(delete_doc(knowledge_base_name, doc_name, delete_content)) return response.dict() else: response = self.delete( "/knowledge_base/delete_doc", + json={"knowledge_base_name": knowledge_base_name, "doc_name": doc_name, "delete_content": delete_content}, + ) + return response.json() + + def update_kb_doc( + self, + knowledge_base_name: str, + doc_name: str, + no_remote_api: bool = None, + ): + ''' + 对应api.py/knowledge_base/update_doc接口 + ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + + if no_remote_api: + from server.knowledge_base.kb_doc_api import update_doc + response = run_async(update_doc(knowledge_base_name, doc_name)) + return response.dict() + else: + response = self.delete( + "/knowledge_base/update_doc", json={"knowledge_base_name": knowledge_base_name, "doc_name": doc_name}, ) return response.json()