allow kb_doc_api.upload_doc to override existed file by parameter.

update_doc is not needed.
This commit is contained in:
liunux4odoo 2023-08-04 15:53:44 +08:00
parent b4f80ca370
commit 46c7d8d169
2 changed files with 18 additions and 10 deletions

View File

@ -30,6 +30,7 @@ async def list_docs(knowledge_base_name: str):
async def upload_doc(file: UploadFile = File(description="上传文件"), async def upload_doc(file: UploadFile = File(description="上传文件"),
knowledge_base_name: str = Form(..., description="知识库名称", example="kb1"), knowledge_base_name: str = Form(..., description="知识库名称", example="kb1"),
override: bool = Form(False, description="覆盖已有文件", example=False),
): ):
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")
@ -41,7 +42,10 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
file_content = await file.read() # 读取上传文件的内容 file_content = await file.read() # 读取上传文件的内容
file_path = os.path.join(saved_path, file.filename) file_path = os.path.join(saved_path, file.filename)
if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content): if (os.path.exists(file_path)
and not override
and os.path.getsize(file_path) == len(file_content)
):
file_status = f"文件 {file.filename} 已存在。" file_status = f"文件 {file.filename} 已存在。"
return BaseResponse(code=404, msg=file_status) return BaseResponse(code=404, msg=file_status)

View File

@ -1,6 +1,4 @@
# 该文件包含webui通用工具可以被不同的webui使用 # 该文件包含webui通用工具可以被不同的webui使用
import tempfile
from typing import * from typing import *
from pathlib import Path from pathlib import Path
import os import os
@ -17,6 +15,7 @@ from server.chat.openai_chat import OpenAiChatMsgIn
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import contextlib import contextlib
import json import json
from io import BytesIO
def set_httpx_timeout(timeout=60.0): def set_httpx_timeout(timeout=60.0):
@ -383,8 +382,10 @@ class ApiRequest:
def upload_kb_doc( def upload_kb_doc(
self, self,
file: Union[str, Path], file: Union[str, Path, bytes],
knowledge_base_name: str, knowledge_base_name: str,
filename: str = None,
override: bool = False,
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
''' '''
@ -393,8 +394,11 @@ class ApiRequest:
if no_remote_api is None: if no_remote_api is None:
no_remote_api = self.no_remote_api no_remote_api = self.no_remote_api
file = Path(file).absolute() if isinstance(file, bytes):
filename = file.name file = BytesIO(file)
else:
file = Path(file).absolute().open("rb")
filename = filename or file.name
if no_remote_api: if no_remote_api:
from server.knowledge_base.kb_doc_api import upload_doc from server.knowledge_base.kb_doc_api import upload_doc
@ -402,18 +406,18 @@ class ApiRequest:
from tempfile import SpooledTemporaryFile from tempfile import SpooledTemporaryFile
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024) temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
with file.open("rb") as fp: temp_file.write(file.read())
temp_file.write(fp.read())
response = run_async(upload_doc( response = run_async(upload_doc(
UploadFile(temp_file, filename=filename), UploadFile(temp_file, filename=filename),
knowledge_base_name, knowledge_base_name,
override,
)) ))
return response.dict() return response.dict()
else: else:
response = self.post( response = self.post(
"/knowledge_base/upload_doc", "/knowledge_base/upload_doc",
data={"knowledge_base_name": knowledge_base_name}, data={"knowledge_base_name": knowledge_base_name, "override": override},
files={"file": (filename, file.open("rb"))}, files={"file": (filename, file)},
) )
return response.json() return response.json()