allow kb_doc_api.upload_doc to override existed file by parameter.

update_doc is not needed.
This commit is contained in:
liunux4odoo 2023-08-04 15:53:44 +08:00
parent b4f80ca370
commit 46c7d8d169
2 changed files with 18 additions and 10 deletions

View File

@ -30,6 +30,7 @@ async def list_docs(knowledge_base_name: str):
async def upload_doc(file: UploadFile = File(description="上传文件"),
knowledge_base_name: str = Form(..., description="知识库名称", example="kb1"),
override: bool = Form(False, description="覆盖已有文件", example=False),
):
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -41,7 +42,10 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
file_content = await file.read() # 读取上传文件的内容
file_path = os.path.join(saved_path, file.filename)
if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content):
if (os.path.exists(file_path)
and not override
and os.path.getsize(file_path) == len(file_content)
):
file_status = f"文件 {file.filename} 已存在。"
return BaseResponse(code=404, msg=file_status)

View File

@ -1,6 +1,4 @@
# 该文件包含webui通用工具可以被不同的webui使用
import tempfile
from typing import *
from pathlib import Path
import os
@ -17,6 +15,7 @@ from server.chat.openai_chat import OpenAiChatMsgIn
from fastapi.responses import StreamingResponse
import contextlib
import json
from io import BytesIO
def set_httpx_timeout(timeout=60.0):
@ -383,8 +382,10 @@ class ApiRequest:
def upload_kb_doc(
self,
file: Union[str, Path],
file: Union[str, Path, bytes],
knowledge_base_name: str,
filename: str = None,
override: bool = False,
no_remote_api: bool = None,
):
'''
@ -393,8 +394,11 @@ class ApiRequest:
if no_remote_api is None:
no_remote_api = self.no_remote_api
file = Path(file).absolute()
filename = file.name
if isinstance(file, bytes):
file = BytesIO(file)
else:
file = Path(file).absolute().open("rb")
filename = filename or file.name
if no_remote_api:
from server.knowledge_base.kb_doc_api import upload_doc
@ -402,18 +406,18 @@ class ApiRequest:
from tempfile import SpooledTemporaryFile
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
with file.open("rb") as fp:
temp_file.write(fp.read())
temp_file.write(file.read())
response = run_async(upload_doc(
UploadFile(temp_file, filename=filename),
knowledge_base_name,
override,
))
return response.dict()
else:
response = self.post(
"/knowledge_base/upload_doc",
data={"knowledge_base_name": knowledge_base_name},
files={"file": (filename, file.open("rb"))},
data={"knowledge_base_name": knowledge_base_name, "override": override},
files={"file": (filename, file)},
)
return response.json()