From ff5f73e041d34fe79a51fe4387dee96072eb1772 Mon Sep 17 00:00:00 2001 From: NieLamu Date: Tue, 11 Jul 2023 19:52:52 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20fastapi=20=E6=8E=A5=E5=8F=A3=E4=BC=98?= =?UTF-8?q?=E5=8C=96=20(#684)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 接口增加参数校验,防止攻击 2. 优化接口参数和逻辑 3. 规范接口错误响应 4. 增加接口描述 Co-authored-by: imClumsyPanda --- api.py | 124 +++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 80 insertions(+), 44 deletions(-) diff --git a/api.py b/api.py index 3f50166..05fca50 100644 --- a/api.py +++ b/api.py @@ -79,23 +79,37 @@ class ChatMessage(BaseModel): } -def get_folder_path(local_doc_id: str): - return os.path.join(KB_ROOT_PATH, local_doc_id, "content") +def get_kb_path(local_doc_id: str): + return os.path.join(KB_ROOT_PATH, local_doc_id) + + +def get_doc_path(local_doc_id: str): + return os.path.join(get_kb_path(local_doc_id), "content") def get_vs_path(local_doc_id: str): - return os.path.join(KB_ROOT_PATH, local_doc_id, "vector_store") + return os.path.join(get_kb_path(local_doc_id), "vector_store") def get_file_path(local_doc_id: str, doc_name: str): - return os.path.join(KB_ROOT_PATH, local_doc_id, "content", doc_name) + return os.path.join(get_doc_path(local_doc_id), doc_name) + + +def validate_kb_name(knowledge_base_id: str) -> bool: + # 检查是否包含预期外的字符或路径攻击关键字 + if "../" in knowledge_base_id: + return False + return True async def upload_file( file: UploadFile = File(description="A single binary file"), knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"), ): - saved_path = get_folder_path(knowledge_base_id) + if not validate_kb_name(knowledge_base_id): + return BaseResponse(code=403, msg="Don't attack me", data=[]) + + saved_path = get_doc_path(knowledge_base_id) if not os.path.exists(saved_path): os.makedirs(saved_path) @@ -125,21 +139,25 @@ async def upload_files( ], knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"), ): - saved_path = get_folder_path(knowledge_base_id) + if not validate_kb_name(knowledge_base_id): + return BaseResponse(code=403, msg="Don't attack me", data=[]) + + saved_path = get_doc_path(knowledge_base_id) if not os.path.exists(saved_path): os.makedirs(saved_path) filelist = [] for file in files: file_content = '' file_path = os.path.join(saved_path, file.filename) - file_content = file.file.read() + file_content = await file.read() if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content): continue - with open(file_path, "ab+") as f: + with open(file_path, "wb") as f: f.write(file_content) filelist.append(file_path) if filelist: - vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, get_vs_path(knowledge_base_id)) + vs_path = get_vs_path(knowledge_base_id) + vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path) if len(loaded_files): file_status = f"documents {', '.join([os.path.split(i)[-1] for i in loaded_files])} upload success" return BaseResponse(code=200, msg=file_status) @@ -163,16 +181,24 @@ async def list_kbs(): async def list_docs( - knowledge_base_id: Optional[str] = Query(default=None, description="Knowledge Base Name", example="kb1") + knowledge_base_id: str = Query(..., description="Knowledge Base Name", example="kb1") ): - local_doc_folder = get_folder_path(knowledge_base_id) + if not validate_kb_name(knowledge_base_id): + return ListDocsResponse(code=403, msg="Don't attack me", data=[]) + + knowledge_base_id = urllib.parse.unquote(knowledge_base_id) + kb_path = get_kb_path(knowledge_base_id) + local_doc_folder = get_doc_path(knowledge_base_id) + if not os.path.exists(kb_path): + return ListDocsResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found", data=[]) if not os.path.exists(local_doc_folder): - return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"} - all_doc_names = [ - doc - for doc in os.listdir(local_doc_folder) - if os.path.isfile(os.path.join(local_doc_folder, doc)) - ] + all_doc_names = [] + else: + all_doc_names = [ + doc + for doc in os.listdir(local_doc_folder) + if os.path.isfile(os.path.join(local_doc_folder, doc)) + ] return ListDocsResponse(data=all_doc_names) @@ -181,11 +207,15 @@ async def delete_kb( description="Knowledge Base Name", example="kb1"), ): + if not validate_kb_name(knowledge_base_id): + return BaseResponse(code=403, msg="Don't attack me") + # TODO: 确认是否支持批量删除知识库 knowledge_base_id = urllib.parse.unquote(knowledge_base_id) - if not os.path.exists(get_folder_path(knowledge_base_id)): - return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"} - shutil.rmtree(get_folder_path(knowledge_base_id)) + kb_path = get_kb_path(knowledge_base_id) + if not os.path.exists(kb_path): + return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found") + shutil.rmtree(kb_path) return BaseResponse(code=200, msg=f"Knowledge Base {knowledge_base_id} delete success") @@ -194,27 +224,30 @@ async def delete_doc( description="Knowledge Base Name", example="kb1"), doc_name: str = Query( - None, description="doc name", example="doc_name_1.pdf" + ..., description="doc name", example="doc_name_1.pdf" ), ): + if not validate_kb_name(knowledge_base_id): + return BaseResponse(code=403, msg="Don't attack me") + knowledge_base_id = urllib.parse.unquote(knowledge_base_id) - if not os.path.exists(get_folder_path(knowledge_base_id)): - return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"} + if not os.path.exists(get_kb_path(knowledge_base_id)): + return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found") doc_path = get_file_path(knowledge_base_id, doc_name) if os.path.exists(doc_path): os.remove(doc_path) remain_docs = await list_docs(knowledge_base_id) if len(remain_docs.data) == 0: - shutil.rmtree(get_folder_path(knowledge_base_id), ignore_errors=True) + shutil.rmtree(get_kb_path(knowledge_base_id), ignore_errors=True) return BaseResponse(code=200, msg=f"document {doc_name} delete success") else: status = local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_id)) if "success" in status: return BaseResponse(code=200, msg=f"document {doc_name} delete success") else: - return BaseResponse(code=1, msg=f"document {doc_name} delete fail") + return BaseResponse(code=500, msg=f"document {doc_name} delete fail") else: - return BaseResponse(code=1, msg=f"document {doc_name} not found") + return BaseResponse(code=404, msg=f"document {doc_name} not found") async def update_doc( @@ -222,23 +255,26 @@ async def update_doc( description="知识库名", example="kb1"), old_doc: str = Query( - None, description="待删除文件名,已存储在知识库中", example="doc_name_1.pdf" + ..., description="待删除文件名,已存储在知识库中", example="doc_name_1.pdf" ), new_doc: UploadFile = File(description="待上传文件"), ): + if not validate_kb_name(knowledge_base_id): + return BaseResponse(code=403, msg="Don't attack me") + knowledge_base_id = urllib.parse.unquote(knowledge_base_id) - if not os.path.exists(get_folder_path(knowledge_base_id)): - return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"} + if not os.path.exists(get_kb_path(knowledge_base_id)): + return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found") doc_path = get_file_path(knowledge_base_id, old_doc) if not os.path.exists(doc_path): - return BaseResponse(code=1, msg=f"document {old_doc} not found") + return BaseResponse(code=404, msg=f"document {old_doc} not found") else: os.remove(doc_path) delete_status = local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_id)) if "fail" in delete_status: - return BaseResponse(code=1, msg=f"document {old_doc} delete failed") + return BaseResponse(code=500, msg=f"document {old_doc} delete failed") else: - saved_path = get_folder_path(knowledge_base_id) + saved_path = get_doc_path(knowledge_base_id) if not os.path.exists(saved_path): os.makedirs(saved_path) @@ -279,7 +315,7 @@ async def local_doc_chat( ): vs_path = get_vs_path(knowledge_base_id) if not os.path.exists(vs_path): - # return BaseResponse(code=1, msg=f"Knowledge base {knowledge_base_id} not found") + # return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found") return ChatMessage( question=question, response=f"Knowledge base {knowledge_base_id} not found", @@ -467,7 +503,7 @@ def api_start(host, port, **kwargs): # 修改了stream_chat的接口,直接通过ws://localhost:7861/local_doc_qa/stream_chat建立连接,在请求体中选择knowledge_base_id app.websocket("/local_doc_qa/stream_chat")(stream_chat) - app.get("/", response_model=BaseResponse)(document) + app.get("/", response_model=BaseResponse, summary="swagger 文档")(document) # 增加基于bing搜索的流式问答 # 需要说明的是,如果想测试websocket的流式问答,需要使用支持websocket的测试工具,如postman,insomnia @@ -475,17 +511,17 @@ def api_start(host, port, **kwargs): # 在测试时选择new websocket request,并将url的协议改为ws,如ws://localhost:7861/local_doc_qa/stream_chat_bing app.websocket("/local_doc_qa/stream_chat_bing")(stream_chat_bing) - app.post("/chat", response_model=ChatMessage)(chat) + app.post("/chat", response_model=ChatMessage, summary="与模型对话")(chat) - app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file) - app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files) - app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat) - app.post("/local_doc_qa/bing_search_chat", response_model=ChatMessage)(bing_search_chat) - app.get("/local_doc_qa/list_knowledge_base", response_model=ListDocsResponse)(list_kbs) - app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs) - app.delete("/local_doc_qa/delete_knowledge_base", response_model=BaseResponse)(delete_kb) - app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_doc) - app.post("/local_doc_qa/update_file", response_model=BaseResponse)(update_doc) + app.post("/local_doc_qa/upload_file", response_model=BaseResponse, summary="上传文件到知识库")(upload_file) + app.post("/local_doc_qa/upload_files", response_model=BaseResponse, summary="批量上传文件到知识库")(upload_files) + app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage, summary="与知识库对话")(local_doc_chat) + app.post("/local_doc_qa/bing_search_chat", response_model=ChatMessage, summary="与必应搜索对话")(bing_search_chat) + app.get("/local_doc_qa/list_knowledge_base", response_model=ListDocsResponse, summary="获取知识库列表")(list_kbs) + app.get("/local_doc_qa/list_files", response_model=ListDocsResponse, summary="获取知识库内的文件列表")(list_docs) + app.delete("/local_doc_qa/delete_knowledge_base", response_model=BaseResponse, summary="删除知识库")(delete_kb) + app.delete("/local_doc_qa/delete_file", response_model=BaseResponse, summary="删除知识库内的文件")(delete_doc) + app.post("/local_doc_qa/update_file", response_model=BaseResponse, summary="上传文件到知识库,并删除另一个文件")(update_doc) local_doc_qa = LocalDocQA() local_doc_qa.init_cfg(