feat: fastapi 接口优化 (#684)
1. 接口增加参数校验,防止攻击 2. 优化接口参数和逻辑 3. 规范接口错误响应 4. 增加接口描述 Co-authored-by: imClumsyPanda <littlepanda0716@gmail.com>
This commit is contained in:
parent
a5ca4bf26a
commit
ff5f73e041
124
api.py
124
api.py
|
|
@ -79,23 +79,37 @@ class ChatMessage(BaseModel):
|
|||
}
|
||||
|
||||
|
||||
def get_folder_path(local_doc_id: str):
|
||||
return os.path.join(KB_ROOT_PATH, local_doc_id, "content")
|
||||
def get_kb_path(local_doc_id: str):
|
||||
return os.path.join(KB_ROOT_PATH, local_doc_id)
|
||||
|
||||
|
||||
def get_doc_path(local_doc_id: str):
|
||||
return os.path.join(get_kb_path(local_doc_id), "content")
|
||||
|
||||
|
||||
def get_vs_path(local_doc_id: str):
|
||||
return os.path.join(KB_ROOT_PATH, local_doc_id, "vector_store")
|
||||
return os.path.join(get_kb_path(local_doc_id), "vector_store")
|
||||
|
||||
|
||||
def get_file_path(local_doc_id: str, doc_name: str):
|
||||
return os.path.join(KB_ROOT_PATH, local_doc_id, "content", doc_name)
|
||||
return os.path.join(get_doc_path(local_doc_id), doc_name)
|
||||
|
||||
|
||||
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||
# 检查是否包含预期外的字符或路径攻击关键字
|
||||
if "../" in knowledge_base_id:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def upload_file(
|
||||
file: UploadFile = File(description="A single binary file"),
|
||||
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
|
||||
):
|
||||
saved_path = get_folder_path(knowledge_base_id)
|
||||
if not validate_kb_name(knowledge_base_id):
|
||||
return BaseResponse(code=403, msg="Don't attack me", data=[])
|
||||
|
||||
saved_path = get_doc_path(knowledge_base_id)
|
||||
if not os.path.exists(saved_path):
|
||||
os.makedirs(saved_path)
|
||||
|
||||
|
|
@ -125,21 +139,25 @@ async def upload_files(
|
|||
],
|
||||
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
|
||||
):
|
||||
saved_path = get_folder_path(knowledge_base_id)
|
||||
if not validate_kb_name(knowledge_base_id):
|
||||
return BaseResponse(code=403, msg="Don't attack me", data=[])
|
||||
|
||||
saved_path = get_doc_path(knowledge_base_id)
|
||||
if not os.path.exists(saved_path):
|
||||
os.makedirs(saved_path)
|
||||
filelist = []
|
||||
for file in files:
|
||||
file_content = ''
|
||||
file_path = os.path.join(saved_path, file.filename)
|
||||
file_content = file.file.read()
|
||||
file_content = await file.read()
|
||||
if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content):
|
||||
continue
|
||||
with open(file_path, "ab+") as f:
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
filelist.append(file_path)
|
||||
if filelist:
|
||||
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, get_vs_path(knowledge_base_id))
|
||||
vs_path = get_vs_path(knowledge_base_id)
|
||||
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path)
|
||||
if len(loaded_files):
|
||||
file_status = f"documents {', '.join([os.path.split(i)[-1] for i in loaded_files])} upload success"
|
||||
return BaseResponse(code=200, msg=file_status)
|
||||
|
|
@ -163,16 +181,24 @@ async def list_kbs():
|
|||
|
||||
|
||||
async def list_docs(
|
||||
knowledge_base_id: Optional[str] = Query(default=None, description="Knowledge Base Name", example="kb1")
|
||||
knowledge_base_id: str = Query(..., description="Knowledge Base Name", example="kb1")
|
||||
):
|
||||
local_doc_folder = get_folder_path(knowledge_base_id)
|
||||
if not validate_kb_name(knowledge_base_id):
|
||||
return ListDocsResponse(code=403, msg="Don't attack me", data=[])
|
||||
|
||||
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
|
||||
kb_path = get_kb_path(knowledge_base_id)
|
||||
local_doc_folder = get_doc_path(knowledge_base_id)
|
||||
if not os.path.exists(kb_path):
|
||||
return ListDocsResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found", data=[])
|
||||
if not os.path.exists(local_doc_folder):
|
||||
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
|
||||
all_doc_names = [
|
||||
doc
|
||||
for doc in os.listdir(local_doc_folder)
|
||||
if os.path.isfile(os.path.join(local_doc_folder, doc))
|
||||
]
|
||||
all_doc_names = []
|
||||
else:
|
||||
all_doc_names = [
|
||||
doc
|
||||
for doc in os.listdir(local_doc_folder)
|
||||
if os.path.isfile(os.path.join(local_doc_folder, doc))
|
||||
]
|
||||
return ListDocsResponse(data=all_doc_names)
|
||||
|
||||
|
||||
|
|
@ -181,11 +207,15 @@ async def delete_kb(
|
|||
description="Knowledge Base Name",
|
||||
example="kb1"),
|
||||
):
|
||||
if not validate_kb_name(knowledge_base_id):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
# TODO: 确认是否支持批量删除知识库
|
||||
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
|
||||
if not os.path.exists(get_folder_path(knowledge_base_id)):
|
||||
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
|
||||
shutil.rmtree(get_folder_path(knowledge_base_id))
|
||||
kb_path = get_kb_path(knowledge_base_id)
|
||||
if not os.path.exists(kb_path):
|
||||
return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found")
|
||||
shutil.rmtree(kb_path)
|
||||
return BaseResponse(code=200, msg=f"Knowledge Base {knowledge_base_id} delete success")
|
||||
|
||||
|
||||
|
|
@ -194,27 +224,30 @@ async def delete_doc(
|
|||
description="Knowledge Base Name",
|
||||
example="kb1"),
|
||||
doc_name: str = Query(
|
||||
None, description="doc name", example="doc_name_1.pdf"
|
||||
..., description="doc name", example="doc_name_1.pdf"
|
||||
),
|
||||
):
|
||||
if not validate_kb_name(knowledge_base_id):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
|
||||
if not os.path.exists(get_folder_path(knowledge_base_id)):
|
||||
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
|
||||
if not os.path.exists(get_kb_path(knowledge_base_id)):
|
||||
return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found")
|
||||
doc_path = get_file_path(knowledge_base_id, doc_name)
|
||||
if os.path.exists(doc_path):
|
||||
os.remove(doc_path)
|
||||
remain_docs = await list_docs(knowledge_base_id)
|
||||
if len(remain_docs.data) == 0:
|
||||
shutil.rmtree(get_folder_path(knowledge_base_id), ignore_errors=True)
|
||||
shutil.rmtree(get_kb_path(knowledge_base_id), ignore_errors=True)
|
||||
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
|
||||
else:
|
||||
status = local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_id))
|
||||
if "success" in status:
|
||||
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
|
||||
else:
|
||||
return BaseResponse(code=1, msg=f"document {doc_name} delete fail")
|
||||
return BaseResponse(code=500, msg=f"document {doc_name} delete fail")
|
||||
else:
|
||||
return BaseResponse(code=1, msg=f"document {doc_name} not found")
|
||||
return BaseResponse(code=404, msg=f"document {doc_name} not found")
|
||||
|
||||
|
||||
async def update_doc(
|
||||
|
|
@ -222,23 +255,26 @@ async def update_doc(
|
|||
description="知识库名",
|
||||
example="kb1"),
|
||||
old_doc: str = Query(
|
||||
None, description="待删除文件名,已存储在知识库中", example="doc_name_1.pdf"
|
||||
..., description="待删除文件名,已存储在知识库中", example="doc_name_1.pdf"
|
||||
),
|
||||
new_doc: UploadFile = File(description="待上传文件"),
|
||||
):
|
||||
if not validate_kb_name(knowledge_base_id):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
|
||||
if not os.path.exists(get_folder_path(knowledge_base_id)):
|
||||
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
|
||||
if not os.path.exists(get_kb_path(knowledge_base_id)):
|
||||
return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found")
|
||||
doc_path = get_file_path(knowledge_base_id, old_doc)
|
||||
if not os.path.exists(doc_path):
|
||||
return BaseResponse(code=1, msg=f"document {old_doc} not found")
|
||||
return BaseResponse(code=404, msg=f"document {old_doc} not found")
|
||||
else:
|
||||
os.remove(doc_path)
|
||||
delete_status = local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_id))
|
||||
if "fail" in delete_status:
|
||||
return BaseResponse(code=1, msg=f"document {old_doc} delete failed")
|
||||
return BaseResponse(code=500, msg=f"document {old_doc} delete failed")
|
||||
else:
|
||||
saved_path = get_folder_path(knowledge_base_id)
|
||||
saved_path = get_doc_path(knowledge_base_id)
|
||||
if not os.path.exists(saved_path):
|
||||
os.makedirs(saved_path)
|
||||
|
||||
|
|
@ -279,7 +315,7 @@ async def local_doc_chat(
|
|||
):
|
||||
vs_path = get_vs_path(knowledge_base_id)
|
||||
if not os.path.exists(vs_path):
|
||||
# return BaseResponse(code=1, msg=f"Knowledge base {knowledge_base_id} not found")
|
||||
# return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found")
|
||||
return ChatMessage(
|
||||
question=question,
|
||||
response=f"Knowledge base {knowledge_base_id} not found",
|
||||
|
|
@ -467,7 +503,7 @@ def api_start(host, port, **kwargs):
|
|||
# 修改了stream_chat的接口,直接通过ws://localhost:7861/local_doc_qa/stream_chat建立连接,在请求体中选择knowledge_base_id
|
||||
app.websocket("/local_doc_qa/stream_chat")(stream_chat)
|
||||
|
||||
app.get("/", response_model=BaseResponse)(document)
|
||||
app.get("/", response_model=BaseResponse, summary="swagger 文档")(document)
|
||||
|
||||
# 增加基于bing搜索的流式问答
|
||||
# 需要说明的是,如果想测试websocket的流式问答,需要使用支持websocket的测试工具,如postman,insomnia
|
||||
|
|
@ -475,17 +511,17 @@ def api_start(host, port, **kwargs):
|
|||
# 在测试时选择new websocket request,并将url的协议改为ws,如ws://localhost:7861/local_doc_qa/stream_chat_bing
|
||||
app.websocket("/local_doc_qa/stream_chat_bing")(stream_chat_bing)
|
||||
|
||||
app.post("/chat", response_model=ChatMessage)(chat)
|
||||
app.post("/chat", response_model=ChatMessage, summary="与模型对话")(chat)
|
||||
|
||||
app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file)
|
||||
app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files)
|
||||
app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat)
|
||||
app.post("/local_doc_qa/bing_search_chat", response_model=ChatMessage)(bing_search_chat)
|
||||
app.get("/local_doc_qa/list_knowledge_base", response_model=ListDocsResponse)(list_kbs)
|
||||
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs)
|
||||
app.delete("/local_doc_qa/delete_knowledge_base", response_model=BaseResponse)(delete_kb)
|
||||
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_doc)
|
||||
app.post("/local_doc_qa/update_file", response_model=BaseResponse)(update_doc)
|
||||
app.post("/local_doc_qa/upload_file", response_model=BaseResponse, summary="上传文件到知识库")(upload_file)
|
||||
app.post("/local_doc_qa/upload_files", response_model=BaseResponse, summary="批量上传文件到知识库")(upload_files)
|
||||
app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage, summary="与知识库对话")(local_doc_chat)
|
||||
app.post("/local_doc_qa/bing_search_chat", response_model=ChatMessage, summary="与必应搜索对话")(bing_search_chat)
|
||||
app.get("/local_doc_qa/list_knowledge_base", response_model=ListDocsResponse, summary="获取知识库列表")(list_kbs)
|
||||
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse, summary="获取知识库内的文件列表")(list_docs)
|
||||
app.delete("/local_doc_qa/delete_knowledge_base", response_model=BaseResponse, summary="删除知识库")(delete_kb)
|
||||
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse, summary="删除知识库内的文件")(delete_doc)
|
||||
app.post("/local_doc_qa/update_file", response_model=BaseResponse, summary="上传文件到知识库,并删除另一个文件")(update_doc)
|
||||
|
||||
local_doc_qa = LocalDocQA()
|
||||
local_doc_qa.init_cfg(
|
||||
|
|
|
|||
Loading…
Reference in New Issue