2023-07-27 23:22:07 +08:00
|
|
|
|
from fastapi.responses import StreamingResponse
|
2023-08-10 21:26:05 +08:00
|
|
|
|
from typing import List
|
2023-07-27 23:22:07 +08:00
|
|
|
|
import openai
|
2023-09-08 20:48:31 +08:00
|
|
|
|
from configs.model_config import llm_model_dict, LLM_MODEL, logger, log_verbose
|
2023-07-30 08:56:49 +08:00
|
|
|
|
from pydantic import BaseModel
|
2023-07-27 23:22:07 +08:00
|
|
|
|
|
2023-07-30 08:56:49 +08:00
|
|
|
|
|
|
|
|
|
|
class OpenAiMessage(BaseModel):
|
|
|
|
|
|
role: str = "user"
|
|
|
|
|
|
content: str = "hello"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OpenAiChatMsgIn(BaseModel):
|
|
|
|
|
|
model: str = LLM_MODEL
|
|
|
|
|
|
messages: List[OpenAiMessage]
|
|
|
|
|
|
temperature: float = 0.7
|
|
|
|
|
|
n: int = 1
|
|
|
|
|
|
max_tokens: int = 1024
|
|
|
|
|
|
stop: List[str] = []
|
2023-08-09 23:48:41 +08:00
|
|
|
|
stream: bool = False
|
2023-07-30 08:56:49 +08:00
|
|
|
|
presence_penalty: int = 0
|
|
|
|
|
|
frequency_penalty: int = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def openai_chat(msg: OpenAiChatMsgIn):
|
2023-07-27 23:22:07 +08:00
|
|
|
|
openai.api_key = llm_model_dict[LLM_MODEL]["api_key"]
|
|
|
|
|
|
print(f"{openai.api_key=}")
|
|
|
|
|
|
openai.api_base = llm_model_dict[LLM_MODEL]["api_base_url"]
|
|
|
|
|
|
print(f"{openai.api_base=}")
|
2023-07-30 08:56:49 +08:00
|
|
|
|
print(msg)
|
2023-07-27 23:22:07 +08:00
|
|
|
|
|
2023-09-08 12:25:02 +08:00
|
|
|
|
async def get_response(msg):
|
2023-08-09 23:48:41 +08:00
|
|
|
|
data = msg.dict()
|
2023-08-15 14:24:54 +08:00
|
|
|
|
|
|
|
|
|
|
try:
|
2023-09-08 12:25:02 +08:00
|
|
|
|
response = await openai.ChatCompletion.acreate(**data)
|
2023-08-15 14:24:54 +08:00
|
|
|
|
if msg.stream:
|
2023-09-08 12:25:02 +08:00
|
|
|
|
async for data in response:
|
2023-08-28 13:50:35 +08:00
|
|
|
|
if choices := data.choices:
|
|
|
|
|
|
if chunk := choices[0].get("delta", {}).get("content"):
|
|
|
|
|
|
print(chunk, end="", flush=True)
|
|
|
|
|
|
yield chunk
|
2023-08-15 14:24:54 +08:00
|
|
|
|
else:
|
2023-08-28 13:50:35 +08:00
|
|
|
|
if response.choices:
|
|
|
|
|
|
answer = response.choices[0].message.content
|
|
|
|
|
|
print(answer)
|
|
|
|
|
|
yield(answer)
|
2023-08-15 14:24:54 +08:00
|
|
|
|
except Exception as e:
|
2023-09-08 20:48:31 +08:00
|
|
|
|
msg = f"获取ChatCompletion时出错:{e}"
|
|
|
|
|
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
|
|
|
|
|
exc_info=e if log_verbose else None)
|
2023-08-15 14:24:54 +08:00
|
|
|
|
|
2023-07-27 23:22:07 +08:00
|
|
|
|
return StreamingResponse(
|
2023-07-30 08:56:49 +08:00
|
|
|
|
get_response(msg),
|
2023-07-27 23:22:07 +08:00
|
|
|
|
media_type='text/event-stream',
|
2023-07-30 08:56:49 +08:00
|
|
|
|
)
|