修复faiss_pool知识库缓存key错误 (#1507)
This commit is contained in:
parent
ec85cd1954
commit
1bae930691
|
|
@ -22,7 +22,11 @@ class ThreadSafeObject:
|
|||
|
||||
def __repr__(self) -> str:
|
||||
cls = type(self).__name__
|
||||
return f"<{cls}: key: {self._key}, obj: {self._obj}>"
|
||||
return f"<{cls}: key: {self.key}, obj: {self._obj}>"
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
return self._key
|
||||
|
||||
@contextmanager
|
||||
def acquire(self, owner: str = "", msg: str = ""):
|
||||
|
|
@ -30,13 +34,13 @@ class ThreadSafeObject:
|
|||
try:
|
||||
self._lock.acquire()
|
||||
if self._pool is not None:
|
||||
self._pool._cache.move_to_end(self._key)
|
||||
self._pool._cache.move_to_end(self.key)
|
||||
if log_verbose:
|
||||
logger.info(f"{owner} 开始操作:{self._key}。{msg}")
|
||||
logger.info(f"{owner} 开始操作:{self.key}。{msg}")
|
||||
yield self._obj
|
||||
finally:
|
||||
if log_verbose:
|
||||
logger.info(f"{owner} 结束操作:{self._key}。{msg}")
|
||||
logger.info(f"{owner} 结束操作:{self.key}。{msg}")
|
||||
self._lock.release()
|
||||
|
||||
def start_loading(self):
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import os
|
|||
class ThreadSafeFaiss(ThreadSafeObject):
|
||||
def __repr__(self) -> str:
|
||||
cls = type(self).__name__
|
||||
return f"<{cls}: key: {self._key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
|
||||
return f"<{cls}: key: {self.key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
|
||||
|
||||
def docs_count(self) -> int:
|
||||
return len(self._obj.docstore._dict)
|
||||
|
|
@ -17,7 +17,7 @@ class ThreadSafeFaiss(ThreadSafeObject):
|
|||
if not os.path.isdir(path) and create_path:
|
||||
os.makedirs(path)
|
||||
ret = self._obj.save_local(path)
|
||||
logger.info(f"已将向量库 {self._key} 保存到磁盘")
|
||||
logger.info(f"已将向量库 {self.key} 保存到磁盘")
|
||||
return ret
|
||||
|
||||
def clear(self):
|
||||
|
|
@ -27,7 +27,7 @@ class ThreadSafeFaiss(ThreadSafeObject):
|
|||
if ids:
|
||||
ret = self._obj.delete(ids)
|
||||
assert len(self._obj.docstore._dict) == 0
|
||||
logger.info(f"已将向量库 {self._key} 清空")
|
||||
logger.info(f"已将向量库 {self.key} 清空")
|
||||
return ret
|
||||
|
||||
|
||||
|
|
@ -66,10 +66,10 @@ class KBFaissPool(_FaissPool):
|
|||
embed_device: str = embedding_device(),
|
||||
) -> ThreadSafeFaiss:
|
||||
self.atomic.acquire()
|
||||
cache = self.get(kb_name+vector_name)
|
||||
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
|
||||
if cache is None:
|
||||
item = ThreadSafeFaiss(kb_name, pool=self)
|
||||
self.set(kb_name+vector_name, item)
|
||||
item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
|
||||
self.set((kb_name, vector_name), item)
|
||||
with item.acquire(msg="初始化"):
|
||||
self.atomic.release()
|
||||
logger.info(f"loading vector store in '{kb_name}/{vector_name}' from disk.")
|
||||
|
|
@ -90,7 +90,7 @@ class KBFaissPool(_FaissPool):
|
|||
item.finish_loading()
|
||||
else:
|
||||
self.atomic.release()
|
||||
return self.get(kb_name+vector_name)
|
||||
return self.get((kb_name, vector_name))
|
||||
|
||||
|
||||
class MemoFaissPool(_FaissPool):
|
||||
|
|
|
|||
|
|
@ -18,18 +18,21 @@ from server.utils import torch_gc
|
|||
class FaissKBService(KBService):
|
||||
vs_path: str
|
||||
kb_path: str
|
||||
vector_name: str = "vector_store"
|
||||
|
||||
def vs_type(self) -> str:
|
||||
return SupportedVSType.FAISS
|
||||
|
||||
def get_vs_path(self):
|
||||
return os.path.join(self.get_kb_path(), "vector_store")
|
||||
return os.path.join(self.get_kb_path(), self.vector_name)
|
||||
|
||||
def get_kb_path(self):
|
||||
return os.path.join(KB_ROOT_PATH, self.kb_name)
|
||||
|
||||
def load_vector_store(self) -> ThreadSafeFaiss:
|
||||
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name, embed_model=self.embed_model)
|
||||
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
|
||||
vector_name=self.vector_name,
|
||||
embed_model=self.embed_model)
|
||||
|
||||
def save_vector_store(self):
|
||||
self.load_vector_store().save(self.vs_path)
|
||||
|
|
|
|||
Loading…
Reference in New Issue