优化FAISS向量库多文件操作;修复recreate_vector_store,大量文件时不再超时。
This commit is contained in:
parent
7bd644701c
commit
f40bb69224
|
|
@ -47,6 +47,7 @@ async def list_docs(
|
|||
async def upload_doc(file: UploadFile = File(..., description="上传文件"),
|
||||
knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]),
|
||||
override: bool = Form(False, description="覆盖已有文件"),
|
||||
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
|
@ -76,7 +77,7 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"),
|
|||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}")
|
||||
|
||||
try:
|
||||
kb.add_doc(kb_file)
|
||||
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}")
|
||||
|
|
@ -87,6 +88,7 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"),
|
|||
async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
doc_name: str = Body(..., examples=["file_name.md"]),
|
||||
delete_content: bool = Body(False),
|
||||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
|
@ -102,7 +104,7 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
|
|||
try:
|
||||
kb_file = KnowledgeFile(filename=doc_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
kb.delete_doc(kb_file, delete_content)
|
||||
kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}")
|
||||
|
|
@ -113,6 +115,7 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
|
|||
async def update_doc(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
file_name: str = Body(..., examples=["file_name"]),
|
||||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
更新知识库文档
|
||||
|
|
@ -128,7 +131,7 @@ async def update_doc(
|
|||
kb_file = KnowledgeFile(filename=file_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
if os.path.exists(kb_file.filepath):
|
||||
kb.update_doc(kb_file)
|
||||
kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
|
@ -205,7 +208,5 @@ async def recreate_vector_store(
|
|||
"code": 500,
|
||||
"msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。",
|
||||
})
|
||||
import asyncio
|
||||
await asyncio.sleep(5)
|
||||
|
||||
return StreamingResponse(output(), media_type="text/event-stream")
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ class KBService(ABC):
|
|||
status = delete_kb_from_db(self.kb_name)
|
||||
return status
|
||||
|
||||
def add_doc(self, kb_file: KnowledgeFile):
|
||||
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
"""
|
||||
向知识库添加文件
|
||||
"""
|
||||
|
|
@ -79,29 +79,29 @@ class KBService(ABC):
|
|||
if docs:
|
||||
self.delete_doc(kb_file)
|
||||
embeddings = self._load_embeddings()
|
||||
self.do_add_doc(docs, embeddings)
|
||||
self.do_add_doc(docs, embeddings, **kwargs)
|
||||
status = add_doc_to_db(kb_file)
|
||||
else:
|
||||
status = False
|
||||
return status
|
||||
|
||||
def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False):
|
||||
def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs):
|
||||
"""
|
||||
从知识库删除文件
|
||||
"""
|
||||
self.do_delete_doc(kb_file)
|
||||
self.do_delete_doc(kb_file, **kwargs)
|
||||
status = delete_file_from_db(kb_file)
|
||||
if delete_content and os.path.exists(kb_file.filepath):
|
||||
os.remove(kb_file.filepath)
|
||||
return status
|
||||
|
||||
def update_doc(self, kb_file: KnowledgeFile):
|
||||
def update_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
"""
|
||||
使用content中的文件更新向量库
|
||||
"""
|
||||
if os.path.exists(kb_file.filepath):
|
||||
self.delete_doc(kb_file)
|
||||
return self.add_doc(kb_file)
|
||||
self.delete_doc(kb_file, **kwargs)
|
||||
return self.add_doc(kb_file, **kwargs)
|
||||
|
||||
def exist_doc(self, file_name: str):
|
||||
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
|
||||
|
|
|
|||
|
|
@ -66,6 +66,7 @@ def refresh_vs_cache(kb_name: str):
|
|||
make vector store cache refreshed when next loading
|
||||
"""
|
||||
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
|
||||
print(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
|
||||
|
||||
|
||||
class FaissKBService(KBService):
|
||||
|
|
@ -111,17 +112,20 @@ class FaissKBService(KBService):
|
|||
def do_add_doc(self,
|
||||
docs: List[Document],
|
||||
embeddings: Embeddings,
|
||||
**kwargs,
|
||||
):
|
||||
vector_store = load_vector_store(self.kb_name,
|
||||
embeddings=embeddings,
|
||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
|
||||
vector_store.add_documents(docs)
|
||||
torch_gc()
|
||||
vector_store.save_local(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
if not kwargs.get("not_refresh_vs_cache"):
|
||||
vector_store.save_local(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
|
||||
def do_delete_doc(self,
|
||||
kb_file: KnowledgeFile):
|
||||
kb_file: KnowledgeFile,
|
||||
**kwargs):
|
||||
embeddings = self._load_embeddings()
|
||||
vector_store = load_vector_store(self.kb_name,
|
||||
embeddings=embeddings,
|
||||
|
|
@ -132,8 +136,9 @@ class FaissKBService(KBService):
|
|||
return None
|
||||
|
||||
vector_store.delete(ids)
|
||||
vector_store.save_local(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
if not kwargs.get("not_refresh_vs_cache"):
|
||||
vector_store.save_local(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
|
||||
return True
|
||||
|
||||
|
|
|
|||
|
|
@ -138,8 +138,10 @@ def knowledge_base_page(api: ApiRequest):
|
|||
# use_container_width=True,
|
||||
disabled=len(files) == 0,
|
||||
):
|
||||
for f in files:
|
||||
ret = api.upload_kb_doc(f, kb)
|
||||
data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files]
|
||||
data[-1]["not_refresh_vs_cache"]=False
|
||||
for k in data:
|
||||
ret = api.upload_kb_doc(**k)
|
||||
if msg := check_success_msg(ret):
|
||||
st.toast(msg, icon="✔")
|
||||
elif msg := check_error_msg(ret):
|
||||
|
|
|
|||
|
|
@ -496,6 +496,7 @@ class ApiRequest:
|
|||
knowledge_base_name: str,
|
||||
filename: str = None,
|
||||
override: bool = False,
|
||||
not_refresh_vs_cache: bool = False,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
|
|
@ -529,7 +530,11 @@ class ApiRequest:
|
|||
else:
|
||||
response = self.post(
|
||||
"/knowledge_base/upload_doc",
|
||||
data={"knowledge_base_name": knowledge_base_name, "override": override},
|
||||
data={
|
||||
"knowledge_base_name": knowledge_base_name,
|
||||
"override": override,
|
||||
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||
},
|
||||
files={"file": (filename, file)},
|
||||
)
|
||||
return self._check_httpx_json_response(response)
|
||||
|
|
@ -539,6 +544,7 @@ class ApiRequest:
|
|||
knowledge_base_name: str,
|
||||
doc_name: str,
|
||||
delete_content: bool = False,
|
||||
not_refresh_vs_cache: bool = False,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
|
|
@ -551,6 +557,7 @@ class ApiRequest:
|
|||
"knowledge_base_name": knowledge_base_name,
|
||||
"doc_name": doc_name,
|
||||
"delete_content": delete_content,
|
||||
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||
}
|
||||
|
||||
if no_remote_api:
|
||||
|
|
@ -568,6 +575,7 @@ class ApiRequest:
|
|||
self,
|
||||
knowledge_base_name: str,
|
||||
file_name: str,
|
||||
not_refresh_vs_cache: bool = False,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
|
|
@ -583,7 +591,11 @@ class ApiRequest:
|
|||
else:
|
||||
response = self.post(
|
||||
"/knowledge_base/update_doc",
|
||||
json={"knowledge_base_name": knowledge_base_name, "file_name": file_name},
|
||||
json={
|
||||
"knowledge_base_name": knowledge_base_name,
|
||||
"file_name": file_name,
|
||||
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||
},
|
||||
)
|
||||
return self._check_httpx_json_response(response)
|
||||
|
||||
|
|
@ -617,7 +629,7 @@ class ApiRequest:
|
|||
"/knowledge_base/recreate_vector_store",
|
||||
json=data,
|
||||
stream=True,
|
||||
timeout=False,
|
||||
timeout=None,
|
||||
)
|
||||
return self._httpx_stream2generator(response, as_json=True)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue