From 323fc13d4cef06f5e877b2d804d20fd4bd5cee2b Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Wed, 9 Aug 2023 18:15:14 +0800 Subject: [PATCH] make parameter examples available in openapi docs --- server/chat/chat.py | 6 +++--- server/chat/knowledge_base_chat.py | 8 ++++---- server/chat/search_engine_chat.py | 20 ++++++++++---------- server/knowledge_base/kb_api.py | 13 ++++++++----- server/knowledge_base/kb_doc_api.py | 20 +++++++++++--------- 5 files changed, 36 insertions(+), 31 deletions(-) diff --git a/server/chat/chat.py b/server/chat/chat.py index 2b163d4..fd7e0db 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -12,12 +12,12 @@ from typing import List, Optional from server.chat.utils import History -def chat(query: str = Body(..., description="用户输入", example="恼羞成怒"), +def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), history: List[History] = Body([], description="历史对话", - example=[ + examples=[[ {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}, - {"role": "assistant", "content": "虎头虎脑"}] + {"role": "assistant", "content": "虎头虎脑"}]] ), ): history = [History(**h) if isinstance(h, dict) else h for h in history] diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 17c045d..f67af14 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -16,16 +16,16 @@ from server.knowledge_base.kb_service.base import KBService, KBServiceFactory import json -def knowledge_base_chat(query: str = Body(..., description="用户输入", example="你好"), - knowledge_base_name: str = Body(..., description="知识库名称", example="samples"), +def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]), + knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), history: List[History] = Body([], description="历史对话", - example=[ + examples=[[ {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}, {"role": "assistant", - "content": "虎头虎脑"}] + "content": "虎头虎脑"}]] ), ): kb = KBServiceFactory.get_service_by_name(knowledge_base_name) diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 70e7a3a..b713b24 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -57,17 +57,17 @@ def lookup_search_engine( return docs -def search_engine_chat(query: str = Body(..., description="用户输入", example="你好"), - search_engine_name: str = Body(..., description="搜索引擎名称", example="duckduckgo"), +def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]), + search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]), top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"), - history: Optional[List[History]] = Body(..., - description="历史对话", - example=[ - {"role": "user", - "content": "我们来玩成语接龙,我先来,生龙活虎"}, - {"role": "assistant", - "content": "虎头虎脑"}] - ), + history: List[History] = Body([], + description="历史对话", + examples=[[ + {"role": "user", + "content": "我们来玩成语接龙,我先来,生龙活虎"}, + {"role": "assistant", + "content": "虎头虎脑"}]] + ), ): if search_engine_name not in SEARCH_ENGINES.keys(): return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") diff --git a/server/knowledge_base/kb_api.py b/server/knowledge_base/kb_api.py index 84c8298..5186dbc 100644 --- a/server/knowledge_base/kb_api.py +++ b/server/knowledge_base/kb_api.py @@ -4,6 +4,7 @@ from server.knowledge_base.utils import validate_kb_name from server.knowledge_base.kb_service.base import KBServiceFactory from server.db.repository.knowledge_base_repository import list_kbs_from_db from configs.model_config import EMBEDDING_MODEL +from fastapi import Body async def list_kbs(): @@ -11,9 +12,9 @@ async def list_kbs(): return ListResponse(data=list_kbs_from_db()) -async def create_kb(knowledge_base_name: str, - vector_store_type: str = "faiss", - embed_model: str = EMBEDDING_MODEL, +async def create_kb(knowledge_base_name: str = Body(..., examples=["kb_name"]), + vector_store_type: str = Body("faiss"), + embed_model: str = Body(EMBEDDING_MODEL), ): # Create selected knowledge base if not validate_kb_name(knowledge_base_name): @@ -21,14 +22,16 @@ async def create_kb(knowledge_base_name: str, if knowledge_base_name is None or knowledge_base_name.strip() == "": return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称") - kb = KBServiceFactory.get_service(knowledge_base_name, "faiss") + kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model) if kb is not None: return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}") kb.create() return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") -async def delete_kb(knowledge_base_name: str): +async def delete_kb( + knowledge_base_name: str = Body(..., examples=["kb_name"]) + ): # Delete selected knowledge base if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 864662b..8a256a8 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -11,7 +11,9 @@ from server.knowledge_base.kb_service.base import SupportedVSType from typing import Union -async def list_docs(knowledge_base_name: str): +async def list_docs( + knowledge_base_name: str = Body(..., examples=["kb_name"]) +): if not validate_kb_name(knowledge_base_name): return ListResponse(code=403, msg="Don't attack me", data=[]) @@ -25,8 +27,8 @@ async def list_docs(knowledge_base_name: str): async def upload_doc(file: UploadFile = File(description="上传文件"), - knowledge_base_name: str = Form(..., description="知识库名称", example="kb1"), - override: bool = Form(False, description="覆盖已有文件", example=False), + knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]), + override: bool = Form(False, description="覆盖已有文件"), ): if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -58,9 +60,9 @@ async def upload_doc(file: UploadFile = File(description="上传文件"), return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}") -async def delete_doc(knowledge_base_name: str = Body(...), - doc_name: str = Body(...), - delete_content: bool = Body(...), +async def delete_doc(knowledge_base_name: str = Body(..., examples=["kb_name"]), + doc_name: str = Body(..., examples=["file_name"]), + delete_content: bool = Body(False), ): if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -80,8 +82,8 @@ async def delete_doc(knowledge_base_name: str = Body(...), async def update_doc( - knowledge_base_name: str = Body(...), - file_name: str = Body(...), + knowledge_base_name: str = Body(..., examples=["kb_name"]), + file_name: str = Body(..., examples=["file_name"]), ): ''' 更新知识库文档 @@ -109,7 +111,7 @@ async def download_doc(): async def recreate_vector_store( - knowledge_base_name: str = Body(...), + knowledge_base_name: str = Body(..., examples=["kb_name"]), allow_empty_kb: bool = Body(True), vs_type: str = Body("faiss"), ):