diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 3601b57..d2e8525 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -28,7 +28,7 @@ def load_faiss_vector_store( embed_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None, tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed. -): +) -> FAISS: print(f"loading vector store in '{knowledge_base_name}'.") vs_path = get_vs_path(knowledge_base_name) if embeddings is None: @@ -74,13 +74,18 @@ class FaissKBService(KBService): def get_kb_path(self): return os.path.join(KB_ROOT_PATH, self.kb_name) - def load_vector_store(self): + def load_vector_store(self) -> FAISS: return load_faiss_vector_store( knowledge_base_name=self.kb_name, embed_model=self.embed_model, tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0), ) + def save_vector_store(self, vector_store: FAISS = None): + vector_store = vector_store or self.load_vector_store() + vector_store.save_local(self.vs_path) + return vector_store + def refresh_vs_cache(self): refresh_vs_cache(self.kb_name) @@ -117,11 +122,11 @@ class FaissKBService(KBService): if not kwargs.get("not_refresh_vs_cache"): vector_store.save_local(self.vs_path) self.refresh_vs_cache() + return vector_store def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): - embeddings = self._load_embeddings() vector_store = self.load_vector_store() ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] @@ -133,7 +138,7 @@ class FaissKBService(KBService): vector_store.save_local(self.vs_path) self.refresh_vs_cache() - return True + return vector_store def do_clear_vs(self): shutil.rmtree(self.vs_path) diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index af506e2..4285b79 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -69,6 +69,7 @@ def folder2db( print(result) if kb.vs_type() == SupportedVSType.FAISS: + kb.save_vector_store() kb.refresh_vs_cache() elif mode == "fill_info_only": files = list_files_from_folder(kb_name) @@ -85,6 +86,7 @@ def folder2db( kb.update_doc(kb_file, not_refresh_vs_cache=True) if kb.vs_type() == SupportedVSType.FAISS: + kb.save_vector_store() kb.refresh_vs_cache() elif mode == "increament": db_files = kb.list_files() @@ -102,6 +104,7 @@ def folder2db( print(result) if kb.vs_type() == SupportedVSType.FAISS: + kb.save_vector_store() kb.refresh_vs_cache() else: print(f"unspported migrate mode: {mode}") @@ -131,7 +134,10 @@ def prune_db_files(kb_name: str): files = list(set(files_in_db) - set(files_in_folder)) kb_files = file_to_kbfile(kb_name, files) for kb_file in kb_files: - kb.delete_doc(kb_file) + kb.delete_doc(kb_file, not_refresh_vs_cache=True) + if kb.vs_type() == SupportedVSType.FAISS: + kb.save_vector_store() + kb.refresh_vs_cache() return kb_files def prune_folder_files(kb_name: str):