LLM对话和知识库对话接口增加流式输出功能
This commit is contained in:
parent
6981bdee62
commit
f21e00bf16
81
api.py
81
api.py
|
|
@ -8,11 +8,13 @@ import urllib
|
||||||
import nltk
|
import nltk
|
||||||
import pydantic
|
import pydantic
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket
|
from fastapi import Body, Request, FastAPI, File, Form, Query, UploadFile, WebSocket
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
from starlette.responses import RedirectResponse
|
from starlette.responses import RedirectResponse
|
||||||
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
||||||
from chains.local_doc_qa import LocalDocQA
|
from chains.local_doc_qa import LocalDocQA
|
||||||
from configs.model_config import (KB_ROOT_PATH, EMBEDDING_DEVICE,
|
from configs.model_config import (KB_ROOT_PATH, EMBEDDING_DEVICE,
|
||||||
|
|
@ -266,6 +268,7 @@ async def update_doc(
|
||||||
async def local_doc_chat(
|
async def local_doc_chat(
|
||||||
knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
|
knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
|
||||||
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
||||||
|
stream: bool = Body(False, description="是否开启流式输出,默认false,有些模型可能不支持。"),
|
||||||
history: List[List[str]] = Body(
|
history: List[List[str]] = Body(
|
||||||
[],
|
[],
|
||||||
description="History of previous questions and answers",
|
description="History of previous questions and answers",
|
||||||
|
|
@ -287,22 +290,34 @@ async def local_doc_chat(
|
||||||
source_documents=[],
|
source_documents=[],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
if (stream):
|
||||||
query=question, vs_path=vs_path, chat_history=history, streaming=True
|
def generate_answer ():
|
||||||
):
|
last_print_len = 0
|
||||||
pass
|
for resp, next_history in local_doc_qa.get_knowledge_based_answer(
|
||||||
source_documents = [
|
query=question, vs_path=vs_path, chat_history=history, streaming=True
|
||||||
f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
):
|
||||||
f"""相关度:{doc.metadata['score']}\n\n"""
|
yield resp["result"][last_print_len:]
|
||||||
for inum, doc in enumerate(resp["source_documents"])
|
last_print_len=len(resp["result"])
|
||||||
]
|
|
||||||
|
|
||||||
return ChatMessage(
|
return StreamingResponse(generate_answer())
|
||||||
question=question,
|
else:
|
||||||
response=resp["result"],
|
for resp, next_history in local_doc_qa.get_knowledge_based_answer(
|
||||||
history=history,
|
query=question, vs_path=vs_path, chat_history=history, streaming=True
|
||||||
source_documents=source_documents,
|
):
|
||||||
)
|
pass
|
||||||
|
|
||||||
|
source_documents = [
|
||||||
|
f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
||||||
|
f"""相关度:{doc.metadata['score']}\n\n"""
|
||||||
|
for inum, doc in enumerate(resp["source_documents"])
|
||||||
|
]
|
||||||
|
|
||||||
|
return ChatMessage(
|
||||||
|
question=question,
|
||||||
|
response=resp["result"],
|
||||||
|
history=next_history,
|
||||||
|
source_documents=source_documents,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def bing_search_chat(
|
async def bing_search_chat(
|
||||||
|
|
@ -337,6 +352,7 @@ async def bing_search_chat(
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
||||||
|
stream: bool = Body(False, description="是否开启流式输出,默认false,有些模型可能不支持。"),
|
||||||
history: List[List[str]] = Body(
|
history: List[List[str]] = Body(
|
||||||
[],
|
[],
|
||||||
description="History of previous questions and answers",
|
description="History of previous questions and answers",
|
||||||
|
|
@ -348,18 +364,29 @@ async def chat(
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
|
|
||||||
streaming=True):
|
|
||||||
resp = answer_result.llm_output["answer"]
|
|
||||||
history = answer_result.history
|
|
||||||
pass
|
|
||||||
|
|
||||||
return ChatMessage(
|
if (stream):
|
||||||
question=question,
|
def generate_answer ():
|
||||||
response=resp,
|
last_print_len = 0
|
||||||
history=history,
|
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
|
||||||
source_documents=[],
|
streaming=True):
|
||||||
)
|
yield answer_result.llm_output["answer"][last_print_len:]
|
||||||
|
last_print_len = len(answer_result.llm_output["answer"])
|
||||||
|
|
||||||
|
return StreamingResponse(generate_answer())
|
||||||
|
else:
|
||||||
|
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
|
||||||
|
streaming=True):
|
||||||
|
resp = answer_result.llm_output["answer"]
|
||||||
|
history = answer_result.history
|
||||||
|
pass
|
||||||
|
|
||||||
|
return ChatMessage(
|
||||||
|
question=question,
|
||||||
|
response=resp,
|
||||||
|
history=history,
|
||||||
|
source_documents=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
|
async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue