From 1678392ceef9aec5c875f639993f0698149c10da Mon Sep 17 00:00:00 2001 From: royd <1615218238@qq.com> Date: Mon, 15 May 2023 19:11:00 +0800 Subject: [PATCH] Update api.py (#357) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 更改目前流式对话的传入模式,使其可以支持多轮对话,体现在新增参数:history、knowledge_base_id --- api.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/api.py b/api.py index c675543..e4917b8 100644 --- a/api.py +++ b/api.py @@ -265,17 +265,17 @@ async def chat( async def stream_chat(websocket: WebSocket, knowledge_base_id: str): await websocket.accept() - vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id) - - if not os.path.exists(vs_path): - await websocket.send_json({"error": f"Knowledge base {knowledge_base_id} not found"}) - await websocket.close() - return - - history = [] turn = 1 while True: - question = await websocket.receive_text() + input_json = await websocket.receive_json() + question, history, knowledge_base_id = input_json[""], input_json["history"], input_json["knowledge_base_id"] + vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id) + + if not os.path.exists(vs_path): + await websocket.send_json({"error": f"Knowledge base {knowledge_base_id} not found"}) + await websocket.close() + return + await websocket.send_json({"question": question, "turn": turn, "flag": "start"}) last_print_len = 0 @@ -304,7 +304,6 @@ async def stream_chat(websocket: WebSocket, knowledge_base_id: str): ) turn += 1 - async def document(): return RedirectResponse(url="/docs")