diff --git a/server/chat/chat_openai_chain/chat_openai_chain.py b/server/chat/chat_openai_chain/chat_openai_chain.py index 17f1866..a989ccd 100644 --- a/server/chat/chat_openai_chain/chat_openai_chain.py +++ b/server/chat/chat_openai_chain/chat_openai_chain.py @@ -4,10 +4,14 @@ from typing import Any, Dict, List, Optional from langchain.chains.base import Chain from langchain.schema import ( BaseMessage, - messages_from_dict, + AIMessage, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, LLMResult ) -from langchain.chat_models import ChatOpenAI +from langchain.chat_models import ChatOpenAI, _convert_dict_to_message from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, @@ -16,6 +20,18 @@ from langchain.callbacks.manager import ( from server.model.chat_openai_chain import OpenAiChatMsgDto, OpenAiMessageDto, BaseMessageDto +def _convert_dict_to_message(_dict: dict) -> BaseMessage: + role = _dict["role"] + if role == "user": + return HumanMessage(content=_dict["content"]) + elif role == "assistant": + return AIMessage(content=_dict["content"]) + elif role == "system": + return SystemMessage(content=_dict["content"]) + else: + return ChatMessage(content=_dict["content"], role=role) + + def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[BaseMessage]: """ 前端消息传输对象DTO转换为chat消息传输对象DTO @@ -25,7 +41,7 @@ def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[Bas messages = [] for message_datum in message_data: messages.append(message_datum.dict()) - return messages_from_dict(messages) + return _convert_dict_to_message(messages) class BaseChatOpenAIChain(Chain, ABC):