diff --git a/api.py b/api.py index 1ca5d75..9c93c0e 100644 --- a/api.py +++ b/api.py @@ -88,6 +88,32 @@ def get_vs_path(local_doc_id: str): def get_file_path(local_doc_id: str, doc_name: str): return os.path.join(UPLOAD_ROOT_PATH, local_doc_id, doc_name) +async def upload_file( + file: UploadFile = File(description="A single binary file"), + knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"), +): + saved_path = get_folder_path(knowledge_base_id) + if not os.path.exists(saved_path): + os.makedirs(saved_path) + + file_content = await file.read() # 读取上传文件的内容 + + file_path = os.path.join(saved_path, file.filename) + if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content): + file_status = f"文件 {file.filename} 已存在。" + return BaseResponse(code=200, msg=file_status) + + with open(file_path, "wb") as f: + f.write(file_content) + + vs_path = get_vs_path(knowledge_base_id) + vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store([file_path], vs_path) + if len(loaded_files) > 0: + file_status = f"文件 {file.filename} 已上传至新的知识库,并已加载知识库,请开始提问。" + return BaseResponse(code=200, msg=file_status) + else: + file_status = "文件上传失败,请重新上传" + return BaseResponse(code=500, msg=file_status) async def upload_files( files: Annotated[ @@ -306,6 +332,7 @@ def main(): app.post("/chat", response_model=ChatMessage)(chat) + app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file) app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files) app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat) app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs)