增加传入矢量名称加载
This commit is contained in:
parent
175c90c362
commit
902ba0c321
|
|
@ -58,21 +58,22 @@ class _FaissPool(CachePool):
|
|||
|
||||
class KBFaissPool(_FaissPool):
|
||||
def load_vector_store(
|
||||
self,
|
||||
kb_name: str,
|
||||
create: bool = True,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = embedding_device(),
|
||||
self,
|
||||
kb_name: str,
|
||||
vector_name: str = "vector_store",
|
||||
create: bool = True,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = embedding_device(),
|
||||
) -> ThreadSafeFaiss:
|
||||
self.atomic.acquire()
|
||||
cache = self.get(kb_name)
|
||||
cache = self.get(kb_name+vector_name)
|
||||
if cache is None:
|
||||
item = ThreadSafeFaiss(kb_name, pool=self)
|
||||
self.set(kb_name, item)
|
||||
self.set(kb_name+vector_name, item)
|
||||
with item.acquire(msg="初始化"):
|
||||
self.atomic.release()
|
||||
logger.info(f"loading vector store in '{kb_name}' from disk.")
|
||||
vs_path = get_vs_path(kb_name)
|
||||
logger.info(f"loading vector store in '{kb_name}/{vector_name}' from disk.")
|
||||
vs_path = get_vs_path(kb_name, vector_name)
|
||||
|
||||
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
|
||||
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device)
|
||||
|
|
@ -89,7 +90,7 @@ class KBFaissPool(_FaissPool):
|
|||
item.finish_loading()
|
||||
else:
|
||||
self.atomic.release()
|
||||
return self.get(kb_name)
|
||||
return self.get(kb_name+vector_name)
|
||||
|
||||
|
||||
class MemoFaissPool(_FaissPool):
|
||||
|
|
@ -144,7 +145,7 @@ if __name__ == "__main__":
|
|||
if r == 3: # delete docs
|
||||
logger.warning(f"清除 {vs_name} by {name}")
|
||||
kb_faiss_pool.get(vs_name).clear()
|
||||
|
||||
|
||||
threads = []
|
||||
for n in range(1, 30):
|
||||
t = threading.Thread(target=worker,
|
||||
|
|
@ -152,6 +153,6 @@ if __name__ == "__main__":
|
|||
daemon=True)
|
||||
t.start()
|
||||
threads.append(t)
|
||||
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ from configs import (
|
|||
CHUNK_SIZE,
|
||||
OVERLAP_SIZE,
|
||||
ZH_TITLE_ENHANCE,
|
||||
logger,
|
||||
log_verbose,
|
||||
logger,
|
||||
log_verbose,
|
||||
text_splitter_dict,
|
||||
LLM_MODEL,
|
||||
TEXT_SPLITTER_NAME,
|
||||
|
|
@ -42,8 +42,8 @@ def get_doc_path(knowledge_base_name: str):
|
|||
return os.path.join(get_kb_path(knowledge_base_name), "content")
|
||||
|
||||
|
||||
def get_vs_path(knowledge_base_name: str):
|
||||
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
|
||||
def get_vs_path(knowledge_base_name: str, vector_name: str):
|
||||
return os.path.join(get_kb_path(knowledge_base_name), vector_name)
|
||||
|
||||
|
||||
def get_file_path(knowledge_base_name: str, doc_name: str):
|
||||
|
|
@ -391,4 +391,4 @@ if __name__ == "__main__":
|
|||
pprint(docs[-1])
|
||||
|
||||
docs = kb_file.file2text()
|
||||
pprint(docs[-1])
|
||||
pprint(docs[-1])
|
||||
|
|
|
|||
Loading…
Reference in New Issue