增加传入矢量名称加载
This commit is contained in:
parent
175c90c362
commit
902ba0c321
|
|
@ -60,19 +60,20 @@ class KBFaissPool(_FaissPool):
|
|||
def load_vector_store(
|
||||
self,
|
||||
kb_name: str,
|
||||
vector_name: str = "vector_store",
|
||||
create: bool = True,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = embedding_device(),
|
||||
) -> ThreadSafeFaiss:
|
||||
self.atomic.acquire()
|
||||
cache = self.get(kb_name)
|
||||
cache = self.get(kb_name+vector_name)
|
||||
if cache is None:
|
||||
item = ThreadSafeFaiss(kb_name, pool=self)
|
||||
self.set(kb_name, item)
|
||||
self.set(kb_name+vector_name, item)
|
||||
with item.acquire(msg="初始化"):
|
||||
self.atomic.release()
|
||||
logger.info(f"loading vector store in '{kb_name}' from disk.")
|
||||
vs_path = get_vs_path(kb_name)
|
||||
logger.info(f"loading vector store in '{kb_name}/{vector_name}' from disk.")
|
||||
vs_path = get_vs_path(kb_name, vector_name)
|
||||
|
||||
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
|
||||
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device)
|
||||
|
|
@ -89,7 +90,7 @@ class KBFaissPool(_FaissPool):
|
|||
item.finish_loading()
|
||||
else:
|
||||
self.atomic.release()
|
||||
return self.get(kb_name)
|
||||
return self.get(kb_name+vector_name)
|
||||
|
||||
|
||||
class MemoFaissPool(_FaissPool):
|
||||
|
|
|
|||
|
|
@ -42,8 +42,8 @@ def get_doc_path(knowledge_base_name: str):
|
|||
return os.path.join(get_kb_path(knowledge_base_name), "content")
|
||||
|
||||
|
||||
def get_vs_path(knowledge_base_name: str):
|
||||
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
|
||||
def get_vs_path(knowledge_base_name: str, vector_name: str):
|
||||
return os.path.join(get_kb_path(knowledge_base_name), vector_name)
|
||||
|
||||
|
||||
def get_file_path(knowledge_base_name: str, doc_name: str):
|
||||
|
|
|
|||
Loading…
Reference in New Issue