BaseChatOpenAIChain,支持基础的ChatOpenAI对话的Chain接入
This commit is contained in:
parent
08493bffbb
commit
823eb06c5d
|
|
@ -4,10 +4,14 @@ from typing import Any, Dict, List, Optional
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
messages_from_dict,
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
ChatMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
LLMResult
|
LLMResult
|
||||||
)
|
)
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI, _convert_dict_to_message
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForChainRun,
|
AsyncCallbackManagerForChainRun,
|
||||||
CallbackManagerForChainRun,
|
CallbackManagerForChainRun,
|
||||||
|
|
@ -16,6 +20,18 @@ from langchain.callbacks.manager import (
|
||||||
from server.model.chat_openai_chain import OpenAiChatMsgDto, OpenAiMessageDto, BaseMessageDto
|
from server.model.chat_openai_chain import OpenAiChatMsgDto, OpenAiMessageDto, BaseMessageDto
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_dict_to_message(_dict: dict) -> BaseMessage:
|
||||||
|
role = _dict["role"]
|
||||||
|
if role == "user":
|
||||||
|
return HumanMessage(content=_dict["content"])
|
||||||
|
elif role == "assistant":
|
||||||
|
return AIMessage(content=_dict["content"])
|
||||||
|
elif role == "system":
|
||||||
|
return SystemMessage(content=_dict["content"])
|
||||||
|
else:
|
||||||
|
return ChatMessage(content=_dict["content"], role=role)
|
||||||
|
|
||||||
|
|
||||||
def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[BaseMessage]:
|
def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[BaseMessage]:
|
||||||
"""
|
"""
|
||||||
前端消息传输对象DTO转换为chat消息传输对象DTO
|
前端消息传输对象DTO转换为chat消息传输对象DTO
|
||||||
|
|
@ -25,7 +41,7 @@ def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[Bas
|
||||||
messages = []
|
messages = []
|
||||||
for message_datum in message_data:
|
for message_datum in message_data:
|
||||||
messages.append(message_datum.dict())
|
messages.append(message_datum.dict())
|
||||||
return messages_from_dict(messages)
|
return _convert_dict_to_message(messages)
|
||||||
|
|
||||||
|
|
||||||
class BaseChatOpenAIChain(Chain, ABC):
|
class BaseChatOpenAIChain(Chain, ABC):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue