2023-07-27 23:22:07 +08:00
|
|
|
|
import asyncio
|
2023-08-08 23:54:51 +08:00
|
|
|
|
from pydantic import BaseModel, Field
|
2023-08-23 08:35:26 +08:00
|
|
|
|
from langchain.prompts.chat import ChatMessagePromptTemplate
|
2023-09-08 20:48:31 +08:00
|
|
|
|
from configs import logger, log_verbose
|
2023-09-15 17:52:22 +08:00
|
|
|
|
from server.utils import get_model_worker_config, fschat_openai_api_address
|
|
|
|
|
|
from langchain.chat_models import ChatOpenAI
|
|
|
|
|
|
from typing import Awaitable, List, Tuple, Dict, Union, Callable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_ChatOpenAI(
|
|
|
|
|
|
model_name: str,
|
|
|
|
|
|
temperature: float,
|
|
|
|
|
|
callbacks: List[Callable] = [],
|
|
|
|
|
|
) -> ChatOpenAI:
|
|
|
|
|
|
config = get_model_worker_config(model_name)
|
|
|
|
|
|
model = ChatOpenAI(
|
|
|
|
|
|
streaming=True,
|
|
|
|
|
|
verbose=True,
|
|
|
|
|
|
callbacks=callbacks,
|
|
|
|
|
|
openai_api_key=config.get("api_key", "EMPTY"),
|
|
|
|
|
|
openai_api_base=fschat_openai_api_address(),
|
|
|
|
|
|
model_name=model_name,
|
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
|
openai_proxy=config.get("openai_proxy")
|
|
|
|
|
|
)
|
|
|
|
|
|
return model
|
2023-07-27 23:22:07 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
|
|
|
|
|
"""Wrap an awaitable with a event to signal when it's done or an exception is raised."""
|
|
|
|
|
|
try:
|
|
|
|
|
|
await fn
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
# TODO: handle exception
|
2023-09-08 20:48:31 +08:00
|
|
|
|
msg = f"Caught exception: {e}"
|
|
|
|
|
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
|
|
|
|
|
exc_info=e if log_verbose else None)
|
2023-07-27 23:22:07 +08:00
|
|
|
|
finally:
|
|
|
|
|
|
# Signal the aiter to stop.
|
2023-08-08 23:54:51 +08:00
|
|
|
|
event.set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class History(BaseModel):
|
|
|
|
|
|
"""
|
|
|
|
|
|
对话历史
|
|
|
|
|
|
可从dict生成,如
|
|
|
|
|
|
h = History(**{"role":"user","content":"你好"})
|
|
|
|
|
|
也可转换为tuple,如
|
|
|
|
|
|
h.to_msy_tuple = ("human", "你好")
|
|
|
|
|
|
"""
|
|
|
|
|
|
role: str = Field(...)
|
|
|
|
|
|
content: str = Field(...)
|
|
|
|
|
|
|
|
|
|
|
|
def to_msg_tuple(self):
|
|
|
|
|
|
return "ai" if self.role=="assistant" else "human", self.content
|
2023-08-23 08:35:26 +08:00
|
|
|
|
|
|
|
|
|
|
def to_msg_template(self, is_raw=True) -> ChatMessagePromptTemplate:
|
|
|
|
|
|
role_maps = {
|
|
|
|
|
|
"ai": "assistant",
|
|
|
|
|
|
"human": "user",
|
|
|
|
|
|
}
|
|
|
|
|
|
role = role_maps.get(self.role, self.role)
|
|
|
|
|
|
if is_raw: # 当前默认历史消息都是没有input_variable的文本。
|
|
|
|
|
|
content = "{% raw %}" + self.content + "{% endraw %}"
|
|
|
|
|
|
else:
|
|
|
|
|
|
content = self.content
|
|
|
|
|
|
|
|
|
|
|
|
return ChatMessagePromptTemplate.from_template(
|
|
|
|
|
|
content,
|
|
|
|
|
|
"jinja2",
|
|
|
|
|
|
role=role,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def from_data(cls, h: Union[List, Tuple, Dict]) -> "History":
|
|
|
|
|
|
if isinstance(h, (list,tuple)) and len(h) >= 2:
|
|
|
|
|
|
h = cls(role=h[0], content=h[1])
|
|
|
|
|
|
elif isinstance(h, dict):
|
|
|
|
|
|
h = cls(**h)
|
|
|
|
|
|
|
|
|
|
|
|
return h
|