BaseChatOpenAIChain,支持基础的ChatOpenAI对话的Chain接入

This commit is contained in:
glide-the 2023-08-07 22:57:13 +08:00
parent 08493bffbb
commit 823eb06c5d
1 changed files with 19 additions and 3 deletions

View File

@ -4,10 +4,14 @@ from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.schema import (
BaseMessage,
messages_from_dict,
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
LLMResult
)
from langchain.chat_models import ChatOpenAI
from langchain.chat_models import ChatOpenAI, _convert_dict_to_message
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
@ -16,6 +20,18 @@ from langchain.callbacks.manager import (
from server.model.chat_openai_chain import OpenAiChatMsgDto, OpenAiMessageDto, BaseMessageDto
def _convert_dict_to_message(_dict: dict) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
return AIMessage(content=_dict["content"])
elif role == "system":
return SystemMessage(content=_dict["content"])
else:
return ChatMessage(content=_dict["content"], role=role)
def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[BaseMessage]:
"""
前端消息传输对象DTO转换为chat消息传输对象DTO
@ -25,7 +41,7 @@ def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[Bas
messages = []
for message_datum in message_data:
messages.append(message_datum.dict())
return messages_from_dict(messages)
return _convert_dict_to_message(messages)
class BaseChatOpenAIChain(Chain, ABC):