增加传入矢量名称加载

This commit is contained in:
glide-the 2023-09-17 13:54:03 +08:00
parent 175c90c362
commit 902ba0c321
2 changed files with 18 additions and 17 deletions

View File

@ -60,19 +60,20 @@ class KBFaissPool(_FaissPool):
def load_vector_store( def load_vector_store(
self, self,
kb_name: str, kb_name: str,
vector_name: str = "vector_store",
create: bool = True, create: bool = True,
embed_model: str = EMBEDDING_MODEL, embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(), embed_device: str = embedding_device(),
) -> ThreadSafeFaiss: ) -> ThreadSafeFaiss:
self.atomic.acquire() self.atomic.acquire()
cache = self.get(kb_name) cache = self.get(kb_name+vector_name)
if cache is None: if cache is None:
item = ThreadSafeFaiss(kb_name, pool=self) item = ThreadSafeFaiss(kb_name, pool=self)
self.set(kb_name, item) self.set(kb_name+vector_name, item)
with item.acquire(msg="初始化"): with item.acquire(msg="初始化"):
self.atomic.release() self.atomic.release()
logger.info(f"loading vector store in '{kb_name}' from disk.") logger.info(f"loading vector store in '{kb_name}/{vector_name}' from disk.")
vs_path = get_vs_path(kb_name) vs_path = get_vs_path(kb_name, vector_name)
if os.path.isfile(os.path.join(vs_path, "index.faiss")): if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device) embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device)
@ -89,7 +90,7 @@ class KBFaissPool(_FaissPool):
item.finish_loading() item.finish_loading()
else: else:
self.atomic.release() self.atomic.release()
return self.get(kb_name) return self.get(kb_name+vector_name)
class MemoFaissPool(_FaissPool): class MemoFaissPool(_FaissPool):

View File

@ -42,8 +42,8 @@ def get_doc_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "content") return os.path.join(get_kb_path(knowledge_base_name), "content")
def get_vs_path(knowledge_base_name: str): def get_vs_path(knowledge_base_name: str, vector_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "vector_store") return os.path.join(get_kb_path(knowledge_base_name), vector_name)
def get_file_path(knowledge_base_name: str, doc_name: str): def get_file_path(knowledge_base_name: str, doc_name: str):