diff --git a/server/chat/chat_openai_chain/__init__.py b/server/chat/chat_openai_chain/__init__.py new file mode 100644 index 0000000..d53d695 --- /dev/null +++ b/server/chat/chat_openai_chain/__init__.py @@ -0,0 +1,5 @@ +from server.chat.chat_openai_chain.chat_openai_chain import BaseChatOpenAIChain + +__all__ = [ + "BaseChatOpenAIChain" +] diff --git a/server/chat/chat_openai_chain/chat_openai_chain.py b/server/chat/chat_openai_chain/chat_openai_chain.py new file mode 100644 index 0000000..17f1866 --- /dev/null +++ b/server/chat/chat_openai_chain/chat_openai_chain.py @@ -0,0 +1,136 @@ +from __future__ import annotations +from abc import ABC +from typing import Any, Dict, List, Optional +from langchain.chains.base import Chain +from langchain.schema import ( + BaseMessage, + messages_from_dict, + LLMResult +) +from langchain.chat_models import ChatOpenAI +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, + Callbacks, +) +from server.model.chat_openai_chain import OpenAiChatMsgDto, OpenAiMessageDto, BaseMessageDto + + +def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[BaseMessage]: + """ + 前端消息传输对象DTO转换为chat消息传输对象DTO + :param message_data: + :return: + """ + messages = [] + for message_datum in message_data: + messages.append(message_datum.dict()) + return messages_from_dict(messages) + + +class BaseChatOpenAIChain(Chain, ABC): + chat: ChatOpenAI + message_dto_key: str = "message_dto" #: :meta private: + output_key: str = "text" #: :meta private: + + @classmethod + def from_chain( + cls, + model_name: str, + streaming: Optional[bool], + verbose: Optional[bool], + callbacks: Optional[Callbacks], + openai_api_key: Optional[str], + # openai_api_base: Optional[str], + **kwargs: Any, + ) -> BaseChatOpenAIChain: + chat = ChatOpenAI( + streaming=streaming, + verbose=verbose, + callbacks=callbacks, + openai_api_key=openai_api_key, + # openai_api_base=openai_api_base, + model_name=model_name + ) + return cls(chat=chat, **kwargs) + + @property + def _chain_type(self) -> str: + return "BaseChatOpenAIChain" + + @property + def input_keys(self) -> List[str]: + return [self.message_dto_key] + + @property + def output_keys(self) -> List[str]: + return [self.output_key] + + def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]: + """Create outputs from response.""" + return [ + # Get the text of the top generated string. + {self.output_key: generation[0].text} + for generation in response.generations + ] + + def _call( + self, + inputs: Dict[str, OpenAiChatMsgDto], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + msg_dto = inputs[self.message_dto_key] + openai_messages_dto = convert_message_processors(msg_dto.messages) + + _text = "openai_messages_dto after formatting:\n" + str(openai_messages_dto) + if run_manager: + run_manager.on_text(_text, end="\n", verbose=self.verbose) + + response = self.chat(messages=openai_messages_dto, stop=msg_dto.stop) + return self.create_outputs(response)[0] + + async def _acall( + self, + inputs: Dict[str, OpenAiChatMsgDto], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + msg_dto = inputs[self.message_dto_key] + + openai_messages_dto = convert_message_processors(msg_dto.messages) + + _text = "openai_messages_dto after formatting:\n" + str(openai_messages_dto) + if run_manager: + run_manager.on_text(_text, end="\n", verbose=self.verbose) + + response = await self.chat(messages=openai_messages_dto, stop=msg_dto.stop) + return self.create_outputs(response)[0] + + +if __name__ == "__main__": + from langchain.callbacks import AsyncIteratorCallbackHandler + import json + + # Convert instances of the classes to dictionaries + message_dto1 = OpenAiMessageDto(type="human", data=BaseMessageDto(content="Hello!")) + message_dto2 = OpenAiMessageDto(type="system", data=BaseMessageDto(content="hi!")) + chat_msg_dto = OpenAiChatMsgDto(model_name="gpt-3.5-turbo", messages=[message_dto1, message_dto2]) + + chat_msg_json = json.dumps(chat_msg_dto.dict(), indent=2) + + print("OpenAiChatMsgDto JSON:") + print(chat_msg_json) + + callback = AsyncIteratorCallbackHandler() + + chains = BaseChatOpenAIChain.from_chain( + streaming=chat_msg_dto.stream, + verbose=True, + callbacks=[callback], + openai_api_key="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", + model_name=chat_msg_dto.model_name + + ) + + out = chains({"message_dto": chat_msg_dto}) + + print(out) diff --git a/server/model/chat_openai_chain.py b/server/model/chat_openai_chain.py new file mode 100644 index 0000000..f4a0f55 --- /dev/null +++ b/server/model/chat_openai_chain.py @@ -0,0 +1,44 @@ +""" +前端消息传输结构 +""" +from pydantic import BaseModel, Field +from typing import Any, Dict, List, Optional + + +class BaseMessageDto(BaseModel): + """Message Dto.""" + + content: str + + +class OpenAiMessageDto(BaseModel): + """ + see @Link{langchain.schema._message_from_dict} + """ + type: Optional[str] = Field( + default="user" + ) + data: BaseMessageDto + + +class OpenAiChatMsgDto(BaseModel): + model_name: str + messages: List[OpenAiMessageDto] + temperature: Optional[float] = Field( + default=0.7 + ) + max_tokens: Optional[int] = Field( + default=512 + ) + stop: List[str] = Field( + default=[] + ) + stream: Optional[bool] = Field( + default=False + ) + presence_penalty: Optional[int] = Field( + default=0 + ) + frequency_penalty: Optional[int] = Field( + default=0 + )