增加传入矢量名称加载

This commit is contained in:
glide-the 2023-09-17 13:54:03 +08:00
parent 175c90c362
commit 902ba0c321
2 changed files with 18 additions and 17 deletions

View File

@ -58,21 +58,22 @@ class _FaissPool(CachePool):
class KBFaissPool(_FaissPool): class KBFaissPool(_FaissPool):
def load_vector_store( def load_vector_store(
self, self,
kb_name: str, kb_name: str,
create: bool = True, vector_name: str = "vector_store",
embed_model: str = EMBEDDING_MODEL, create: bool = True,
embed_device: str = embedding_device(), embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> ThreadSafeFaiss: ) -> ThreadSafeFaiss:
self.atomic.acquire() self.atomic.acquire()
cache = self.get(kb_name) cache = self.get(kb_name+vector_name)
if cache is None: if cache is None:
item = ThreadSafeFaiss(kb_name, pool=self) item = ThreadSafeFaiss(kb_name, pool=self)
self.set(kb_name, item) self.set(kb_name+vector_name, item)
with item.acquire(msg="初始化"): with item.acquire(msg="初始化"):
self.atomic.release() self.atomic.release()
logger.info(f"loading vector store in '{kb_name}' from disk.") logger.info(f"loading vector store in '{kb_name}/{vector_name}' from disk.")
vs_path = get_vs_path(kb_name) vs_path = get_vs_path(kb_name, vector_name)
if os.path.isfile(os.path.join(vs_path, "index.faiss")): if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device) embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device)
@ -89,7 +90,7 @@ class KBFaissPool(_FaissPool):
item.finish_loading() item.finish_loading()
else: else:
self.atomic.release() self.atomic.release()
return self.get(kb_name) return self.get(kb_name+vector_name)
class MemoFaissPool(_FaissPool): class MemoFaissPool(_FaissPool):
@ -144,7 +145,7 @@ if __name__ == "__main__":
if r == 3: # delete docs if r == 3: # delete docs
logger.warning(f"清除 {vs_name} by {name}") logger.warning(f"清除 {vs_name} by {name}")
kb_faiss_pool.get(vs_name).clear() kb_faiss_pool.get(vs_name).clear()
threads = [] threads = []
for n in range(1, 30): for n in range(1, 30):
t = threading.Thread(target=worker, t = threading.Thread(target=worker,
@ -152,6 +153,6 @@ if __name__ == "__main__":
daemon=True) daemon=True)
t.start() t.start()
threads.append(t) threads.append(t)
for t in threads: for t in threads:
t.join() t.join()

View File

@ -8,8 +8,8 @@ from configs import (
CHUNK_SIZE, CHUNK_SIZE,
OVERLAP_SIZE, OVERLAP_SIZE,
ZH_TITLE_ENHANCE, ZH_TITLE_ENHANCE,
logger, logger,
log_verbose, log_verbose,
text_splitter_dict, text_splitter_dict,
LLM_MODEL, LLM_MODEL,
TEXT_SPLITTER_NAME, TEXT_SPLITTER_NAME,
@ -42,8 +42,8 @@ def get_doc_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "content") return os.path.join(get_kb_path(knowledge_base_name), "content")
def get_vs_path(knowledge_base_name: str): def get_vs_path(knowledge_base_name: str, vector_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "vector_store") return os.path.join(get_kb_path(knowledge_base_name), vector_name)
def get_file_path(knowledge_base_name: str, doc_name: str): def get_file_path(knowledge_base_name: str, doc_name: str):
@ -391,4 +391,4 @@ if __name__ == "__main__":
pprint(docs[-1]) pprint(docs[-1])
docs = kb_file.file2text() docs = kb_file.file2text()
pprint(docs[-1]) pprint(docs[-1])