LLM对话和知识库对话接口增加流式输出功能
This commit is contained in:
parent
6981bdee62
commit
f21e00bf16
81
api.py
81
api.py
|
|
@ -8,11 +8,13 @@ import urllib
|
|||
import nltk
|
||||
import pydantic
|
||||
import uvicorn
|
||||
from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket
|
||||
from fastapi import Body, Request, FastAPI, File, Form, Query, UploadFile, WebSocket
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Annotated
|
||||
from starlette.responses import RedirectResponse
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from chains.local_doc_qa import LocalDocQA
|
||||
from configs.model_config import (KB_ROOT_PATH, EMBEDDING_DEVICE,
|
||||
|
|
@ -266,6 +268,7 @@ async def update_doc(
|
|||
async def local_doc_chat(
|
||||
knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
|
||||
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
||||
stream: bool = Body(False, description="是否开启流式输出,默认false,有些模型可能不支持。"),
|
||||
history: List[List[str]] = Body(
|
||||
[],
|
||||
description="History of previous questions and answers",
|
||||
|
|
@ -287,22 +290,34 @@ async def local_doc_chat(
|
|||
source_documents=[],
|
||||
)
|
||||
else:
|
||||
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
||||
query=question, vs_path=vs_path, chat_history=history, streaming=True
|
||||
):
|
||||
pass
|
||||
source_documents = [
|
||||
f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
||||
f"""相关度:{doc.metadata['score']}\n\n"""
|
||||
for inum, doc in enumerate(resp["source_documents"])
|
||||
]
|
||||
if (stream):
|
||||
def generate_answer ():
|
||||
last_print_len = 0
|
||||
for resp, next_history in local_doc_qa.get_knowledge_based_answer(
|
||||
query=question, vs_path=vs_path, chat_history=history, streaming=True
|
||||
):
|
||||
yield resp["result"][last_print_len:]
|
||||
last_print_len=len(resp["result"])
|
||||
|
||||
return ChatMessage(
|
||||
question=question,
|
||||
response=resp["result"],
|
||||
history=history,
|
||||
source_documents=source_documents,
|
||||
)
|
||||
return StreamingResponse(generate_answer())
|
||||
else:
|
||||
for resp, next_history in local_doc_qa.get_knowledge_based_answer(
|
||||
query=question, vs_path=vs_path, chat_history=history, streaming=True
|
||||
):
|
||||
pass
|
||||
|
||||
source_documents = [
|
||||
f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
||||
f"""相关度:{doc.metadata['score']}\n\n"""
|
||||
for inum, doc in enumerate(resp["source_documents"])
|
||||
]
|
||||
|
||||
return ChatMessage(
|
||||
question=question,
|
||||
response=resp["result"],
|
||||
history=next_history,
|
||||
source_documents=source_documents,
|
||||
)
|
||||
|
||||
|
||||
async def bing_search_chat(
|
||||
|
|
@ -337,6 +352,7 @@ async def bing_search_chat(
|
|||
|
||||
async def chat(
|
||||
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
||||
stream: bool = Body(False, description="是否开启流式输出,默认false,有些模型可能不支持。"),
|
||||
history: List[List[str]] = Body(
|
||||
[],
|
||||
description="History of previous questions and answers",
|
||||
|
|
@ -348,18 +364,29 @@ async def chat(
|
|||
],
|
||||
),
|
||||
):
|
||||
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
|
||||
streaming=True):
|
||||
resp = answer_result.llm_output["answer"]
|
||||
history = answer_result.history
|
||||
pass
|
||||
|
||||
return ChatMessage(
|
||||
question=question,
|
||||
response=resp,
|
||||
history=history,
|
||||
source_documents=[],
|
||||
)
|
||||
if (stream):
|
||||
def generate_answer ():
|
||||
last_print_len = 0
|
||||
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
|
||||
streaming=True):
|
||||
yield answer_result.llm_output["answer"][last_print_len:]
|
||||
last_print_len = len(answer_result.llm_output["answer"])
|
||||
|
||||
return StreamingResponse(generate_answer())
|
||||
else:
|
||||
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
|
||||
streaming=True):
|
||||
resp = answer_result.llm_output["answer"]
|
||||
history = answer_result.history
|
||||
pass
|
||||
|
||||
return ChatMessage(
|
||||
question=question,
|
||||
response=resp,
|
||||
history=history,
|
||||
source_documents=[],
|
||||
)
|
||||
|
||||
|
||||
async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
|
||||
|
|
|
|||
Loading…
Reference in New Issue