diff --git a/api.py b/api.py index 0413c95..89afe86 100644 --- a/api.py +++ b/api.py @@ -8,11 +8,13 @@ import urllib import nltk import pydantic import uvicorn -from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket +from fastapi import Body, Request, FastAPI, File, Form, Query, UploadFile, WebSocket from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse from pydantic import BaseModel from typing_extensions import Annotated from starlette.responses import RedirectResponse +from sse_starlette.sse import EventSourceResponse from chains.local_doc_qa import LocalDocQA from configs.model_config import (KB_ROOT_PATH, EMBEDDING_DEVICE, @@ -266,6 +268,7 @@ async def update_doc( async def local_doc_chat( knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"), question: str = Body(..., description="Question", example="工伤保险是什么?"), + stream: bool = Body(False, description="是否开启流式输出,默认false,有些模型可能不支持。"), history: List[List[str]] = Body( [], description="History of previous questions and answers", @@ -287,22 +290,34 @@ async def local_doc_chat( source_documents=[], ) else: - for resp, history in local_doc_qa.get_knowledge_based_answer( - query=question, vs_path=vs_path, chat_history=history, streaming=True - ): - pass - source_documents = [ - f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" - f"""相关度:{doc.metadata['score']}\n\n""" - for inum, doc in enumerate(resp["source_documents"]) - ] + if (stream): + def generate_answer (): + last_print_len = 0 + for resp, next_history in local_doc_qa.get_knowledge_based_answer( + query=question, vs_path=vs_path, chat_history=history, streaming=True + ): + yield resp["result"][last_print_len:] + last_print_len=len(resp["result"]) - return ChatMessage( - question=question, - response=resp["result"], - history=history, - source_documents=source_documents, - ) + return StreamingResponse(generate_answer()) + else: + for resp, next_history in local_doc_qa.get_knowledge_based_answer( + query=question, vs_path=vs_path, chat_history=history, streaming=True + ): + pass + + source_documents = [ + f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" + f"""相关度:{doc.metadata['score']}\n\n""" + for inum, doc in enumerate(resp["source_documents"]) + ] + + return ChatMessage( + question=question, + response=resp["result"], + history=next_history, + source_documents=source_documents, + ) async def bing_search_chat( @@ -337,6 +352,7 @@ async def bing_search_chat( async def chat( question: str = Body(..., description="Question", example="工伤保险是什么?"), + stream: bool = Body(False, description="是否开启流式输出,默认false,有些模型可能不支持。"), history: List[List[str]] = Body( [], description="History of previous questions and answers", @@ -348,18 +364,29 @@ async def chat( ], ), ): - for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history, - streaming=True): - resp = answer_result.llm_output["answer"] - history = answer_result.history - pass - return ChatMessage( - question=question, - response=resp, - history=history, - source_documents=[], - ) + if (stream): + def generate_answer (): + last_print_len = 0 + for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history, + streaming=True): + yield answer_result.llm_output["answer"][last_print_len:] + last_print_len = len(answer_result.llm_output["answer"]) + + return StreamingResponse(generate_answer()) + else: + for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history, + streaming=True): + resp = answer_result.llm_output["answer"] + history = answer_result.history + pass + + return ChatMessage( + question=question, + response=resp, + history=history, + source_documents=[], + ) async def stream_chat(websocket: WebSocket, knowledge_base_id: str):