From f40bb69224df37db171a6dbd05afba11b30a353d Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Sun, 20 Aug 2023 19:10:29 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96FAISS=E5=90=91=E9=87=8F?= =?UTF-8?q?=E5=BA=93=E5=A4=9A=E6=96=87=E4=BB=B6=E6=93=8D=E4=BD=9C=EF=BC=9B?= =?UTF-8?q?=E4=BF=AE=E5=A4=8Drecreate=5Fvector=5Fstore=EF=BC=8C=E5=A4=A7?= =?UTF-8?q?=E9=87=8F=E6=96=87=E4=BB=B6=E6=97=B6=E4=B8=8D=E5=86=8D=E8=B6=85?= =?UTF-8?q?=E6=97=B6=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/knowledge_base/kb_doc_api.py | 11 ++++++----- server/knowledge_base/kb_service/base.py | 14 +++++++------- .../kb_service/faiss_kb_service.py | 15 ++++++++++----- webui_pages/knowledge_base/knowledge_base.py | 6 ++++-- webui_pages/utils.py | 18 +++++++++++++++--- 5 files changed, 42 insertions(+), 22 deletions(-) diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 0d74fd6..74edc98 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -47,6 +47,7 @@ async def list_docs( async def upload_doc(file: UploadFile = File(..., description="上传文件"), knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]), override: bool = Form(False, description="覆盖已有文件"), + not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -76,7 +77,7 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"), return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}") try: - kb.add_doc(kb_file) + kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) except Exception as e: print(e) return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}") @@ -87,6 +88,7 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"), async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]), doc_name: str = Body(..., examples=["file_name.md"]), delete_content: bool = Body(False), + not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -102,7 +104,7 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]), try: kb_file = KnowledgeFile(filename=doc_name, knowledge_base_name=knowledge_base_name) - kb.delete_doc(kb_file, delete_content) + kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=not_refresh_vs_cache) except Exception as e: print(e) return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}") @@ -113,6 +115,7 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]), async def update_doc( knowledge_base_name: str = Body(..., examples=["samples"]), file_name: str = Body(..., examples=["file_name"]), + not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: ''' 更新知识库文档 @@ -128,7 +131,7 @@ async def update_doc( kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name) if os.path.exists(kb_file.filepath): - kb.update_doc(kb_file) + kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}") except Exception as e: print(e) @@ -205,7 +208,5 @@ async def recreate_vector_store( "code": 500, "msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。", }) - import asyncio - await asyncio.sleep(5) return StreamingResponse(output(), media_type="text/event-stream") diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 09766e6..9af5b0e 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -71,7 +71,7 @@ class KBService(ABC): status = delete_kb_from_db(self.kb_name) return status - def add_doc(self, kb_file: KnowledgeFile): + def add_doc(self, kb_file: KnowledgeFile, **kwargs): """ 向知识库添加文件 """ @@ -79,29 +79,29 @@ class KBService(ABC): if docs: self.delete_doc(kb_file) embeddings = self._load_embeddings() - self.do_add_doc(docs, embeddings) + self.do_add_doc(docs, embeddings, **kwargs) status = add_doc_to_db(kb_file) else: status = False return status - def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False): + def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs): """ 从知识库删除文件 """ - self.do_delete_doc(kb_file) + self.do_delete_doc(kb_file, **kwargs) status = delete_file_from_db(kb_file) if delete_content and os.path.exists(kb_file.filepath): os.remove(kb_file.filepath) return status - def update_doc(self, kb_file: KnowledgeFile): + def update_doc(self, kb_file: KnowledgeFile, **kwargs): """ 使用content中的文件更新向量库 """ if os.path.exists(kb_file.filepath): - self.delete_doc(kb_file) - return self.add_doc(kb_file) + self.delete_doc(kb_file, **kwargs) + return self.add_doc(kb_file, **kwargs) def exist_doc(self, file_name: str): return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name, diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 5953c3a..9fccfa2 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -66,6 +66,7 @@ def refresh_vs_cache(kb_name: str): make vector store cache refreshed when next loading """ _VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1 + print(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}") class FaissKBService(KBService): @@ -111,17 +112,20 @@ class FaissKBService(KBService): def do_add_doc(self, docs: List[Document], embeddings: Embeddings, + **kwargs, ): vector_store = load_vector_store(self.kb_name, embeddings=embeddings, tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0)) vector_store.add_documents(docs) torch_gc() - vector_store.save_local(self.vs_path) - refresh_vs_cache(self.kb_name) + if not kwargs.get("not_refresh_vs_cache"): + vector_store.save_local(self.vs_path) + refresh_vs_cache(self.kb_name) def do_delete_doc(self, - kb_file: KnowledgeFile): + kb_file: KnowledgeFile, + **kwargs): embeddings = self._load_embeddings() vector_store = load_vector_store(self.kb_name, embeddings=embeddings, @@ -132,8 +136,9 @@ class FaissKBService(KBService): return None vector_store.delete(ids) - vector_store.save_local(self.vs_path) - refresh_vs_cache(self.kb_name) + if not kwargs.get("not_refresh_vs_cache"): + vector_store.save_local(self.vs_path) + refresh_vs_cache(self.kb_name) return True diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 3bd531a..4351e95 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -138,8 +138,10 @@ def knowledge_base_page(api: ApiRequest): # use_container_width=True, disabled=len(files) == 0, ): - for f in files: - ret = api.upload_kb_doc(f, kb) + data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files] + data[-1]["not_refresh_vs_cache"]=False + for k in data: + ret = api.upload_kb_doc(**k) if msg := check_success_msg(ret): st.toast(msg, icon="✔") elif msg := check_error_msg(ret): diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 18a24e4..c666d45 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -496,6 +496,7 @@ class ApiRequest: knowledge_base_name: str, filename: str = None, override: bool = False, + not_refresh_vs_cache: bool = False, no_remote_api: bool = None, ): ''' @@ -529,7 +530,11 @@ class ApiRequest: else: response = self.post( "/knowledge_base/upload_doc", - data={"knowledge_base_name": knowledge_base_name, "override": override}, + data={ + "knowledge_base_name": knowledge_base_name, + "override": override, + "not_refresh_vs_cache": not_refresh_vs_cache, + }, files={"file": (filename, file)}, ) return self._check_httpx_json_response(response) @@ -539,6 +544,7 @@ class ApiRequest: knowledge_base_name: str, doc_name: str, delete_content: bool = False, + not_refresh_vs_cache: bool = False, no_remote_api: bool = None, ): ''' @@ -551,6 +557,7 @@ class ApiRequest: "knowledge_base_name": knowledge_base_name, "doc_name": doc_name, "delete_content": delete_content, + "not_refresh_vs_cache": not_refresh_vs_cache, } if no_remote_api: @@ -568,6 +575,7 @@ class ApiRequest: self, knowledge_base_name: str, file_name: str, + not_refresh_vs_cache: bool = False, no_remote_api: bool = None, ): ''' @@ -583,7 +591,11 @@ class ApiRequest: else: response = self.post( "/knowledge_base/update_doc", - json={"knowledge_base_name": knowledge_base_name, "file_name": file_name}, + json={ + "knowledge_base_name": knowledge_base_name, + "file_name": file_name, + "not_refresh_vs_cache": not_refresh_vs_cache, + }, ) return self._check_httpx_json_response(response) @@ -617,7 +629,7 @@ class ApiRequest: "/knowledge_base/recreate_vector_store", json=data, stream=True, - timeout=False, + timeout=None, ) return self._httpx_stream2generator(response, as_json=True)