增加传入矢量名称加载

This commit is contained in:
glide-the 2023-09-17 13:54:03 +08:00
parent 175c90c362
commit 902ba0c321
2 changed files with 18 additions and 17 deletions

View File

@ -58,21 +58,22 @@ class _FaissPool(CachePool):
class KBFaissPool(_FaissPool):
def load_vector_store(
self,
kb_name: str,
create: bool = True,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
self,
kb_name: str,
vector_name: str = "vector_store",
create: bool = True,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> ThreadSafeFaiss:
self.atomic.acquire()
cache = self.get(kb_name)
cache = self.get(kb_name+vector_name)
if cache is None:
item = ThreadSafeFaiss(kb_name, pool=self)
self.set(kb_name, item)
self.set(kb_name+vector_name, item)
with item.acquire(msg="初始化"):
self.atomic.release()
logger.info(f"loading vector store in '{kb_name}' from disk.")
vs_path = get_vs_path(kb_name)
logger.info(f"loading vector store in '{kb_name}/{vector_name}' from disk.")
vs_path = get_vs_path(kb_name, vector_name)
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device)
@ -89,7 +90,7 @@ class KBFaissPool(_FaissPool):
item.finish_loading()
else:
self.atomic.release()
return self.get(kb_name)
return self.get(kb_name+vector_name)
class MemoFaissPool(_FaissPool):
@ -144,7 +145,7 @@ if __name__ == "__main__":
if r == 3: # delete docs
logger.warning(f"清除 {vs_name} by {name}")
kb_faiss_pool.get(vs_name).clear()
threads = []
for n in range(1, 30):
t = threading.Thread(target=worker,
@ -152,6 +153,6 @@ if __name__ == "__main__":
daemon=True)
t.start()
threads.append(t)
for t in threads:
t.join()

View File

@ -8,8 +8,8 @@ from configs import (
CHUNK_SIZE,
OVERLAP_SIZE,
ZH_TITLE_ENHANCE,
logger,
log_verbose,
logger,
log_verbose,
text_splitter_dict,
LLM_MODEL,
TEXT_SPLITTER_NAME,
@ -42,8 +42,8 @@ def get_doc_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "content")
def get_vs_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
def get_vs_path(knowledge_base_name: str, vector_name: str):
return os.path.join(get_kb_path(knowledge_base_name), vector_name)
def get_file_path(knowledge_base_name: str, doc_name: str):
@ -391,4 +391,4 @@ if __name__ == "__main__":
pprint(docs[-1])
docs = kb_file.file2text()
pprint(docs[-1])
pprint(docs[-1])