From 823eb06c5d5a74987c5c903554b9ca1f1b0dd435 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Mon, 7 Aug 2023 22:57:13 +0800 Subject: [PATCH] =?UTF-8?q?BaseChatOpenAIChain,=E6=94=AF=E6=8C=81=E5=9F=BA?= =?UTF-8?q?=E7=A1=80=E7=9A=84ChatOpenAI=E5=AF=B9=E8=AF=9D=E7=9A=84Chain?= =?UTF-8?q?=E6=8E=A5=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chat_openai_chain/chat_openai_chain.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/server/chat/chat_openai_chain/chat_openai_chain.py b/server/chat/chat_openai_chain/chat_openai_chain.py index 17f1866..a989ccd 100644 --- a/server/chat/chat_openai_chain/chat_openai_chain.py +++ b/server/chat/chat_openai_chain/chat_openai_chain.py @@ -4,10 +4,14 @@ from typing import Any, Dict, List, Optional from langchain.chains.base import Chain from langchain.schema import ( BaseMessage, - messages_from_dict, + AIMessage, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, LLMResult ) -from langchain.chat_models import ChatOpenAI +from langchain.chat_models import ChatOpenAI, _convert_dict_to_message from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, @@ -16,6 +20,18 @@ from langchain.callbacks.manager import ( from server.model.chat_openai_chain import OpenAiChatMsgDto, OpenAiMessageDto, BaseMessageDto +def _convert_dict_to_message(_dict: dict) -> BaseMessage: + role = _dict["role"] + if role == "user": + return HumanMessage(content=_dict["content"]) + elif role == "assistant": + return AIMessage(content=_dict["content"]) + elif role == "system": + return SystemMessage(content=_dict["content"]) + else: + return ChatMessage(content=_dict["content"], role=role) + + def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[BaseMessage]: """ 前端消息传输对象DTO转换为chat消息传输对象DTO @@ -25,7 +41,7 @@ def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[Bas messages = [] for message_datum in message_data: messages.append(message_datum.dict()) - return messages_from_dict(messages) + return _convert_dict_to_message(messages) class BaseChatOpenAIChain(Chain, ABC):