diff --git a/server/chat/chat.py b/server/chat/chat.py index 5783829..34350a4 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -1,5 +1,6 @@ from fastapi import Body from sse_starlette.sse import EventSourceResponse +from fastapi.responses import StreamingResponse from configs import LLM_MODELS, TEMPERATURE from server.utils import wrap_done, get_ChatOpenAI from langchain.chains import LLMChain @@ -100,4 +101,6 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 await task - return EventSourceResponse(chat_iterator()) + #return EventSourceResponse(chat_iterator()) + return StreamingResponse(chat_iterator(), + media_type="text/event-stream") diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 79ed159..66afef2 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -1,4 +1,5 @@ from fastapi import Body, Request +from fastapi.responses import StreamingResponse from sse_starlette.sse import EventSourceResponse from fastapi.concurrency import run_in_threadpool from configs import (LLM_MODELS, @@ -145,5 +146,12 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", ensure_ascii=False) await task - return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history,model_name,prompt_name)) + return StreamingResponse(knowledge_base_chat_iterator(query=query, + top_k=top_k, + history=history, + model_name=model_name, + prompt_name=prompt_name), + media_type="text/event-stream") + + #return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history,model_name,prompt_name))