增加传入矢量名称加载
This commit is contained in:
parent
175c90c362
commit
902ba0c321
|
|
@ -60,19 +60,20 @@ class KBFaissPool(_FaissPool):
|
||||||
def load_vector_store(
|
def load_vector_store(
|
||||||
self,
|
self,
|
||||||
kb_name: str,
|
kb_name: str,
|
||||||
|
vector_name: str = "vector_store",
|
||||||
create: bool = True,
|
create: bool = True,
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
embed_device: str = embedding_device(),
|
embed_device: str = embedding_device(),
|
||||||
) -> ThreadSafeFaiss:
|
) -> ThreadSafeFaiss:
|
||||||
self.atomic.acquire()
|
self.atomic.acquire()
|
||||||
cache = self.get(kb_name)
|
cache = self.get(kb_name+vector_name)
|
||||||
if cache is None:
|
if cache is None:
|
||||||
item = ThreadSafeFaiss(kb_name, pool=self)
|
item = ThreadSafeFaiss(kb_name, pool=self)
|
||||||
self.set(kb_name, item)
|
self.set(kb_name+vector_name, item)
|
||||||
with item.acquire(msg="初始化"):
|
with item.acquire(msg="初始化"):
|
||||||
self.atomic.release()
|
self.atomic.release()
|
||||||
logger.info(f"loading vector store in '{kb_name}' from disk.")
|
logger.info(f"loading vector store in '{kb_name}/{vector_name}' from disk.")
|
||||||
vs_path = get_vs_path(kb_name)
|
vs_path = get_vs_path(kb_name, vector_name)
|
||||||
|
|
||||||
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
|
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
|
||||||
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device)
|
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device)
|
||||||
|
|
@ -89,7 +90,7 @@ class KBFaissPool(_FaissPool):
|
||||||
item.finish_loading()
|
item.finish_loading()
|
||||||
else:
|
else:
|
||||||
self.atomic.release()
|
self.atomic.release()
|
||||||
return self.get(kb_name)
|
return self.get(kb_name+vector_name)
|
||||||
|
|
||||||
|
|
||||||
class MemoFaissPool(_FaissPool):
|
class MemoFaissPool(_FaissPool):
|
||||||
|
|
|
||||||
|
|
@ -42,8 +42,8 @@ def get_doc_path(knowledge_base_name: str):
|
||||||
return os.path.join(get_kb_path(knowledge_base_name), "content")
|
return os.path.join(get_kb_path(knowledge_base_name), "content")
|
||||||
|
|
||||||
|
|
||||||
def get_vs_path(knowledge_base_name: str):
|
def get_vs_path(knowledge_base_name: str, vector_name: str):
|
||||||
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
|
return os.path.join(get_kb_path(knowledge_base_name), vector_name)
|
||||||
|
|
||||||
|
|
||||||
def get_file_path(knowledge_base_name: str, doc_name: str):
|
def get_file_path(knowledge_base_name: str, doc_name: str):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue