LLM对话和知识库对话接口增加流式输出功能

This commit is contained in:
超能刚哥 2023-06-18 17:30:41 +08:00 committed by GitHub
parent 6981bdee62
commit f21e00bf16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 54 additions and 27 deletions

81
api.py
View File

@ -8,11 +8,13 @@ import urllib
import nltk
import pydantic
import uvicorn
from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket
from fastapi import Body, Request, FastAPI, File, Form, Query, UploadFile, WebSocket
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing_extensions import Annotated
from starlette.responses import RedirectResponse
from sse_starlette.sse import EventSourceResponse
from chains.local_doc_qa import LocalDocQA
from configs.model_config import (KB_ROOT_PATH, EMBEDDING_DEVICE,
@ -266,6 +268,7 @@ async def update_doc(
async def local_doc_chat(
knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
question: str = Body(..., description="Question", example="工伤保险是什么?"),
stream: bool = Body(False, description="是否开启流式输出默认false有些模型可能不支持。"),
history: List[List[str]] = Body(
[],
description="History of previous questions and answers",
@ -287,22 +290,34 @@ async def local_doc_chat(
source_documents=[],
)
else:
for resp, history in local_doc_qa.get_knowledge_based_answer(
query=question, vs_path=vs_path, chat_history=history, streaming=True
):
pass
source_documents = [
f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}\n\n{doc.page_content}\n\n"""
f"""相关度:{doc.metadata['score']}\n\n"""
for inum, doc in enumerate(resp["source_documents"])
]
if (stream):
def generate_answer ():
last_print_len = 0
for resp, next_history in local_doc_qa.get_knowledge_based_answer(
query=question, vs_path=vs_path, chat_history=history, streaming=True
):
yield resp["result"][last_print_len:]
last_print_len=len(resp["result"])
return ChatMessage(
question=question,
response=resp["result"],
history=history,
source_documents=source_documents,
)
return StreamingResponse(generate_answer())
else:
for resp, next_history in local_doc_qa.get_knowledge_based_answer(
query=question, vs_path=vs_path, chat_history=history, streaming=True
):
pass
source_documents = [
f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}\n\n{doc.page_content}\n\n"""
f"""相关度:{doc.metadata['score']}\n\n"""
for inum, doc in enumerate(resp["source_documents"])
]
return ChatMessage(
question=question,
response=resp["result"],
history=next_history,
source_documents=source_documents,
)
async def bing_search_chat(
@ -337,6 +352,7 @@ async def bing_search_chat(
async def chat(
question: str = Body(..., description="Question", example="工伤保险是什么?"),
stream: bool = Body(False, description="是否开启流式输出默认false有些模型可能不支持。"),
history: List[List[str]] = Body(
[],
description="History of previous questions and answers",
@ -348,18 +364,29 @@ async def chat(
],
),
):
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
streaming=True):
resp = answer_result.llm_output["answer"]
history = answer_result.history
pass
return ChatMessage(
question=question,
response=resp,
history=history,
source_documents=[],
)
if (stream):
def generate_answer ():
last_print_len = 0
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
streaming=True):
yield answer_result.llm_output["answer"][last_print_len:]
last_print_len = len(answer_result.llm_output["answer"])
return StreamingResponse(generate_answer())
else:
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
streaming=True):
resp = answer_result.llm_output["answer"]
history = answer_result.history
pass
return ChatMessage(
question=question,
response=resp,
history=history,
source_documents=[],
)
async def stream_chat(websocket: WebSocket, knowledge_base_id: str):