bug fix: faiss vector store not saved when recreate

This commit is contained in:
liunux4odoo 2023-08-31 16:43:42 +08:00
parent 4bdfb6e154
commit 80590ef5dc
2 changed files with 16 additions and 5 deletions

View File

@ -28,7 +28,7 @@ def load_faiss_vector_store(
embed_device: str = EMBEDDING_DEVICE,
embeddings: Embeddings = None,
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
):
) -> FAISS:
print(f"loading vector store in '{knowledge_base_name}'.")
vs_path = get_vs_path(knowledge_base_name)
if embeddings is None:
@ -74,13 +74,18 @@ class FaissKBService(KBService):
def get_kb_path(self):
return os.path.join(KB_ROOT_PATH, self.kb_name)
def load_vector_store(self):
def load_vector_store(self) -> FAISS:
return load_faiss_vector_store(
knowledge_base_name=self.kb_name,
embed_model=self.embed_model,
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0),
)
def save_vector_store(self, vector_store: FAISS = None):
vector_store = vector_store or self.load_vector_store()
vector_store.save_local(self.vs_path)
return vector_store
def refresh_vs_cache(self):
refresh_vs_cache(self.kb_name)
@ -117,11 +122,11 @@ class FaissKBService(KBService):
if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
self.refresh_vs_cache()
return vector_store
def do_delete_doc(self,
kb_file: KnowledgeFile,
**kwargs):
embeddings = self._load_embeddings()
vector_store = self.load_vector_store()
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
@ -133,7 +138,7 @@ class FaissKBService(KBService):
vector_store.save_local(self.vs_path)
self.refresh_vs_cache()
return True
return vector_store
def do_clear_vs(self):
shutil.rmtree(self.vs_path)

View File

@ -69,6 +69,7 @@ def folder2db(
print(result)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache()
elif mode == "fill_info_only":
files = list_files_from_folder(kb_name)
@ -85,6 +86,7 @@ def folder2db(
kb.update_doc(kb_file, not_refresh_vs_cache=True)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache()
elif mode == "increament":
db_files = kb.list_files()
@ -102,6 +104,7 @@ def folder2db(
print(result)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache()
else:
print(f"unspported migrate mode: {mode}")
@ -131,7 +134,10 @@ def prune_db_files(kb_name: str):
files = list(set(files_in_db) - set(files_in_folder))
kb_files = file_to_kbfile(kb_name, files)
for kb_file in kb_files:
kb.delete_doc(kb_file)
kb.delete_doc(kb_file, not_refresh_vs_cache=True)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache()
return kb_files
def prune_folder_files(kb_name: str):