修复faiss_pool知识库缓存key错误 (#1507)

This commit is contained in:
liunux4odoo 2023-09-17 16:31:44 +08:00 committed by GitHub
parent ec85cd1954
commit 1bae930691
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 13 deletions

View File

@ -22,7 +22,11 @@ class ThreadSafeObject:
def __repr__(self) -> str: def __repr__(self) -> str:
cls = type(self).__name__ cls = type(self).__name__
return f"<{cls}: key: {self._key}, obj: {self._obj}>" return f"<{cls}: key: {self.key}, obj: {self._obj}>"
@property
def key(self):
return self._key
@contextmanager @contextmanager
def acquire(self, owner: str = "", msg: str = ""): def acquire(self, owner: str = "", msg: str = ""):
@ -30,13 +34,13 @@ class ThreadSafeObject:
try: try:
self._lock.acquire() self._lock.acquire()
if self._pool is not None: if self._pool is not None:
self._pool._cache.move_to_end(self._key) self._pool._cache.move_to_end(self.key)
if log_verbose: if log_verbose:
logger.info(f"{owner} 开始操作:{self._key}{msg}") logger.info(f"{owner} 开始操作:{self.key}{msg}")
yield self._obj yield self._obj
finally: finally:
if log_verbose: if log_verbose:
logger.info(f"{owner} 结束操作:{self._key}{msg}") logger.info(f"{owner} 结束操作:{self.key}{msg}")
self._lock.release() self._lock.release()
def start_loading(self): def start_loading(self):

View File

@ -7,7 +7,7 @@ import os
class ThreadSafeFaiss(ThreadSafeObject): class ThreadSafeFaiss(ThreadSafeObject):
def __repr__(self) -> str: def __repr__(self) -> str:
cls = type(self).__name__ cls = type(self).__name__
return f"<{cls}: key: {self._key}, obj: {self._obj}, docs_count: {self.docs_count()}>" return f"<{cls}: key: {self.key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
def docs_count(self) -> int: def docs_count(self) -> int:
return len(self._obj.docstore._dict) return len(self._obj.docstore._dict)
@ -17,7 +17,7 @@ class ThreadSafeFaiss(ThreadSafeObject):
if not os.path.isdir(path) and create_path: if not os.path.isdir(path) and create_path:
os.makedirs(path) os.makedirs(path)
ret = self._obj.save_local(path) ret = self._obj.save_local(path)
logger.info(f"已将向量库 {self._key} 保存到磁盘") logger.info(f"已将向量库 {self.key} 保存到磁盘")
return ret return ret
def clear(self): def clear(self):
@ -27,7 +27,7 @@ class ThreadSafeFaiss(ThreadSafeObject):
if ids: if ids:
ret = self._obj.delete(ids) ret = self._obj.delete(ids)
assert len(self._obj.docstore._dict) == 0 assert len(self._obj.docstore._dict) == 0
logger.info(f"已将向量库 {self._key} 清空") logger.info(f"已将向量库 {self.key} 清空")
return ret return ret
@ -66,10 +66,10 @@ class KBFaissPool(_FaissPool):
embed_device: str = embedding_device(), embed_device: str = embedding_device(),
) -> ThreadSafeFaiss: ) -> ThreadSafeFaiss:
self.atomic.acquire() self.atomic.acquire()
cache = self.get(kb_name+vector_name) cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
if cache is None: if cache is None:
item = ThreadSafeFaiss(kb_name, pool=self) item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
self.set(kb_name+vector_name, item) self.set((kb_name, vector_name), item)
with item.acquire(msg="初始化"): with item.acquire(msg="初始化"):
self.atomic.release() self.atomic.release()
logger.info(f"loading vector store in '{kb_name}/{vector_name}' from disk.") logger.info(f"loading vector store in '{kb_name}/{vector_name}' from disk.")
@ -90,7 +90,7 @@ class KBFaissPool(_FaissPool):
item.finish_loading() item.finish_loading()
else: else:
self.atomic.release() self.atomic.release()
return self.get(kb_name+vector_name) return self.get((kb_name, vector_name))
class MemoFaissPool(_FaissPool): class MemoFaissPool(_FaissPool):

View File

@ -18,18 +18,21 @@ from server.utils import torch_gc
class FaissKBService(KBService): class FaissKBService(KBService):
vs_path: str vs_path: str
kb_path: str kb_path: str
vector_name: str = "vector_store"
def vs_type(self) -> str: def vs_type(self) -> str:
return SupportedVSType.FAISS return SupportedVSType.FAISS
def get_vs_path(self): def get_vs_path(self):
return os.path.join(self.get_kb_path(), "vector_store") return os.path.join(self.get_kb_path(), self.vector_name)
def get_kb_path(self): def get_kb_path(self):
return os.path.join(KB_ROOT_PATH, self.kb_name) return os.path.join(KB_ROOT_PATH, self.kb_name)
def load_vector_store(self) -> ThreadSafeFaiss: def load_vector_store(self) -> ThreadSafeFaiss:
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name, embed_model=self.embed_model) return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
vector_name=self.vector_name,
embed_model=self.embed_model)
def save_vector_store(self): def save_vector_store(self):
self.load_vector_store().save(self.vs_path) self.load_vector_store().save(self.vs_path)