优化FAISS向量库多文件操作;修复recreate_vector_store,大量文件时不再超时。

This commit is contained in:
liunux4odoo 2023-08-20 19:10:29 +08:00
parent 7bd644701c
commit f40bb69224
5 changed files with 42 additions and 22 deletions

View File

@ -47,6 +47,7 @@ async def list_docs(
async def upload_doc(file: UploadFile = File(..., description="上传文件"), async def upload_doc(file: UploadFile = File(..., description="上传文件"),
knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]), knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]),
override: bool = Form(False, description="覆盖已有文件"), override: bool = Form(False, description="覆盖已有文件"),
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse: ) -> BaseResponse:
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")
@ -76,7 +77,7 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"),
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}") return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}")
try: try:
kb.add_doc(kb_file) kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
except Exception as e: except Exception as e:
print(e) print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}") return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}")
@ -87,6 +88,7 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"),
async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]), async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
doc_name: str = Body(..., examples=["file_name.md"]), doc_name: str = Body(..., examples=["file_name.md"]),
delete_content: bool = Body(False), delete_content: bool = Body(False),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse: ) -> BaseResponse:
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")
@ -102,7 +104,7 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
try: try:
kb_file = KnowledgeFile(filename=doc_name, kb_file = KnowledgeFile(filename=doc_name,
knowledge_base_name=knowledge_base_name) knowledge_base_name=knowledge_base_name)
kb.delete_doc(kb_file, delete_content) kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=not_refresh_vs_cache)
except Exception as e: except Exception as e:
print(e) print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}") return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}")
@ -113,6 +115,7 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
async def update_doc( async def update_doc(
knowledge_base_name: str = Body(..., examples=["samples"]), knowledge_base_name: str = Body(..., examples=["samples"]),
file_name: str = Body(..., examples=["file_name"]), file_name: str = Body(..., examples=["file_name"]),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse: ) -> BaseResponse:
''' '''
更新知识库文档 更新知识库文档
@ -128,7 +131,7 @@ async def update_doc(
kb_file = KnowledgeFile(filename=file_name, kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name) knowledge_base_name=knowledge_base_name)
if os.path.exists(kb_file.filepath): if os.path.exists(kb_file.filepath):
kb.update_doc(kb_file) kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}") return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
except Exception as e: except Exception as e:
print(e) print(e)
@ -205,7 +208,5 @@ async def recreate_vector_store(
"code": 500, "code": 500,
"msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。", "msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。",
}) })
import asyncio
await asyncio.sleep(5)
return StreamingResponse(output(), media_type="text/event-stream") return StreamingResponse(output(), media_type="text/event-stream")

View File

@ -71,7 +71,7 @@ class KBService(ABC):
status = delete_kb_from_db(self.kb_name) status = delete_kb_from_db(self.kb_name)
return status return status
def add_doc(self, kb_file: KnowledgeFile): def add_doc(self, kb_file: KnowledgeFile, **kwargs):
""" """
向知识库添加文件 向知识库添加文件
""" """
@ -79,29 +79,29 @@ class KBService(ABC):
if docs: if docs:
self.delete_doc(kb_file) self.delete_doc(kb_file)
embeddings = self._load_embeddings() embeddings = self._load_embeddings()
self.do_add_doc(docs, embeddings) self.do_add_doc(docs, embeddings, **kwargs)
status = add_doc_to_db(kb_file) status = add_doc_to_db(kb_file)
else: else:
status = False status = False
return status return status
def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False): def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs):
""" """
从知识库删除文件 从知识库删除文件
""" """
self.do_delete_doc(kb_file) self.do_delete_doc(kb_file, **kwargs)
status = delete_file_from_db(kb_file) status = delete_file_from_db(kb_file)
if delete_content and os.path.exists(kb_file.filepath): if delete_content and os.path.exists(kb_file.filepath):
os.remove(kb_file.filepath) os.remove(kb_file.filepath)
return status return status
def update_doc(self, kb_file: KnowledgeFile): def update_doc(self, kb_file: KnowledgeFile, **kwargs):
""" """
使用content中的文件更新向量库 使用content中的文件更新向量库
""" """
if os.path.exists(kb_file.filepath): if os.path.exists(kb_file.filepath):
self.delete_doc(kb_file) self.delete_doc(kb_file, **kwargs)
return self.add_doc(kb_file) return self.add_doc(kb_file, **kwargs)
def exist_doc(self, file_name: str): def exist_doc(self, file_name: str):
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name, return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,

View File

@ -66,6 +66,7 @@ def refresh_vs_cache(kb_name: str):
make vector store cache refreshed when next loading make vector store cache refreshed when next loading
""" """
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1 _VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
print(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
class FaissKBService(KBService): class FaissKBService(KBService):
@ -111,17 +112,20 @@ class FaissKBService(KBService):
def do_add_doc(self, def do_add_doc(self,
docs: List[Document], docs: List[Document],
embeddings: Embeddings, embeddings: Embeddings,
**kwargs,
): ):
vector_store = load_vector_store(self.kb_name, vector_store = load_vector_store(self.kb_name,
embeddings=embeddings, embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0)) tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
vector_store.add_documents(docs) vector_store.add_documents(docs)
torch_gc() torch_gc()
vector_store.save_local(self.vs_path) if not kwargs.get("not_refresh_vs_cache"):
refresh_vs_cache(self.kb_name) vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
def do_delete_doc(self, def do_delete_doc(self,
kb_file: KnowledgeFile): kb_file: KnowledgeFile,
**kwargs):
embeddings = self._load_embeddings() embeddings = self._load_embeddings()
vector_store = load_vector_store(self.kb_name, vector_store = load_vector_store(self.kb_name,
embeddings=embeddings, embeddings=embeddings,
@ -132,8 +136,9 @@ class FaissKBService(KBService):
return None return None
vector_store.delete(ids) vector_store.delete(ids)
vector_store.save_local(self.vs_path) if not kwargs.get("not_refresh_vs_cache"):
refresh_vs_cache(self.kb_name) vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
return True return True

View File

@ -138,8 +138,10 @@ def knowledge_base_page(api: ApiRequest):
# use_container_width=True, # use_container_width=True,
disabled=len(files) == 0, disabled=len(files) == 0,
): ):
for f in files: data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files]
ret = api.upload_kb_doc(f, kb) data[-1]["not_refresh_vs_cache"]=False
for k in data:
ret = api.upload_kb_doc(**k)
if msg := check_success_msg(ret): if msg := check_success_msg(ret):
st.toast(msg, icon="") st.toast(msg, icon="")
elif msg := check_error_msg(ret): elif msg := check_error_msg(ret):

View File

@ -496,6 +496,7 @@ class ApiRequest:
knowledge_base_name: str, knowledge_base_name: str,
filename: str = None, filename: str = None,
override: bool = False, override: bool = False,
not_refresh_vs_cache: bool = False,
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
''' '''
@ -529,7 +530,11 @@ class ApiRequest:
else: else:
response = self.post( response = self.post(
"/knowledge_base/upload_doc", "/knowledge_base/upload_doc",
data={"knowledge_base_name": knowledge_base_name, "override": override}, data={
"knowledge_base_name": knowledge_base_name,
"override": override,
"not_refresh_vs_cache": not_refresh_vs_cache,
},
files={"file": (filename, file)}, files={"file": (filename, file)},
) )
return self._check_httpx_json_response(response) return self._check_httpx_json_response(response)
@ -539,6 +544,7 @@ class ApiRequest:
knowledge_base_name: str, knowledge_base_name: str,
doc_name: str, doc_name: str,
delete_content: bool = False, delete_content: bool = False,
not_refresh_vs_cache: bool = False,
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
''' '''
@ -551,6 +557,7 @@ class ApiRequest:
"knowledge_base_name": knowledge_base_name, "knowledge_base_name": knowledge_base_name,
"doc_name": doc_name, "doc_name": doc_name,
"delete_content": delete_content, "delete_content": delete_content,
"not_refresh_vs_cache": not_refresh_vs_cache,
} }
if no_remote_api: if no_remote_api:
@ -568,6 +575,7 @@ class ApiRequest:
self, self,
knowledge_base_name: str, knowledge_base_name: str,
file_name: str, file_name: str,
not_refresh_vs_cache: bool = False,
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
''' '''
@ -583,7 +591,11 @@ class ApiRequest:
else: else:
response = self.post( response = self.post(
"/knowledge_base/update_doc", "/knowledge_base/update_doc",
json={"knowledge_base_name": knowledge_base_name, "file_name": file_name}, json={
"knowledge_base_name": knowledge_base_name,
"file_name": file_name,
"not_refresh_vs_cache": not_refresh_vs_cache,
},
) )
return self._check_httpx_json_response(response) return self._check_httpx_json_response(response)
@ -617,7 +629,7 @@ class ApiRequest:
"/knowledge_base/recreate_vector_store", "/knowledge_base/recreate_vector_store",
json=data, json=data,
stream=True, stream=True,
timeout=False, timeout=None,
) )
return self._httpx_stream2generator(response, as_json=True) return self._httpx_stream2generator(response, as_json=True)