Langchain-Chatchat/server/chat/openai_chat.py

59 lines
1.8 KiB
Python
Raw Normal View History

2023-07-27 23:22:07 +08:00
from fastapi.responses import StreamingResponse
from typing import List, Optional
2023-07-27 23:22:07 +08:00
import openai
from configs import LLM_MODEL, logger, log_verbose
from server.utils import get_model_worker_config, fschat_openai_api_address
from pydantic import BaseModel
2023-07-27 23:22:07 +08:00
class OpenAiMessage(BaseModel):
role: str = "user"
content: str = "hello"
class OpenAiChatMsgIn(BaseModel):
model: str = LLM_MODEL
messages: List[OpenAiMessage]
temperature: float = 0.7
n: int = 1
max_tokens: Optional[int] = None
stop: List[str] = []
stream: bool = False
presence_penalty: int = 0
frequency_penalty: int = 0
async def openai_chat(msg: OpenAiChatMsgIn):
config = get_model_worker_config(msg.model)
openai.api_key = config.get("api_key", "EMPTY")
2023-07-27 23:22:07 +08:00
print(f"{openai.api_key=}")
openai.api_base = config.get("api_base_url", fschat_openai_api_address())
2023-07-27 23:22:07 +08:00
print(f"{openai.api_base=}")
print(msg)
2023-07-27 23:22:07 +08:00
async def get_response(msg):
data = msg.dict()
try:
response = await openai.ChatCompletion.acreate(**data)
if msg.stream:
async for data in response:
if choices := data.choices:
if chunk := choices[0].get("delta", {}).get("content"):
print(chunk, end="", flush=True)
yield chunk
else:
if response.choices:
answer = response.choices[0].message.content
print(answer)
yield(answer)
except Exception as e:
2023-09-08 20:48:31 +08:00
msg = f"获取ChatCompletion时出错{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
2023-07-27 23:22:07 +08:00
return StreamingResponse(
get_response(msg),
2023-07-27 23:22:07 +08:00
media_type='text/event-stream',
)