增加传入矢量名称加载

This commit is contained in:
glide-the 2023-09-17 13:54:03 +08:00
parent 175c90c362
commit 902ba0c321
2 changed files with 18 additions and 17 deletions

View File

@ -60,19 +60,20 @@ class KBFaissPool(_FaissPool):
def load_vector_store(
self,
kb_name: str,
vector_name: str = "vector_store",
create: bool = True,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> ThreadSafeFaiss:
self.atomic.acquire()
cache = self.get(kb_name)
cache = self.get(kb_name+vector_name)
if cache is None:
item = ThreadSafeFaiss(kb_name, pool=self)
self.set(kb_name, item)
self.set(kb_name+vector_name, item)
with item.acquire(msg="初始化"):
self.atomic.release()
logger.info(f"loading vector store in '{kb_name}' from disk.")
vs_path = get_vs_path(kb_name)
logger.info(f"loading vector store in '{kb_name}/{vector_name}' from disk.")
vs_path = get_vs_path(kb_name, vector_name)
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device)
@ -89,7 +90,7 @@ class KBFaissPool(_FaissPool):
item.finish_loading()
else:
self.atomic.release()
return self.get(kb_name)
return self.get(kb_name+vector_name)
class MemoFaissPool(_FaissPool):

View File

@ -42,8 +42,8 @@ def get_doc_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "content")
def get_vs_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
def get_vs_path(knowledge_base_name: str, vector_name: str):
return os.path.join(get_kb_path(knowledge_base_name), vector_name)
def get_file_path(knowledge_base_name: str, doc_name: str):