diff --git a/server/chat/chat_openai_chain/__init__.py b/server/chat/chat_openai_chain/__init__.py deleted file mode 100644 index d53d695..0000000 --- a/server/chat/chat_openai_chain/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from server.chat.chat_openai_chain.chat_openai_chain import BaseChatOpenAIChain - -__all__ = [ - "BaseChatOpenAIChain" -] diff --git a/server/chat/chat_openai_chain/chat_openai_chain.py b/server/chat/chat_openai_chain/chat_openai_chain.py deleted file mode 100644 index 7757f98..0000000 --- a/server/chat/chat_openai_chain/chat_openai_chain.py +++ /dev/null @@ -1,152 +0,0 @@ -from __future__ import annotations -from abc import ABC -from typing import Any, Dict, List, Optional -from langchain.chains.base import Chain -from langchain.schema import ( - BaseMessage, - AIMessage, - BaseMessage, - ChatMessage, - HumanMessage, - SystemMessage, - LLMResult -) -from langchain.chat_models import ChatOpenAI -from langchain.callbacks.manager import ( - AsyncCallbackManagerForChainRun, - CallbackManagerForChainRun, - Callbacks, -) -from server.model.chat_openai_chain import OpenAiChatMsgDto, OpenAiMessageDto, BaseMessageDto - - -def _convert_dict_to_message(_dict: dict) -> BaseMessage: - role = _dict["role"] - if role == "user": - return HumanMessage(content=_dict["content"]) - elif role == "assistant": - return AIMessage(content=_dict["content"]) - elif role == "system": - return SystemMessage(content=_dict["content"]) - else: - return ChatMessage(content=_dict["content"], role=role) - - -def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[BaseMessage]: - """ - 前端消息传输对象DTO转换为chat消息传输对象DTO - :param message_data: - :return: - """ - messages = [] - for message_datum in message_data: - messages.append(message_datum.dict()) - return _convert_dict_to_message(messages) - - -class BaseChatOpenAIChain(Chain, ABC): - chat: ChatOpenAI - message_dto_key: str = "message_dto" #: :meta private: - output_key: str = "text" #: :meta private: - - @classmethod - def from_chain( - cls, - model_name: str, - streaming: Optional[bool], - verbose: Optional[bool], - callbacks: Optional[Callbacks], - openai_api_key: Optional[str], - # openai_api_base: Optional[str], - **kwargs: Any, - ) -> BaseChatOpenAIChain: - chat = ChatOpenAI( - streaming=streaming, - verbose=verbose, - callbacks=callbacks, - openai_api_key=openai_api_key, - # openai_api_base=openai_api_base, - model_name=model_name - ) - return cls(chat=chat, **kwargs) - - @property - def _chain_type(self) -> str: - return "BaseChatOpenAIChain" - - @property - def input_keys(self) -> List[str]: - return [self.message_dto_key] - - @property - def output_keys(self) -> List[str]: - return [self.output_key] - - def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]: - """Create outputs from response.""" - return [ - # Get the text of the top generated string. - {self.output_key: generation[0].text} - for generation in response.generations - ] - - def _call( - self, - inputs: Dict[str, OpenAiChatMsgDto], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - msg_dto = inputs[self.message_dto_key] - openai_messages_dto = convert_message_processors(msg_dto.messages) - - _text = "openai_messages_dto after formatting:\n" + str(openai_messages_dto) - if run_manager: - run_manager.on_text(_text, end="\n", verbose=self.verbose) - - response = self.chat(messages=openai_messages_dto, stop=msg_dto.stop) - return self.create_outputs(response)[0] - - async def _acall( - self, - inputs: Dict[str, OpenAiChatMsgDto], - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - msg_dto = inputs[self.message_dto_key] - - openai_messages_dto = convert_message_processors(msg_dto.messages) - - _text = "openai_messages_dto after formatting:\n" + str(openai_messages_dto) - if run_manager: - run_manager.on_text(_text, end="\n", verbose=self.verbose) - - response = await self.chat(messages=openai_messages_dto, stop=msg_dto.stop) - return self.create_outputs(response)[0] - - -if __name__ == "__main__": - from langchain.callbacks import AsyncIteratorCallbackHandler - import json - - # Convert instances of the classes to dictionaries - message_dto1 = OpenAiMessageDto(type="human", data=BaseMessageDto(content="Hello!")) - message_dto2 = OpenAiMessageDto(type="system", data=BaseMessageDto(content="hi!")) - chat_msg_dto = OpenAiChatMsgDto(model_name="gpt-3.5-turbo", messages=[message_dto1, message_dto2]) - - chat_msg_json = json.dumps(chat_msg_dto.dict(), indent=2) - - print("OpenAiChatMsgDto JSON:") - print(chat_msg_json) - - callback = AsyncIteratorCallbackHandler() - - chains = BaseChatOpenAIChain.from_chain( - streaming=chat_msg_dto.stream, - verbose=True, - callbacks=[callback], - openai_api_key="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", - model_name=chat_msg_dto.model_name - - ) - - out = chains({"message_dto": chat_msg_dto}) - - print(out) diff --git a/server/chat/llmchain_with_history.py b/server/chat/llmchain_with_history.py new file mode 100644 index 0000000..3d36042 --- /dev/null +++ b/server/chat/llmchain_with_history.py @@ -0,0 +1,29 @@ +from langchain.chat_models import ChatOpenAI +from configs.model_config import llm_model_dict, LLM_MODEL +from langchain import LLMChain +from langchain.prompts.chat import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, +) + +model = ChatOpenAI( + streaming=True, + verbose=True, + # callbacks=[callback], + openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], + openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], + model_name=LLM_MODEL +) + + +human_prompt = "{input}" +human_message_template = HumanMessagePromptTemplate.from_template(human_prompt) + +chat_prompt = ChatPromptTemplate.from_messages( + [("human", "我们来玩成语接龙,我先来,生龙活虎"), + ("ai", "虎头虎脑"), + ("human", "{input}")]) + + +chain = LLMChain(prompt=chat_prompt, llm=model, verbose=True) +print(chain({"input": "恼羞成怒"})) \ No newline at end of file diff --git a/server/model/chat_openai_chain.py b/server/model/chat_openai_chain.py deleted file mode 100644 index f4a0f55..0000000 --- a/server/model/chat_openai_chain.py +++ /dev/null @@ -1,44 +0,0 @@ -""" -前端消息传输结构 -""" -from pydantic import BaseModel, Field -from typing import Any, Dict, List, Optional - - -class BaseMessageDto(BaseModel): - """Message Dto.""" - - content: str - - -class OpenAiMessageDto(BaseModel): - """ - see @Link{langchain.schema._message_from_dict} - """ - type: Optional[str] = Field( - default="user" - ) - data: BaseMessageDto - - -class OpenAiChatMsgDto(BaseModel): - model_name: str - messages: List[OpenAiMessageDto] - temperature: Optional[float] = Field( - default=0.7 - ) - max_tokens: Optional[int] = Field( - default=512 - ) - stop: List[str] = Field( - default=[] - ) - stream: Optional[bool] = Field( - default=False - ) - presence_penalty: Optional[int] = Field( - default=0 - ) - frequency_penalty: Optional[int] = Field( - default=0 - )