Langchain-Chatchat/server/chat/chat.py

44 lines
1.6 KiB
Python
Raw Normal View History

2023-07-27 23:22:07 +08:00
from fastapi import Body
from fastapi.responses import StreamingResponse
from configs.model_config import llm_model_dict, LLM_MODEL
2023-08-06 18:32:10 +08:00
from server.chat.utils import wrap_done
2023-07-27 23:22:07 +08:00
from langchain.chat_models import ChatOpenAI
from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts import PromptTemplate
def chat(query: str = Body(..., description="用户输入", example="你好")):
async def chat_iterator(message: str) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL
)
# llm = OpenAI(model_name=LLM_MODEL,
# openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
# openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
# streaming=True)
prompt = PromptTemplate(input_variables=["input"], template="{input}")
chain = LLMChain(prompt=prompt, llm=model)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall(message),
callback.done),
)
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield token
await task
return StreamingResponse(chat_iterator(query), media_type="text/event-stream")