BaseChatOpenAIChain,支持基础的ChatOpenAI对话的Chain接入

This commit is contained in:
glide-the 2023-08-07 22:57:13 +08:00
parent 08493bffbb
commit 823eb06c5d
1 changed files with 19 additions and 3 deletions

View File

@ -4,10 +4,14 @@ from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.schema import ( from langchain.schema import (
BaseMessage, BaseMessage,
messages_from_dict, AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
LLMResult LLMResult
) )
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI, _convert_dict_to_message
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun, AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun, CallbackManagerForChainRun,
@ -16,6 +20,18 @@ from langchain.callbacks.manager import (
from server.model.chat_openai_chain import OpenAiChatMsgDto, OpenAiMessageDto, BaseMessageDto from server.model.chat_openai_chain import OpenAiChatMsgDto, OpenAiMessageDto, BaseMessageDto
def _convert_dict_to_message(_dict: dict) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
return AIMessage(content=_dict["content"])
elif role == "system":
return SystemMessage(content=_dict["content"])
else:
return ChatMessage(content=_dict["content"], role=role)
def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[BaseMessage]: def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[BaseMessage]:
""" """
前端消息传输对象DTO转换为chat消息传输对象DTO 前端消息传输对象DTO转换为chat消息传输对象DTO
@ -25,7 +41,7 @@ def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[Bas
messages = [] messages = []
for message_datum in message_data: for message_datum in message_data:
messages.append(message_datum.dict()) messages.append(message_datum.dict())
return messages_from_dict(messages) return _convert_dict_to_message(messages)
class BaseChatOpenAIChain(Chain, ABC): class BaseChatOpenAIChain(Chain, ABC):