diff --git a/api.py b/api.py index c675543..e4917b8 100644 --- a/api.py +++ b/api.py @@ -265,17 +265,17 @@ async def chat( async def stream_chat(websocket: WebSocket, knowledge_base_id: str): await websocket.accept() - vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id) - - if not os.path.exists(vs_path): - await websocket.send_json({"error": f"Knowledge base {knowledge_base_id} not found"}) - await websocket.close() - return - - history = [] turn = 1 while True: - question = await websocket.receive_text() + input_json = await websocket.receive_json() + question, history, knowledge_base_id = input_json[""], input_json["history"], input_json["knowledge_base_id"] + vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id) + + if not os.path.exists(vs_path): + await websocket.send_json({"error": f"Knowledge base {knowledge_base_id} not found"}) + await websocket.close() + return + await websocket.send_json({"question": question, "turn": turn, "flag": "start"}) last_print_len = 0 @@ -304,7 +304,6 @@ async def stream_chat(websocket: WebSocket, knowledge_base_id: str): ) turn += 1 - async def document(): return RedirectResponse(url="/docs")