make parameter examples available in openapi docs
This commit is contained in:
parent
c7b91bfaf1
commit
323fc13d4c
|
|
@ -12,12 +12,12 @@ from typing import List, Optional
|
||||||
from server.chat.utils import History
|
from server.chat.utils import History
|
||||||
|
|
||||||
|
|
||||||
def chat(query: str = Body(..., description="用户输入", example="恼羞成怒"),
|
def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||||
history: List[History] = Body([],
|
history: List[History] = Body([],
|
||||||
description="历史对话",
|
description="历史对话",
|
||||||
example=[
|
examples=[[
|
||||||
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
|
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||||
{"role": "assistant", "content": "虎头虎脑"}]
|
{"role": "assistant", "content": "虎头虎脑"}]]
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
history = [History(**h) if isinstance(h, dict) else h for h in history]
|
history = [History(**h) if isinstance(h, dict) else h for h in history]
|
||||||
|
|
|
||||||
|
|
@ -16,16 +16,16 @@ from server.knowledge_base.kb_service.base import KBService, KBServiceFactory
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
def knowledge_base_chat(query: str = Body(..., description="用户输入", example="你好"),
|
def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||||
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
|
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||||
history: List[History] = Body([],
|
history: List[History] = Body([],
|
||||||
description="历史对话",
|
description="历史对话",
|
||||||
example=[
|
examples=[[
|
||||||
{"role": "user",
|
{"role": "user",
|
||||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||||
{"role": "assistant",
|
{"role": "assistant",
|
||||||
"content": "虎头虎脑"}]
|
"content": "虎头虎脑"}]]
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
|
|
|
||||||
|
|
@ -57,16 +57,16 @@ def lookup_search_engine(
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
||||||
def search_engine_chat(query: str = Body(..., description="用户输入", example="你好"),
|
def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||||
search_engine_name: str = Body(..., description="搜索引擎名称", example="duckduckgo"),
|
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
|
||||||
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
||||||
history: Optional[List[History]] = Body(...,
|
history: List[History] = Body([],
|
||||||
description="历史对话",
|
description="历史对话",
|
||||||
example=[
|
examples=[[
|
||||||
{"role": "user",
|
{"role": "user",
|
||||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||||
{"role": "assistant",
|
{"role": "assistant",
|
||||||
"content": "虎头虎脑"}]
|
"content": "虎头虎脑"}]]
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from server.knowledge_base.utils import validate_kb_name
|
||||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||||
from server.db.repository.knowledge_base_repository import list_kbs_from_db
|
from server.db.repository.knowledge_base_repository import list_kbs_from_db
|
||||||
from configs.model_config import EMBEDDING_MODEL
|
from configs.model_config import EMBEDDING_MODEL
|
||||||
|
from fastapi import Body
|
||||||
|
|
||||||
|
|
||||||
async def list_kbs():
|
async def list_kbs():
|
||||||
|
|
@ -11,9 +12,9 @@ async def list_kbs():
|
||||||
return ListResponse(data=list_kbs_from_db())
|
return ListResponse(data=list_kbs_from_db())
|
||||||
|
|
||||||
|
|
||||||
async def create_kb(knowledge_base_name: str,
|
async def create_kb(knowledge_base_name: str = Body(..., examples=["kb_name"]),
|
||||||
vector_store_type: str = "faiss",
|
vector_store_type: str = Body("faiss"),
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = Body(EMBEDDING_MODEL),
|
||||||
):
|
):
|
||||||
# Create selected knowledge base
|
# Create selected knowledge base
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
|
|
@ -21,14 +22,16 @@ async def create_kb(knowledge_base_name: str,
|
||||||
if knowledge_base_name is None or knowledge_base_name.strip() == "":
|
if knowledge_base_name is None or knowledge_base_name.strip() == "":
|
||||||
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
|
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
|
||||||
|
|
||||||
kb = KBServiceFactory.get_service(knowledge_base_name, "faiss")
|
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
|
||||||
if kb is not None:
|
if kb is not None:
|
||||||
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
|
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
|
||||||
kb.create()
|
kb.create()
|
||||||
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
|
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
|
|
||||||
async def delete_kb(knowledge_base_name: str):
|
async def delete_kb(
|
||||||
|
knowledge_base_name: str = Body(..., examples=["kb_name"])
|
||||||
|
):
|
||||||
# Delete selected knowledge base
|
# Delete selected knowledge base
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,9 @@ from server.knowledge_base.kb_service.base import SupportedVSType
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
async def list_docs(knowledge_base_name: str):
|
async def list_docs(
|
||||||
|
knowledge_base_name: str = Body(..., examples=["kb_name"])
|
||||||
|
):
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
return ListResponse(code=403, msg="Don't attack me", data=[])
|
return ListResponse(code=403, msg="Don't attack me", data=[])
|
||||||
|
|
||||||
|
|
@ -25,8 +27,8 @@ async def list_docs(knowledge_base_name: str):
|
||||||
|
|
||||||
|
|
||||||
async def upload_doc(file: UploadFile = File(description="上传文件"),
|
async def upload_doc(file: UploadFile = File(description="上传文件"),
|
||||||
knowledge_base_name: str = Form(..., description="知识库名称", example="kb1"),
|
knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]),
|
||||||
override: bool = Form(False, description="覆盖已有文件", example=False),
|
override: bool = Form(False, description="覆盖已有文件"),
|
||||||
):
|
):
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
|
|
@ -58,9 +60,9 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
|
||||||
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
|
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
|
||||||
|
|
||||||
|
|
||||||
async def delete_doc(knowledge_base_name: str = Body(...),
|
async def delete_doc(knowledge_base_name: str = Body(..., examples=["kb_name"]),
|
||||||
doc_name: str = Body(...),
|
doc_name: str = Body(..., examples=["file_name"]),
|
||||||
delete_content: bool = Body(...),
|
delete_content: bool = Body(False),
|
||||||
):
|
):
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
|
|
@ -80,8 +82,8 @@ async def delete_doc(knowledge_base_name: str = Body(...),
|
||||||
|
|
||||||
|
|
||||||
async def update_doc(
|
async def update_doc(
|
||||||
knowledge_base_name: str = Body(...),
|
knowledge_base_name: str = Body(..., examples=["kb_name"]),
|
||||||
file_name: str = Body(...),
|
file_name: str = Body(..., examples=["file_name"]),
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
更新知识库文档
|
更新知识库文档
|
||||||
|
|
@ -109,7 +111,7 @@ async def download_doc():
|
||||||
|
|
||||||
|
|
||||||
async def recreate_vector_store(
|
async def recreate_vector_store(
|
||||||
knowledge_base_name: str = Body(...),
|
knowledge_base_name: str = Body(..., examples=["kb_name"]),
|
||||||
allow_empty_kb: bool = Body(True),
|
allow_empty_kb: bool = Body(True),
|
||||||
vs_type: str = Body("faiss"),
|
vs_type: str = Body("faiss"),
|
||||||
):
|
):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue