add llmchain_with_history.py as example

This commit is contained in:
imClumsyPanda 2023-08-08 22:14:58 +08:00
parent 135af5f3ff
commit c8b6d53ede
4 changed files with 29 additions and 201 deletions

View File

@ -1,5 +0,0 @@
from server.chat.chat_openai_chain.chat_openai_chain import BaseChatOpenAIChain
__all__ = [
"BaseChatOpenAIChain"
]

View File

@ -1,152 +0,0 @@
from __future__ import annotations
from abc import ABC
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.schema import (
BaseMessage,
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
LLMResult
)
from langchain.chat_models import ChatOpenAI
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
Callbacks,
)
from server.model.chat_openai_chain import OpenAiChatMsgDto, OpenAiMessageDto, BaseMessageDto
def _convert_dict_to_message(_dict: dict) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
return AIMessage(content=_dict["content"])
elif role == "system":
return SystemMessage(content=_dict["content"])
else:
return ChatMessage(content=_dict["content"], role=role)
def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[BaseMessage]:
"""
前端消息传输对象DTO转换为chat消息传输对象DTO
:param message_data:
:return:
"""
messages = []
for message_datum in message_data:
messages.append(message_datum.dict())
return _convert_dict_to_message(messages)
class BaseChatOpenAIChain(Chain, ABC):
chat: ChatOpenAI
message_dto_key: str = "message_dto" #: :meta private:
output_key: str = "text" #: :meta private:
@classmethod
def from_chain(
cls,
model_name: str,
streaming: Optional[bool],
verbose: Optional[bool],
callbacks: Optional[Callbacks],
openai_api_key: Optional[str],
# openai_api_base: Optional[str],
**kwargs: Any,
) -> BaseChatOpenAIChain:
chat = ChatOpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
openai_api_key=openai_api_key,
# openai_api_base=openai_api_base,
model_name=model_name
)
return cls(chat=chat, **kwargs)
@property
def _chain_type(self) -> str:
return "BaseChatOpenAIChain"
@property
def input_keys(self) -> List[str]:
return [self.message_dto_key]
@property
def output_keys(self) -> List[str]:
return [self.output_key]
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
"""Create outputs from response."""
return [
# Get the text of the top generated string.
{self.output_key: generation[0].text}
for generation in response.generations
]
def _call(
self,
inputs: Dict[str, OpenAiChatMsgDto],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
msg_dto = inputs[self.message_dto_key]
openai_messages_dto = convert_message_processors(msg_dto.messages)
_text = "openai_messages_dto after formatting:\n" + str(openai_messages_dto)
if run_manager:
run_manager.on_text(_text, end="\n", verbose=self.verbose)
response = self.chat(messages=openai_messages_dto, stop=msg_dto.stop)
return self.create_outputs(response)[0]
async def _acall(
self,
inputs: Dict[str, OpenAiChatMsgDto],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
msg_dto = inputs[self.message_dto_key]
openai_messages_dto = convert_message_processors(msg_dto.messages)
_text = "openai_messages_dto after formatting:\n" + str(openai_messages_dto)
if run_manager:
run_manager.on_text(_text, end="\n", verbose=self.verbose)
response = await self.chat(messages=openai_messages_dto, stop=msg_dto.stop)
return self.create_outputs(response)[0]
if __name__ == "__main__":
from langchain.callbacks import AsyncIteratorCallbackHandler
import json
# Convert instances of the classes to dictionaries
message_dto1 = OpenAiMessageDto(type="human", data=BaseMessageDto(content="Hello!"))
message_dto2 = OpenAiMessageDto(type="system", data=BaseMessageDto(content="hi!"))
chat_msg_dto = OpenAiChatMsgDto(model_name="gpt-3.5-turbo", messages=[message_dto1, message_dto2])
chat_msg_json = json.dumps(chat_msg_dto.dict(), indent=2)
print("OpenAiChatMsgDto JSON:")
print(chat_msg_json)
callback = AsyncIteratorCallbackHandler()
chains = BaseChatOpenAIChain.from_chain(
streaming=chat_msg_dto.stream,
verbose=True,
callbacks=[callback],
openai_api_key="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",
model_name=chat_msg_dto.model_name
)
out = chains({"message_dto": chat_msg_dto})
print(out)

View File

@ -0,0 +1,29 @@
from langchain.chat_models import ChatOpenAI
from configs.model_config import llm_model_dict, LLM_MODEL
from langchain import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
model = ChatOpenAI(
streaming=True,
verbose=True,
# callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL
)
human_prompt = "{input}"
human_message_template = HumanMessagePromptTemplate.from_template(human_prompt)
chat_prompt = ChatPromptTemplate.from_messages(
[("human", "我们来玩成语接龙,我先来,生龙活虎"),
("ai", "虎头虎脑"),
("human", "{input}")])
chain = LLMChain(prompt=chat_prompt, llm=model, verbose=True)
print(chain({"input": "恼羞成怒"}))

View File

@ -1,44 +0,0 @@
"""
前端消息传输结构
"""
from pydantic import BaseModel, Field
from typing import Any, Dict, List, Optional
class BaseMessageDto(BaseModel):
"""Message Dto."""
content: str
class OpenAiMessageDto(BaseModel):
"""
see @Link{langchain.schema._message_from_dict}
"""
type: Optional[str] = Field(
default="user"
)
data: BaseMessageDto
class OpenAiChatMsgDto(BaseModel):
model_name: str
messages: List[OpenAiMessageDto]
temperature: Optional[float] = Field(
default=0.7
)
max_tokens: Optional[int] = Field(
default=512
)
stop: List[str] = Field(
default=[]
)
stream: Optional[bool] = Field(
default=False
)
presence_penalty: Optional[int] = Field(
default=0
)
frequency_penalty: Optional[int] = Field(
default=0
)