LLM对话和知识库对话接口增加流式输出功能

This commit is contained in:
超能刚哥 2023-06-18 17:30:41 +08:00 committed by GitHub
parent 6981bdee62
commit f21e00bf16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 54 additions and 27 deletions

81
api.py
View File

@ -8,11 +8,13 @@ import urllib
import nltk import nltk
import pydantic import pydantic
import uvicorn import uvicorn
from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket from fastapi import Body, Request, FastAPI, File, Form, Query, UploadFile, WebSocket
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import Annotated from typing_extensions import Annotated
from starlette.responses import RedirectResponse from starlette.responses import RedirectResponse
from sse_starlette.sse import EventSourceResponse
from chains.local_doc_qa import LocalDocQA from chains.local_doc_qa import LocalDocQA
from configs.model_config import (KB_ROOT_PATH, EMBEDDING_DEVICE, from configs.model_config import (KB_ROOT_PATH, EMBEDDING_DEVICE,
@ -266,6 +268,7 @@ async def update_doc(
async def local_doc_chat( async def local_doc_chat(
knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"), knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
question: str = Body(..., description="Question", example="工伤保险是什么?"), question: str = Body(..., description="Question", example="工伤保险是什么?"),
stream: bool = Body(False, description="是否开启流式输出默认false有些模型可能不支持。"),
history: List[List[str]] = Body( history: List[List[str]] = Body(
[], [],
description="History of previous questions and answers", description="History of previous questions and answers",
@ -287,22 +290,34 @@ async def local_doc_chat(
source_documents=[], source_documents=[],
) )
else: else:
for resp, history in local_doc_qa.get_knowledge_based_answer( if (stream):
query=question, vs_path=vs_path, chat_history=history, streaming=True def generate_answer ():
): last_print_len = 0
pass for resp, next_history in local_doc_qa.get_knowledge_based_answer(
source_documents = [ query=question, vs_path=vs_path, chat_history=history, streaming=True
f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}\n\n{doc.page_content}\n\n""" ):
f"""相关度:{doc.metadata['score']}\n\n""" yield resp["result"][last_print_len:]
for inum, doc in enumerate(resp["source_documents"]) last_print_len=len(resp["result"])
]
return ChatMessage( return StreamingResponse(generate_answer())
question=question, else:
response=resp["result"], for resp, next_history in local_doc_qa.get_knowledge_based_answer(
history=history, query=question, vs_path=vs_path, chat_history=history, streaming=True
source_documents=source_documents, ):
) pass
source_documents = [
f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}\n\n{doc.page_content}\n\n"""
f"""相关度:{doc.metadata['score']}\n\n"""
for inum, doc in enumerate(resp["source_documents"])
]
return ChatMessage(
question=question,
response=resp["result"],
history=next_history,
source_documents=source_documents,
)
async def bing_search_chat( async def bing_search_chat(
@ -337,6 +352,7 @@ async def bing_search_chat(
async def chat( async def chat(
question: str = Body(..., description="Question", example="工伤保险是什么?"), question: str = Body(..., description="Question", example="工伤保险是什么?"),
stream: bool = Body(False, description="是否开启流式输出默认false有些模型可能不支持。"),
history: List[List[str]] = Body( history: List[List[str]] = Body(
[], [],
description="History of previous questions and answers", description="History of previous questions and answers",
@ -348,18 +364,29 @@ async def chat(
], ],
), ),
): ):
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
streaming=True):
resp = answer_result.llm_output["answer"]
history = answer_result.history
pass
return ChatMessage( if (stream):
question=question, def generate_answer ():
response=resp, last_print_len = 0
history=history, for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
source_documents=[], streaming=True):
) yield answer_result.llm_output["answer"][last_print_len:]
last_print_len = len(answer_result.llm_output["answer"])
return StreamingResponse(generate_answer())
else:
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
streaming=True):
resp = answer_result.llm_output["answer"]
history = answer_result.history
pass
return ChatMessage(
question=question,
response=resp,
history=history,
source_documents=[],
)
async def stream_chat(websocket: WebSocket, knowledge_base_id: str): async def stream_chat(websocket: WebSocket, knowledge_base_id: str):