BaseChatOpenAIChain,支持基础的ChatOpenAI对话的Chain接入
This commit is contained in:
parent
08493bffbb
commit
823eb06c5d
|
|
@ -4,10 +4,14 @@ from typing import Any, Dict, List, Optional
|
|||
from langchain.chains.base import Chain
|
||||
from langchain.schema import (
|
||||
BaseMessage,
|
||||
messages_from_dict,
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
LLMResult
|
||||
)
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.chat_models import ChatOpenAI, _convert_dict_to_message
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
|
|
@ -16,6 +20,18 @@ from langchain.callbacks.manager import (
|
|||
from server.model.chat_openai_chain import OpenAiChatMsgDto, OpenAiMessageDto, BaseMessageDto
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: dict) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
elif role == "assistant":
|
||||
return AIMessage(content=_dict["content"])
|
||||
elif role == "system":
|
||||
return SystemMessage(content=_dict["content"])
|
||||
else:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
|
||||
|
||||
def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[BaseMessage]:
|
||||
"""
|
||||
前端消息传输对象DTO转换为chat消息传输对象DTO
|
||||
|
|
@ -25,7 +41,7 @@ def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[Bas
|
|||
messages = []
|
||||
for message_datum in message_data:
|
||||
messages.append(message_datum.dict())
|
||||
return messages_from_dict(messages)
|
||||
return _convert_dict_to_message(messages)
|
||||
|
||||
|
||||
class BaseChatOpenAIChain(Chain, ABC):
|
||||
|
|
|
|||
Loading…
Reference in New Issue