update api.py
This commit is contained in:
parent
a887df1715
commit
11dd2b5b84
73
api.py
73
api.py
|
|
@ -141,9 +141,9 @@ async def upload_files(
|
|||
if filelist:
|
||||
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, get_vs_path(knowledge_base_id))
|
||||
if len(loaded_files):
|
||||
file_status = f"已上传 {'、'.join([os.path.split(i)[-1] for i in loaded_files])} 至知识库,并已加载知识库,请开始提问"
|
||||
file_status = f"documents {', '.join([os.path.split(i)[-1] for i in loaded_files])} upload success"
|
||||
return BaseResponse(code=200, msg=file_status)
|
||||
file_status = "文件未成功加载,请重新上传文件"
|
||||
file_status = f"documents {', '.join([os.path.split(i)[-1] for i in loaded_files])} upload fail"
|
||||
return BaseResponse(code=500, msg=file_status)
|
||||
|
||||
|
||||
|
|
@ -176,7 +176,7 @@ async def list_docs(
|
|||
return ListDocsResponse(data=all_doc_names)
|
||||
|
||||
|
||||
async def delete_kbs(
|
||||
async def delete_kb(
|
||||
knowledge_base_id: str = Query(...,
|
||||
description="Knowledge Base Name",
|
||||
example="kb1"),
|
||||
|
|
@ -189,7 +189,7 @@ async def delete_kbs(
|
|||
return BaseResponse(code=200, msg=f"Knowledge Base {knowledge_base_id} delete success")
|
||||
|
||||
|
||||
async def delete_docs(
|
||||
async def delete_doc(
|
||||
knowledge_base_id: str = Query(...,
|
||||
description="Knowledge Base Name",
|
||||
example="kb1"),
|
||||
|
|
@ -197,28 +197,72 @@ async def delete_docs(
|
|||
None, description="doc name", example="doc_name_1.pdf"
|
||||
),
|
||||
):
|
||||
# TODO: 确认是否支持批量删除文件
|
||||
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
|
||||
if not os.path.exists(get_folder_path(knowledge_base_id)):
|
||||
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
|
||||
doc_path = get_file_path(knowledge_base_id, doc_name)
|
||||
if os.path.exists(doc_path):
|
||||
os.remove(doc_path)
|
||||
|
||||
# 删除上传的文件后重新生成知识库(FAISS)内的数据
|
||||
# TODO: 删除向量库中对应文件
|
||||
remain_docs = await list_docs(knowledge_base_id)
|
||||
if len(remain_docs.data) == 0:
|
||||
shutil.rmtree(get_folder_path(knowledge_base_id), ignore_errors=True)
|
||||
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
|
||||
else:
|
||||
local_doc_qa.init_knowledge_vector_store(
|
||||
get_folder_path(knowledge_base_id), get_vs_path(knowledge_base_id)
|
||||
)
|
||||
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
|
||||
status = local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_id))
|
||||
if "success" in status:
|
||||
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
|
||||
else:
|
||||
return BaseResponse(code=1, msg=f"document {doc_name} delete fail")
|
||||
else:
|
||||
return BaseResponse(code=1, msg=f"document {doc_name} not found")
|
||||
|
||||
|
||||
async def update_doc(
|
||||
knowledge_base_id: str = Query(...,
|
||||
description="知识库名",
|
||||
example="kb1"),
|
||||
old_doc: str = Query(
|
||||
None, description="待删除文件名,已存储在知识库中", example="doc_name_1.pdf"
|
||||
),
|
||||
new_doc: UploadFile = File(description="待上传文件"),
|
||||
):
|
||||
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
|
||||
if not os.path.exists(get_folder_path(knowledge_base_id)):
|
||||
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
|
||||
doc_path = get_file_path(knowledge_base_id, old_doc)
|
||||
if not os.path.exists(doc_path):
|
||||
return BaseResponse(code=1, msg=f"document {old_doc} not found")
|
||||
else:
|
||||
os.remove(doc_path)
|
||||
delete_status = local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_id))
|
||||
if "fail" in delete_status:
|
||||
return BaseResponse(code=1, msg=f"document {old_doc} delete failed")
|
||||
else:
|
||||
saved_path = get_folder_path(knowledge_base_id)
|
||||
if not os.path.exists(saved_path):
|
||||
os.makedirs(saved_path)
|
||||
|
||||
file_content = await new_doc.read() # 读取上传文件的内容
|
||||
|
||||
file_path = os.path.join(saved_path, new_doc.filename)
|
||||
if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content):
|
||||
file_status = f"document {new_doc.filename} already exists"
|
||||
return BaseResponse(code=200, msg=file_status)
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
|
||||
vs_path = get_vs_path(knowledge_base_id)
|
||||
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store([file_path], vs_path)
|
||||
if len(loaded_files) > 0:
|
||||
file_status = f"document {old_doc} delete and document {new_doc.filename} upload success"
|
||||
return BaseResponse(code=200, msg=file_status)
|
||||
else:
|
||||
file_status = f"document {old_doc} success but document {new_doc.filename} upload fail"
|
||||
return BaseResponse(code=500, msg=file_status)
|
||||
|
||||
|
||||
|
||||
async def local_doc_chat(
|
||||
knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
|
||||
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
||||
|
|
@ -394,8 +438,11 @@ def api_start(host, port):
|
|||
app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files)
|
||||
app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat)
|
||||
app.post("/local_doc_qa/bing_search_chat", response_model=ChatMessage)(bing_search_chat)
|
||||
app.get("/local_doc_qa/list_knowledge_base", response_model=ListDocsResponse)(list_kbs)
|
||||
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs)
|
||||
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_docs)
|
||||
app.delete("/local_doc_qa/delete_knowledge_base", response_model=BaseResponse)(delete_kb)
|
||||
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_doc)
|
||||
app.post("/local_doc_qa/update_file", response_model=BaseResponse)(update_doc)
|
||||
|
||||
local_doc_qa = LocalDocQA()
|
||||
local_doc_qa.init_cfg(
|
||||
|
|
|
|||
|
|
@ -282,6 +282,21 @@ class LocalDocQA:
|
|||
"source_documents": result_docs}
|
||||
yield response, history
|
||||
|
||||
def delete_file_from_vector_store(self,
|
||||
filepath: str or List[str],
|
||||
vs_path):
|
||||
vector_store = load_vector_store(vs_path, self.embeddings)
|
||||
status = vector_store.delete_doc(filepath)
|
||||
return status
|
||||
|
||||
def update_file_from_vector_store(self,
|
||||
filepath: str or List[str],
|
||||
vs_path,
|
||||
docs: List[Document],):
|
||||
vector_store = load_vector_store(vs_path, self.embeddings)
|
||||
status = vector_store.update_doc(filepath, docs)
|
||||
return status
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 初始化消息
|
||||
|
|
|
|||
|
|
@ -108,14 +108,20 @@ class MyFAISS(FAISS, VectorStore):
|
|||
return docs
|
||||
|
||||
def delete_doc(self, source):
|
||||
ids = [k for k, v in self.docstore._dict.items() if v.metadata["source"] == source]
|
||||
for id in ids:
|
||||
index = list(self.index_to_docstore_id.keys())[list(self.index_to_docstore_id.values()).index(id)]
|
||||
self.index_to_docstore_id.pop(index)
|
||||
self.docstore._dict.pop(id)
|
||||
return f"{len(ids)} docs deleted"
|
||||
try:
|
||||
ids = [k for k, v in self.docstore._dict.items() if v.metadata["source"] == source]
|
||||
for id in ids:
|
||||
index = list(self.index_to_docstore_id.keys())[list(self.index_to_docstore_id.values()).index(id)]
|
||||
self.index_to_docstore_id.pop(index)
|
||||
self.docstore._dict.pop(id)
|
||||
return f"docs delete success"
|
||||
except:
|
||||
return f"docs delete fail"
|
||||
|
||||
def update_doc(self, source, new_docs):
|
||||
delete_len = self.delete_doc(source)
|
||||
ls = self.add_documents(new_docs)
|
||||
return f"{delete_len} docs deleted, {len(ls)} added", ls
|
||||
try:
|
||||
delete_len = self.delete_doc(source)
|
||||
ls = self.add_documents(new_docs)
|
||||
return f"docs update success"
|
||||
except:
|
||||
return f"docs update fail"
|
||||
|
|
|
|||
Loading…
Reference in New Issue