diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 749f7c2..2773143 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -30,6 +30,7 @@ async def list_docs(knowledge_base_name: str): async def upload_doc(file: UploadFile = File(description="上传文件"), knowledge_base_name: str = Form(..., description="知识库名称", example="kb1"), + override: bool = Form(False, description="覆盖已有文件", example=False), ): if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -41,7 +42,10 @@ async def upload_doc(file: UploadFile = File(description="上传文件"), file_content = await file.read() # 读取上传文件的内容 file_path = os.path.join(saved_path, file.filename) - if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content): + if (os.path.exists(file_path) + and not override + and os.path.getsize(file_path) == len(file_content) + ): file_status = f"文件 {file.filename} 已存在。" return BaseResponse(code=404, msg=file_status) diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 1f7ab5d..8388e25 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -1,6 +1,4 @@ # 该文件包含webui通用工具,可以被不同的webui使用 - -import tempfile from typing import * from pathlib import Path import os @@ -17,6 +15,7 @@ from server.chat.openai_chat import OpenAiChatMsgIn from fastapi.responses import StreamingResponse import contextlib import json +from io import BytesIO def set_httpx_timeout(timeout=60.0): @@ -383,8 +382,10 @@ class ApiRequest: def upload_kb_doc( self, - file: Union[str, Path], + file: Union[str, Path, bytes], knowledge_base_name: str, + filename: str = None, + override: bool = False, no_remote_api: bool = None, ): ''' @@ -393,8 +394,11 @@ class ApiRequest: if no_remote_api is None: no_remote_api = self.no_remote_api - file = Path(file).absolute() - filename = file.name + if isinstance(file, bytes): + file = BytesIO(file) + else: + file = Path(file).absolute().open("rb") + filename = filename or file.name if no_remote_api: from server.knowledge_base.kb_doc_api import upload_doc @@ -402,18 +406,18 @@ class ApiRequest: from tempfile import SpooledTemporaryFile temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024) - with file.open("rb") as fp: - temp_file.write(fp.read()) + temp_file.write(file.read()) response = run_async(upload_doc( UploadFile(temp_file, filename=filename), knowledge_base_name, + override, )) return response.dict() else: response = self.post( "/knowledge_base/upload_doc", - data={"knowledge_base_name": knowledge_base_name}, - files={"file": (filename, file.open("rb"))}, + data={"knowledge_base_name": knowledge_base_name, "override": override}, + files={"file": (filename, file)}, ) return response.json()