diff --git a/server/knowledge_base/kb_cache/faiss_cache.py b/server/knowledge_base/kb_cache/faiss_cache.py index 325c7bb..8cc3c31 100644 --- a/server/knowledge_base/kb_cache/faiss_cache.py +++ b/server/knowledge_base/kb_cache/faiss_cache.py @@ -58,21 +58,22 @@ class _FaissPool(CachePool): class KBFaissPool(_FaissPool): def load_vector_store( - self, - kb_name: str, - create: bool = True, - embed_model: str = EMBEDDING_MODEL, - embed_device: str = embedding_device(), + self, + kb_name: str, + vector_name: str = "vector_store", + create: bool = True, + embed_model: str = EMBEDDING_MODEL, + embed_device: str = embedding_device(), ) -> ThreadSafeFaiss: self.atomic.acquire() - cache = self.get(kb_name) + cache = self.get(kb_name+vector_name) if cache is None: item = ThreadSafeFaiss(kb_name, pool=self) - self.set(kb_name, item) + self.set(kb_name+vector_name, item) with item.acquire(msg="初始化"): self.atomic.release() - logger.info(f"loading vector store in '{kb_name}' from disk.") - vs_path = get_vs_path(kb_name) + logger.info(f"loading vector store in '{kb_name}/{vector_name}' from disk.") + vs_path = get_vs_path(kb_name, vector_name) if os.path.isfile(os.path.join(vs_path, "index.faiss")): embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device) @@ -89,7 +90,7 @@ class KBFaissPool(_FaissPool): item.finish_loading() else: self.atomic.release() - return self.get(kb_name) + return self.get(kb_name+vector_name) class MemoFaissPool(_FaissPool): @@ -144,7 +145,7 @@ if __name__ == "__main__": if r == 3: # delete docs logger.warning(f"清除 {vs_name} by {name}") kb_faiss_pool.get(vs_name).clear() - + threads = [] for n in range(1, 30): t = threading.Thread(target=worker, @@ -152,6 +153,6 @@ if __name__ == "__main__": daemon=True) t.start() threads.append(t) - + for t in threads: t.join() diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 033b48d..d38d9c6 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -8,8 +8,8 @@ from configs import ( CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, - logger, - log_verbose, + logger, + log_verbose, text_splitter_dict, LLM_MODEL, TEXT_SPLITTER_NAME, @@ -42,8 +42,8 @@ def get_doc_path(knowledge_base_name: str): return os.path.join(get_kb_path(knowledge_base_name), "content") -def get_vs_path(knowledge_base_name: str): - return os.path.join(get_kb_path(knowledge_base_name), "vector_store") +def get_vs_path(knowledge_base_name: str, vector_name: str): + return os.path.join(get_kb_path(knowledge_base_name), vector_name) def get_file_path(knowledge_base_name: str, doc_name: str): @@ -391,4 +391,4 @@ if __name__ == "__main__": pprint(docs[-1]) docs = kb_file.file2text() - pprint(docs[-1]) \ No newline at end of file + pprint(docs[-1])