bug fix: faiss vector store not saved when recreate

This commit is contained in:
liunux4odoo 2023-08-31 16:43:42 +08:00
parent 4bdfb6e154
commit 80590ef5dc
2 changed files with 16 additions and 5 deletions

View File

@ -28,7 +28,7 @@ def load_faiss_vector_store(
embed_device: str = EMBEDDING_DEVICE, embed_device: str = EMBEDDING_DEVICE,
embeddings: Embeddings = None, embeddings: Embeddings = None,
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed. tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
): ) -> FAISS:
print(f"loading vector store in '{knowledge_base_name}'.") print(f"loading vector store in '{knowledge_base_name}'.")
vs_path = get_vs_path(knowledge_base_name) vs_path = get_vs_path(knowledge_base_name)
if embeddings is None: if embeddings is None:
@ -74,13 +74,18 @@ class FaissKBService(KBService):
def get_kb_path(self): def get_kb_path(self):
return os.path.join(KB_ROOT_PATH, self.kb_name) return os.path.join(KB_ROOT_PATH, self.kb_name)
def load_vector_store(self): def load_vector_store(self) -> FAISS:
return load_faiss_vector_store( return load_faiss_vector_store(
knowledge_base_name=self.kb_name, knowledge_base_name=self.kb_name,
embed_model=self.embed_model, embed_model=self.embed_model,
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0), tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0),
) )
def save_vector_store(self, vector_store: FAISS = None):
vector_store = vector_store or self.load_vector_store()
vector_store.save_local(self.vs_path)
return vector_store
def refresh_vs_cache(self): def refresh_vs_cache(self):
refresh_vs_cache(self.kb_name) refresh_vs_cache(self.kb_name)
@ -117,11 +122,11 @@ class FaissKBService(KBService):
if not kwargs.get("not_refresh_vs_cache"): if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path) vector_store.save_local(self.vs_path)
self.refresh_vs_cache() self.refresh_vs_cache()
return vector_store
def do_delete_doc(self, def do_delete_doc(self,
kb_file: KnowledgeFile, kb_file: KnowledgeFile,
**kwargs): **kwargs):
embeddings = self._load_embeddings()
vector_store = self.load_vector_store() vector_store = self.load_vector_store()
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
@ -133,7 +138,7 @@ class FaissKBService(KBService):
vector_store.save_local(self.vs_path) vector_store.save_local(self.vs_path)
self.refresh_vs_cache() self.refresh_vs_cache()
return True return vector_store
def do_clear_vs(self): def do_clear_vs(self):
shutil.rmtree(self.vs_path) shutil.rmtree(self.vs_path)

View File

@ -69,6 +69,7 @@ def folder2db(
print(result) print(result)
if kb.vs_type() == SupportedVSType.FAISS: if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache() kb.refresh_vs_cache()
elif mode == "fill_info_only": elif mode == "fill_info_only":
files = list_files_from_folder(kb_name) files = list_files_from_folder(kb_name)
@ -85,6 +86,7 @@ def folder2db(
kb.update_doc(kb_file, not_refresh_vs_cache=True) kb.update_doc(kb_file, not_refresh_vs_cache=True)
if kb.vs_type() == SupportedVSType.FAISS: if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache() kb.refresh_vs_cache()
elif mode == "increament": elif mode == "increament":
db_files = kb.list_files() db_files = kb.list_files()
@ -102,6 +104,7 @@ def folder2db(
print(result) print(result)
if kb.vs_type() == SupportedVSType.FAISS: if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache() kb.refresh_vs_cache()
else: else:
print(f"unspported migrate mode: {mode}") print(f"unspported migrate mode: {mode}")
@ -131,7 +134,10 @@ def prune_db_files(kb_name: str):
files = list(set(files_in_db) - set(files_in_folder)) files = list(set(files_in_db) - set(files_in_folder))
kb_files = file_to_kbfile(kb_name, files) kb_files = file_to_kbfile(kb_name, files)
for kb_file in kb_files: for kb_file in kb_files:
kb.delete_doc(kb_file) kb.delete_doc(kb_file, not_refresh_vs_cache=True)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache()
return kb_files return kb_files
def prune_folder_files(kb_name: str): def prune_folder_files(kb_name: str):