diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index d506f63..09766e6 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -77,6 +77,7 @@ class KBService(ABC): """ docs = kb_file.file2text() if docs: + self.delete_doc(kb_file) embeddings = self._load_embeddings() self.do_add_doc(docs, embeddings) status = add_doc_to_db(kb_file) diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 5c8376f..5953c3a 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -41,7 +41,23 @@ def load_vector_store( vs_path = get_vs_path(knowledge_base_name) if embeddings is None: embeddings = load_embeddings(embed_model, embed_device) - search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True) + + if not os.path.exists(vs_path): + os.makedirs(vs_path) + + if "index.faiss" in os.listdir(vs_path): + search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True) + else: + # create an empty vector store + doc = Document(page_content="init", metadata={}) + search_index = FAISS.from_documents([doc], embeddings, normalize_L2=True) + ids = [k for k, v in search_index.docstore._dict.items()] + search_index.delete(ids) + search_index.save_local(vs_path) + + if tick == 0: # vector store is loaded first time + _VECTOR_STORE_TICKS[knowledge_base_name] = 0 + return search_index @@ -74,8 +90,10 @@ class FaissKBService(KBService): def do_create_kb(self): if not os.path.exists(self.vs_path): os.makedirs(self.vs_path) + load_vector_store(self.kb_name) def do_drop_kb(self): + self.clear_vs() shutil.rmtree(self.kb_path) def do_search(self, @@ -94,37 +112,35 @@ class FaissKBService(KBService): docs: List[Document], embeddings: Embeddings, ): - if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path): - vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True) - vector_store.add_documents(docs) - torch_gc() - else: - if not os.path.exists(self.vs_path): - os.makedirs(self.vs_path) - vector_store = FAISS.from_documents( - docs, embeddings, normalize_L2=True) # docs 为Document列表 - torch_gc() + vector_store = load_vector_store(self.kb_name, + embeddings=embeddings, + tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0)) + vector_store.add_documents(docs) + torch_gc() vector_store.save_local(self.vs_path) refresh_vs_cache(self.kb_name) def do_delete_doc(self, kb_file: KnowledgeFile): embeddings = self._load_embeddings() - if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path): - vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True) - ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] - if len(ids) == 0: - return None - vector_store.delete(ids) - vector_store.save_local(self.vs_path) - refresh_vs_cache(self.kb_name) - return True - else: + vector_store = load_vector_store(self.kb_name, + embeddings=embeddings, + tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0)) + + ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] + if len(ids) == 0: return None + vector_store.delete(ids) + vector_store.save_local(self.vs_path) + refresh_vs_cache(self.kb_name) + + return True + def do_clear_vs(self): shutil.rmtree(self.vs_path) os.makedirs(self.vs_path) + refresh_vs_cache(self.kb_name) def exist_doc(self, file_name: str): if super().exist_doc(file_name): diff --git a/tests/api/test_kb_api.py b/tests/api/test_kb_api.py index c09d57f..5a8b97d 100644 --- a/tests/api/test_kb_api.py +++ b/tests/api/test_kb_api.py @@ -8,6 +8,7 @@ root_path = Path(__file__).parent.parent.parent sys.path.append(str(root_path)) from configs.server_config import api_address from configs.model_config import VECTOR_SEARCH_TOP_K +from server.knowledge_base.utils import get_kb_path from pprint import pprint @@ -22,8 +23,11 @@ test_files = { def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"): + if not Path(get_kb_path(kb)).exists(): + return + url = api_base_url + api - print("\n删除知识库") + print("\n测试知识库存在,需要删除") r = requests.post(url, json=kb) data = r.json() pprint(data)