修改server.chat.openai_chat中的参数定义,使其与openai中/v1/chat/completions接口的入参保持一致,并按照model_config提供默认值。

openai_chat中的接口还要修改:openai根据参数stream有不同的返回值,本接口要与其对应。
This commit is contained in:
liunux4odoo 2023-07-30 08:56:49 +08:00
parent 1a7271e966
commit 179c2a9a92
1 changed files with 25 additions and 11 deletions

View File

@ -3,26 +3,40 @@ from fastapi.responses import StreamingResponse
from typing import List, Dict from typing import List, Dict
import openai import openai
from configs.model_config import llm_model_dict, LLM_MODEL from configs.model_config import llm_model_dict, LLM_MODEL
from pydantic import BaseModel
async def openai_chat(messages: List[Dict] = Body(...,
description="用户输入", class OpenAiMessage(BaseModel):
example=[{"role": "user", "content": "你好"}])): role: str = "user"
content: str = "hello"
class OpenAiChatMsgIn(BaseModel):
model: str = LLM_MODEL
messages: List[OpenAiMessage]
temperature: float = 0.7
n: int = 1
max_tokens: int = 1024
stop: List[str] = []
stream: bool = True
presence_penalty: int = 0
frequency_penalty: int = 0
async def openai_chat(msg: OpenAiChatMsgIn):
openai.api_key = llm_model_dict[LLM_MODEL]["api_key"] openai.api_key = llm_model_dict[LLM_MODEL]["api_key"]
print(f"{openai.api_key=}") print(f"{openai.api_key=}")
openai.api_base = llm_model_dict[LLM_MODEL]["api_base_url"] openai.api_base = llm_model_dict[LLM_MODEL]["api_base_url"]
print(f"{openai.api_base=}") print(f"{openai.api_base=}")
print(messages) print(msg)
async def get_response(messages: List[Dict]): async def get_response(msg):
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(**msg.dict())
model=LLM_MODEL,
messages=messages,
)
for chunk in response.choices[0].message.content: for chunk in response.choices[0].message.content:
print(chunk) print(chunk)
yield chunk yield chunk
return StreamingResponse( return StreamingResponse(
get_response(messages), get_response(msg),
media_type='text/event-stream', media_type='text/event-stream',
) )