修复faiss_pool知识库缓存key错误 (#1507)
This commit is contained in:
parent
ec85cd1954
commit
1bae930691
|
|
@ -22,7 +22,11 @@ class ThreadSafeObject:
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
cls = type(self).__name__
|
cls = type(self).__name__
|
||||||
return f"<{cls}: key: {self._key}, obj: {self._obj}>"
|
return f"<{cls}: key: {self.key}, obj: {self._obj}>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def key(self):
|
||||||
|
return self._key
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def acquire(self, owner: str = "", msg: str = ""):
|
def acquire(self, owner: str = "", msg: str = ""):
|
||||||
|
|
@ -30,13 +34,13 @@ class ThreadSafeObject:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
if self._pool is not None:
|
if self._pool is not None:
|
||||||
self._pool._cache.move_to_end(self._key)
|
self._pool._cache.move_to_end(self.key)
|
||||||
if log_verbose:
|
if log_verbose:
|
||||||
logger.info(f"{owner} 开始操作:{self._key}。{msg}")
|
logger.info(f"{owner} 开始操作:{self.key}。{msg}")
|
||||||
yield self._obj
|
yield self._obj
|
||||||
finally:
|
finally:
|
||||||
if log_verbose:
|
if log_verbose:
|
||||||
logger.info(f"{owner} 结束操作:{self._key}。{msg}")
|
logger.info(f"{owner} 结束操作:{self.key}。{msg}")
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
def start_loading(self):
|
def start_loading(self):
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import os
|
||||||
class ThreadSafeFaiss(ThreadSafeObject):
|
class ThreadSafeFaiss(ThreadSafeObject):
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
cls = type(self).__name__
|
cls = type(self).__name__
|
||||||
return f"<{cls}: key: {self._key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
|
return f"<{cls}: key: {self.key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
|
||||||
|
|
||||||
def docs_count(self) -> int:
|
def docs_count(self) -> int:
|
||||||
return len(self._obj.docstore._dict)
|
return len(self._obj.docstore._dict)
|
||||||
|
|
@ -17,7 +17,7 @@ class ThreadSafeFaiss(ThreadSafeObject):
|
||||||
if not os.path.isdir(path) and create_path:
|
if not os.path.isdir(path) and create_path:
|
||||||
os.makedirs(path)
|
os.makedirs(path)
|
||||||
ret = self._obj.save_local(path)
|
ret = self._obj.save_local(path)
|
||||||
logger.info(f"已将向量库 {self._key} 保存到磁盘")
|
logger.info(f"已将向量库 {self.key} 保存到磁盘")
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
|
|
@ -27,7 +27,7 @@ class ThreadSafeFaiss(ThreadSafeObject):
|
||||||
if ids:
|
if ids:
|
||||||
ret = self._obj.delete(ids)
|
ret = self._obj.delete(ids)
|
||||||
assert len(self._obj.docstore._dict) == 0
|
assert len(self._obj.docstore._dict) == 0
|
||||||
logger.info(f"已将向量库 {self._key} 清空")
|
logger.info(f"已将向量库 {self.key} 清空")
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -66,10 +66,10 @@ class KBFaissPool(_FaissPool):
|
||||||
embed_device: str = embedding_device(),
|
embed_device: str = embedding_device(),
|
||||||
) -> ThreadSafeFaiss:
|
) -> ThreadSafeFaiss:
|
||||||
self.atomic.acquire()
|
self.atomic.acquire()
|
||||||
cache = self.get(kb_name+vector_name)
|
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
|
||||||
if cache is None:
|
if cache is None:
|
||||||
item = ThreadSafeFaiss(kb_name, pool=self)
|
item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
|
||||||
self.set(kb_name+vector_name, item)
|
self.set((kb_name, vector_name), item)
|
||||||
with item.acquire(msg="初始化"):
|
with item.acquire(msg="初始化"):
|
||||||
self.atomic.release()
|
self.atomic.release()
|
||||||
logger.info(f"loading vector store in '{kb_name}/{vector_name}' from disk.")
|
logger.info(f"loading vector store in '{kb_name}/{vector_name}' from disk.")
|
||||||
|
|
@ -90,7 +90,7 @@ class KBFaissPool(_FaissPool):
|
||||||
item.finish_loading()
|
item.finish_loading()
|
||||||
else:
|
else:
|
||||||
self.atomic.release()
|
self.atomic.release()
|
||||||
return self.get(kb_name+vector_name)
|
return self.get((kb_name, vector_name))
|
||||||
|
|
||||||
|
|
||||||
class MemoFaissPool(_FaissPool):
|
class MemoFaissPool(_FaissPool):
|
||||||
|
|
|
||||||
|
|
@ -18,18 +18,21 @@ from server.utils import torch_gc
|
||||||
class FaissKBService(KBService):
|
class FaissKBService(KBService):
|
||||||
vs_path: str
|
vs_path: str
|
||||||
kb_path: str
|
kb_path: str
|
||||||
|
vector_name: str = "vector_store"
|
||||||
|
|
||||||
def vs_type(self) -> str:
|
def vs_type(self) -> str:
|
||||||
return SupportedVSType.FAISS
|
return SupportedVSType.FAISS
|
||||||
|
|
||||||
def get_vs_path(self):
|
def get_vs_path(self):
|
||||||
return os.path.join(self.get_kb_path(), "vector_store")
|
return os.path.join(self.get_kb_path(), self.vector_name)
|
||||||
|
|
||||||
def get_kb_path(self):
|
def get_kb_path(self):
|
||||||
return os.path.join(KB_ROOT_PATH, self.kb_name)
|
return os.path.join(KB_ROOT_PATH, self.kb_name)
|
||||||
|
|
||||||
def load_vector_store(self) -> ThreadSafeFaiss:
|
def load_vector_store(self) -> ThreadSafeFaiss:
|
||||||
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name, embed_model=self.embed_model)
|
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
|
||||||
|
vector_name=self.vector_name,
|
||||||
|
embed_model=self.embed_model)
|
||||||
|
|
||||||
def save_vector_store(self):
|
def save_vector_store(self):
|
||||||
self.load_vector_store().save(self.vs_path)
|
self.load_vector_store().save(self.vs_path)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue