修复faiss_pool知识库缓存key错误 (#1507)

This commit is contained in:
liunux4odoo 2023-09-17 16:31:44 +08:00 committed by GitHub
parent ec85cd1954
commit 1bae930691
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 13 deletions

View File

@ -22,7 +22,11 @@ class ThreadSafeObject:
def __repr__(self) -> str:
cls = type(self).__name__
return f"<{cls}: key: {self._key}, obj: {self._obj}>"
return f"<{cls}: key: {self.key}, obj: {self._obj}>"
@property
def key(self):
return self._key
@contextmanager
def acquire(self, owner: str = "", msg: str = ""):
@ -30,13 +34,13 @@ class ThreadSafeObject:
try:
self._lock.acquire()
if self._pool is not None:
self._pool._cache.move_to_end(self._key)
self._pool._cache.move_to_end(self.key)
if log_verbose:
logger.info(f"{owner} 开始操作:{self._key}{msg}")
logger.info(f"{owner} 开始操作:{self.key}{msg}")
yield self._obj
finally:
if log_verbose:
logger.info(f"{owner} 结束操作:{self._key}{msg}")
logger.info(f"{owner} 结束操作:{self.key}{msg}")
self._lock.release()
def start_loading(self):

View File

@ -7,7 +7,7 @@ import os
class ThreadSafeFaiss(ThreadSafeObject):
def __repr__(self) -> str:
cls = type(self).__name__
return f"<{cls}: key: {self._key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
return f"<{cls}: key: {self.key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
def docs_count(self) -> int:
return len(self._obj.docstore._dict)
@ -17,7 +17,7 @@ class ThreadSafeFaiss(ThreadSafeObject):
if not os.path.isdir(path) and create_path:
os.makedirs(path)
ret = self._obj.save_local(path)
logger.info(f"已将向量库 {self._key} 保存到磁盘")
logger.info(f"已将向量库 {self.key} 保存到磁盘")
return ret
def clear(self):
@ -27,7 +27,7 @@ class ThreadSafeFaiss(ThreadSafeObject):
if ids:
ret = self._obj.delete(ids)
assert len(self._obj.docstore._dict) == 0
logger.info(f"已将向量库 {self._key} 清空")
logger.info(f"已将向量库 {self.key} 清空")
return ret
@ -66,10 +66,10 @@ class KBFaissPool(_FaissPool):
embed_device: str = embedding_device(),
) -> ThreadSafeFaiss:
self.atomic.acquire()
cache = self.get(kb_name+vector_name)
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
if cache is None:
item = ThreadSafeFaiss(kb_name, pool=self)
self.set(kb_name+vector_name, item)
item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
self.set((kb_name, vector_name), item)
with item.acquire(msg="初始化"):
self.atomic.release()
logger.info(f"loading vector store in '{kb_name}/{vector_name}' from disk.")
@ -90,7 +90,7 @@ class KBFaissPool(_FaissPool):
item.finish_loading()
else:
self.atomic.release()
return self.get(kb_name+vector_name)
return self.get((kb_name, vector_name))
class MemoFaissPool(_FaissPool):

View File

@ -18,18 +18,21 @@ from server.utils import torch_gc
class FaissKBService(KBService):
vs_path: str
kb_path: str
vector_name: str = "vector_store"
def vs_type(self) -> str:
return SupportedVSType.FAISS
def get_vs_path(self):
return os.path.join(self.get_kb_path(), "vector_store")
return os.path.join(self.get_kb_path(), self.vector_name)
def get_kb_path(self):
return os.path.join(KB_ROOT_PATH, self.kb_name)
def load_vector_store(self) -> ThreadSafeFaiss:
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name, embed_model=self.embed_model)
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
vector_name=self.vector_name,
embed_model=self.embed_model)
def save_vector_store(self):
self.load_vector_store().save(self.vs_path)