2023-07-27 23:22:07 +08:00
|
|
|
from fastapi import Body
|
|
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
|
from configs.model_config import llm_model_dict, LLM_MODEL
|
2023-08-06 18:32:10 +08:00
|
|
|
from server.chat.utils import wrap_done
|
2023-07-27 23:22:07 +08:00
|
|
|
from langchain.chat_models import ChatOpenAI
|
|
|
|
|
from langchain import LLMChain
|
|
|
|
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
|
|
|
|
from typing import AsyncIterable
|
|
|
|
|
import asyncio
|
|
|
|
|
from langchain.prompts import PromptTemplate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def chat(query: str = Body(..., description="用户输入", example="你好")):
|
|
|
|
|
async def chat_iterator(message: str) -> AsyncIterable[str]:
|
|
|
|
|
callback = AsyncIteratorCallbackHandler()
|
|
|
|
|
|
|
|
|
|
model = ChatOpenAI(
|
|
|
|
|
streaming=True,
|
|
|
|
|
verbose=True,
|
|
|
|
|
callbacks=[callback],
|
|
|
|
|
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
|
|
|
|
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
|
|
|
|
model_name=LLM_MODEL
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# llm = OpenAI(model_name=LLM_MODEL,
|
|
|
|
|
# openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
|
|
|
|
# openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
|
|
|
|
# streaming=True)
|
|
|
|
|
|
|
|
|
|
prompt = PromptTemplate(input_variables=["input"], template="{input}")
|
|
|
|
|
chain = LLMChain(prompt=prompt, llm=model)
|
|
|
|
|
|
|
|
|
|
# Begin a task that runs in the background.
|
|
|
|
|
task = asyncio.create_task(wrap_done(
|
|
|
|
|
chain.acall(message),
|
|
|
|
|
callback.done),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async for token in callback.aiter():
|
|
|
|
|
# Use server-sent-events to stream the response
|
|
|
|
|
yield token
|
|
|
|
|
await task
|
|
|
|
|
return StreamingResponse(chat_iterator(query), media_type="text/event-stream")
|