优化FAISS向量库多文件操作;修复recreate_vector_store,大量文件时不再超时。
This commit is contained in:
parent
7bd644701c
commit
f40bb69224
|
|
@ -47,6 +47,7 @@ async def list_docs(
|
||||||
async def upload_doc(file: UploadFile = File(..., description="上传文件"),
|
async def upload_doc(file: UploadFile = File(..., description="上传文件"),
|
||||||
knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]),
|
knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]),
|
||||||
override: bool = Form(False, description="覆盖已有文件"),
|
override: bool = Form(False, description="覆盖已有文件"),
|
||||||
|
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
|
|
@ -76,7 +77,7 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"),
|
||||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}")
|
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
kb.add_doc(kb_file)
|
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}")
|
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}")
|
||||||
|
|
@ -87,6 +88,7 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"),
|
||||||
async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
|
async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
doc_name: str = Body(..., examples=["file_name.md"]),
|
doc_name: str = Body(..., examples=["file_name.md"]),
|
||||||
delete_content: bool = Body(False),
|
delete_content: bool = Body(False),
|
||||||
|
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
|
|
@ -102,7 +104,7 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
try:
|
try:
|
||||||
kb_file = KnowledgeFile(filename=doc_name,
|
kb_file = KnowledgeFile(filename=doc_name,
|
||||||
knowledge_base_name=knowledge_base_name)
|
knowledge_base_name=knowledge_base_name)
|
||||||
kb.delete_doc(kb_file, delete_content)
|
kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}")
|
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}")
|
||||||
|
|
@ -113,6 +115,7 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
async def update_doc(
|
async def update_doc(
|
||||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
file_name: str = Body(..., examples=["file_name"]),
|
file_name: str = Body(..., examples=["file_name"]),
|
||||||
|
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
'''
|
'''
|
||||||
更新知识库文档
|
更新知识库文档
|
||||||
|
|
@ -128,7 +131,7 @@ async def update_doc(
|
||||||
kb_file = KnowledgeFile(filename=file_name,
|
kb_file = KnowledgeFile(filename=file_name,
|
||||||
knowledge_base_name=knowledge_base_name)
|
knowledge_base_name=knowledge_base_name)
|
||||||
if os.path.exists(kb_file.filepath):
|
if os.path.exists(kb_file.filepath):
|
||||||
kb.update_doc(kb_file)
|
kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||||
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
|
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
@ -205,7 +208,5 @@ async def recreate_vector_store(
|
||||||
"code": 500,
|
"code": 500,
|
||||||
"msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。",
|
"msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。",
|
||||||
})
|
})
|
||||||
import asyncio
|
|
||||||
await asyncio.sleep(5)
|
|
||||||
|
|
||||||
return StreamingResponse(output(), media_type="text/event-stream")
|
return StreamingResponse(output(), media_type="text/event-stream")
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ class KBService(ABC):
|
||||||
status = delete_kb_from_db(self.kb_name)
|
status = delete_kb_from_db(self.kb_name)
|
||||||
return status
|
return status
|
||||||
|
|
||||||
def add_doc(self, kb_file: KnowledgeFile):
|
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||||
"""
|
"""
|
||||||
向知识库添加文件
|
向知识库添加文件
|
||||||
"""
|
"""
|
||||||
|
|
@ -79,29 +79,29 @@ class KBService(ABC):
|
||||||
if docs:
|
if docs:
|
||||||
self.delete_doc(kb_file)
|
self.delete_doc(kb_file)
|
||||||
embeddings = self._load_embeddings()
|
embeddings = self._load_embeddings()
|
||||||
self.do_add_doc(docs, embeddings)
|
self.do_add_doc(docs, embeddings, **kwargs)
|
||||||
status = add_doc_to_db(kb_file)
|
status = add_doc_to_db(kb_file)
|
||||||
else:
|
else:
|
||||||
status = False
|
status = False
|
||||||
return status
|
return status
|
||||||
|
|
||||||
def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False):
|
def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs):
|
||||||
"""
|
"""
|
||||||
从知识库删除文件
|
从知识库删除文件
|
||||||
"""
|
"""
|
||||||
self.do_delete_doc(kb_file)
|
self.do_delete_doc(kb_file, **kwargs)
|
||||||
status = delete_file_from_db(kb_file)
|
status = delete_file_from_db(kb_file)
|
||||||
if delete_content and os.path.exists(kb_file.filepath):
|
if delete_content and os.path.exists(kb_file.filepath):
|
||||||
os.remove(kb_file.filepath)
|
os.remove(kb_file.filepath)
|
||||||
return status
|
return status
|
||||||
|
|
||||||
def update_doc(self, kb_file: KnowledgeFile):
|
def update_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||||
"""
|
"""
|
||||||
使用content中的文件更新向量库
|
使用content中的文件更新向量库
|
||||||
"""
|
"""
|
||||||
if os.path.exists(kb_file.filepath):
|
if os.path.exists(kb_file.filepath):
|
||||||
self.delete_doc(kb_file)
|
self.delete_doc(kb_file, **kwargs)
|
||||||
return self.add_doc(kb_file)
|
return self.add_doc(kb_file, **kwargs)
|
||||||
|
|
||||||
def exist_doc(self, file_name: str):
|
def exist_doc(self, file_name: str):
|
||||||
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
|
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
|
||||||
|
|
|
||||||
|
|
@ -66,6 +66,7 @@ def refresh_vs_cache(kb_name: str):
|
||||||
make vector store cache refreshed when next loading
|
make vector store cache refreshed when next loading
|
||||||
"""
|
"""
|
||||||
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
|
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
|
||||||
|
print(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
|
||||||
|
|
||||||
|
|
||||||
class FaissKBService(KBService):
|
class FaissKBService(KBService):
|
||||||
|
|
@ -111,17 +112,20 @@ class FaissKBService(KBService):
|
||||||
def do_add_doc(self,
|
def do_add_doc(self,
|
||||||
docs: List[Document],
|
docs: List[Document],
|
||||||
embeddings: Embeddings,
|
embeddings: Embeddings,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
vector_store = load_vector_store(self.kb_name,
|
vector_store = load_vector_store(self.kb_name,
|
||||||
embeddings=embeddings,
|
embeddings=embeddings,
|
||||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
|
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
|
||||||
vector_store.add_documents(docs)
|
vector_store.add_documents(docs)
|
||||||
torch_gc()
|
torch_gc()
|
||||||
vector_store.save_local(self.vs_path)
|
if not kwargs.get("not_refresh_vs_cache"):
|
||||||
refresh_vs_cache(self.kb_name)
|
vector_store.save_local(self.vs_path)
|
||||||
|
refresh_vs_cache(self.kb_name)
|
||||||
|
|
||||||
def do_delete_doc(self,
|
def do_delete_doc(self,
|
||||||
kb_file: KnowledgeFile):
|
kb_file: KnowledgeFile,
|
||||||
|
**kwargs):
|
||||||
embeddings = self._load_embeddings()
|
embeddings = self._load_embeddings()
|
||||||
vector_store = load_vector_store(self.kb_name,
|
vector_store = load_vector_store(self.kb_name,
|
||||||
embeddings=embeddings,
|
embeddings=embeddings,
|
||||||
|
|
@ -132,8 +136,9 @@ class FaissKBService(KBService):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
vector_store.delete(ids)
|
vector_store.delete(ids)
|
||||||
vector_store.save_local(self.vs_path)
|
if not kwargs.get("not_refresh_vs_cache"):
|
||||||
refresh_vs_cache(self.kb_name)
|
vector_store.save_local(self.vs_path)
|
||||||
|
refresh_vs_cache(self.kb_name)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -138,8 +138,10 @@ def knowledge_base_page(api: ApiRequest):
|
||||||
# use_container_width=True,
|
# use_container_width=True,
|
||||||
disabled=len(files) == 0,
|
disabled=len(files) == 0,
|
||||||
):
|
):
|
||||||
for f in files:
|
data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files]
|
||||||
ret = api.upload_kb_doc(f, kb)
|
data[-1]["not_refresh_vs_cache"]=False
|
||||||
|
for k in data:
|
||||||
|
ret = api.upload_kb_doc(**k)
|
||||||
if msg := check_success_msg(ret):
|
if msg := check_success_msg(ret):
|
||||||
st.toast(msg, icon="✔")
|
st.toast(msg, icon="✔")
|
||||||
elif msg := check_error_msg(ret):
|
elif msg := check_error_msg(ret):
|
||||||
|
|
|
||||||
|
|
@ -496,6 +496,7 @@ class ApiRequest:
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
filename: str = None,
|
filename: str = None,
|
||||||
override: bool = False,
|
override: bool = False,
|
||||||
|
not_refresh_vs_cache: bool = False,
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
|
|
@ -529,7 +530,11 @@ class ApiRequest:
|
||||||
else:
|
else:
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/knowledge_base/upload_doc",
|
"/knowledge_base/upload_doc",
|
||||||
data={"knowledge_base_name": knowledge_base_name, "override": override},
|
data={
|
||||||
|
"knowledge_base_name": knowledge_base_name,
|
||||||
|
"override": override,
|
||||||
|
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||||
|
},
|
||||||
files={"file": (filename, file)},
|
files={"file": (filename, file)},
|
||||||
)
|
)
|
||||||
return self._check_httpx_json_response(response)
|
return self._check_httpx_json_response(response)
|
||||||
|
|
@ -539,6 +544,7 @@ class ApiRequest:
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
doc_name: str,
|
doc_name: str,
|
||||||
delete_content: bool = False,
|
delete_content: bool = False,
|
||||||
|
not_refresh_vs_cache: bool = False,
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
|
|
@ -551,6 +557,7 @@ class ApiRequest:
|
||||||
"knowledge_base_name": knowledge_base_name,
|
"knowledge_base_name": knowledge_base_name,
|
||||||
"doc_name": doc_name,
|
"doc_name": doc_name,
|
||||||
"delete_content": delete_content,
|
"delete_content": delete_content,
|
||||||
|
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||||
}
|
}
|
||||||
|
|
||||||
if no_remote_api:
|
if no_remote_api:
|
||||||
|
|
@ -568,6 +575,7 @@ class ApiRequest:
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
file_name: str,
|
file_name: str,
|
||||||
|
not_refresh_vs_cache: bool = False,
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
|
|
@ -583,7 +591,11 @@ class ApiRequest:
|
||||||
else:
|
else:
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/knowledge_base/update_doc",
|
"/knowledge_base/update_doc",
|
||||||
json={"knowledge_base_name": knowledge_base_name, "file_name": file_name},
|
json={
|
||||||
|
"knowledge_base_name": knowledge_base_name,
|
||||||
|
"file_name": file_name,
|
||||||
|
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
return self._check_httpx_json_response(response)
|
return self._check_httpx_json_response(response)
|
||||||
|
|
||||||
|
|
@ -617,7 +629,7 @@ class ApiRequest:
|
||||||
"/knowledge_base/recreate_vector_store",
|
"/knowledge_base/recreate_vector_store",
|
||||||
json=data,
|
json=data,
|
||||||
stream=True,
|
stream=True,
|
||||||
timeout=False,
|
timeout=None,
|
||||||
)
|
)
|
||||||
return self._httpx_stream2generator(response, as_json=True)
|
return self._httpx_stream2generator(response, as_json=True)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue