add llmchain_with_history.py as example
This commit is contained in:
parent
135af5f3ff
commit
c8b6d53ede
|
|
@ -1,5 +0,0 @@
|
|||
from server.chat.chat_openai_chain.chat_openai_chain import BaseChatOpenAIChain
|
||||
|
||||
__all__ = [
|
||||
"BaseChatOpenAIChain"
|
||||
]
|
||||
|
|
@ -1,152 +0,0 @@
|
|||
from __future__ import annotations
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Optional
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.schema import (
|
||||
BaseMessage,
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
LLMResult
|
||||
)
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from server.model.chat_openai_chain import OpenAiChatMsgDto, OpenAiMessageDto, BaseMessageDto
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: dict) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
elif role == "assistant":
|
||||
return AIMessage(content=_dict["content"])
|
||||
elif role == "system":
|
||||
return SystemMessage(content=_dict["content"])
|
||||
else:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
|
||||
|
||||
def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[BaseMessage]:
|
||||
"""
|
||||
前端消息传输对象DTO转换为chat消息传输对象DTO
|
||||
:param message_data:
|
||||
:return:
|
||||
"""
|
||||
messages = []
|
||||
for message_datum in message_data:
|
||||
messages.append(message_datum.dict())
|
||||
return _convert_dict_to_message(messages)
|
||||
|
||||
|
||||
class BaseChatOpenAIChain(Chain, ABC):
|
||||
chat: ChatOpenAI
|
||||
message_dto_key: str = "message_dto" #: :meta private:
|
||||
output_key: str = "text" #: :meta private:
|
||||
|
||||
@classmethod
|
||||
def from_chain(
|
||||
cls,
|
||||
model_name: str,
|
||||
streaming: Optional[bool],
|
||||
verbose: Optional[bool],
|
||||
callbacks: Optional[Callbacks],
|
||||
openai_api_key: Optional[str],
|
||||
# openai_api_base: Optional[str],
|
||||
**kwargs: Any,
|
||||
) -> BaseChatOpenAIChain:
|
||||
chat = ChatOpenAI(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
openai_api_key=openai_api_key,
|
||||
# openai_api_base=openai_api_base,
|
||||
model_name=model_name
|
||||
)
|
||||
return cls(chat=chat, **kwargs)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "BaseChatOpenAIChain"
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
return [self.message_dto_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
return [self.output_key]
|
||||
|
||||
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
|
||||
"""Create outputs from response."""
|
||||
return [
|
||||
# Get the text of the top generated string.
|
||||
{self.output_key: generation[0].text}
|
||||
for generation in response.generations
|
||||
]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, OpenAiChatMsgDto],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
msg_dto = inputs[self.message_dto_key]
|
||||
openai_messages_dto = convert_message_processors(msg_dto.messages)
|
||||
|
||||
_text = "openai_messages_dto after formatting:\n" + str(openai_messages_dto)
|
||||
if run_manager:
|
||||
run_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
|
||||
response = self.chat(messages=openai_messages_dto, stop=msg_dto.stop)
|
||||
return self.create_outputs(response)[0]
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, OpenAiChatMsgDto],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
msg_dto = inputs[self.message_dto_key]
|
||||
|
||||
openai_messages_dto = convert_message_processors(msg_dto.messages)
|
||||
|
||||
_text = "openai_messages_dto after formatting:\n" + str(openai_messages_dto)
|
||||
if run_manager:
|
||||
run_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
|
||||
response = await self.chat(messages=openai_messages_dto, stop=msg_dto.stop)
|
||||
return self.create_outputs(response)[0]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
import json
|
||||
|
||||
# Convert instances of the classes to dictionaries
|
||||
message_dto1 = OpenAiMessageDto(type="human", data=BaseMessageDto(content="Hello!"))
|
||||
message_dto2 = OpenAiMessageDto(type="system", data=BaseMessageDto(content="hi!"))
|
||||
chat_msg_dto = OpenAiChatMsgDto(model_name="gpt-3.5-turbo", messages=[message_dto1, message_dto2])
|
||||
|
||||
chat_msg_json = json.dumps(chat_msg_dto.dict(), indent=2)
|
||||
|
||||
print("OpenAiChatMsgDto JSON:")
|
||||
print(chat_msg_json)
|
||||
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
|
||||
chains = BaseChatOpenAIChain.from_chain(
|
||||
streaming=chat_msg_dto.stream,
|
||||
verbose=True,
|
||||
callbacks=[callback],
|
||||
openai_api_key="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",
|
||||
model_name=chat_msg_dto.model_name
|
||||
|
||||
)
|
||||
|
||||
out = chains({"message_dto": chat_msg_dto})
|
||||
|
||||
print(out)
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
from langchain.chat_models import ChatOpenAI
|
||||
from configs.model_config import llm_model_dict, LLM_MODEL
|
||||
from langchain import LLMChain
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
)
|
||||
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
# callbacks=[callback],
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL
|
||||
)
|
||||
|
||||
|
||||
human_prompt = "{input}"
|
||||
human_message_template = HumanMessagePromptTemplate.from_template(human_prompt)
|
||||
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[("human", "我们来玩成语接龙,我先来,生龙活虎"),
|
||||
("ai", "虎头虎脑"),
|
||||
("human", "{input}")])
|
||||
|
||||
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model, verbose=True)
|
||||
print(chain({"input": "恼羞成怒"}))
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
"""
|
||||
前端消息传输结构
|
||||
"""
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class BaseMessageDto(BaseModel):
|
||||
"""Message Dto."""
|
||||
|
||||
content: str
|
||||
|
||||
|
||||
class OpenAiMessageDto(BaseModel):
|
||||
"""
|
||||
see @Link{langchain.schema._message_from_dict}
|
||||
"""
|
||||
type: Optional[str] = Field(
|
||||
default="user"
|
||||
)
|
||||
data: BaseMessageDto
|
||||
|
||||
|
||||
class OpenAiChatMsgDto(BaseModel):
|
||||
model_name: str
|
||||
messages: List[OpenAiMessageDto]
|
||||
temperature: Optional[float] = Field(
|
||||
default=0.7
|
||||
)
|
||||
max_tokens: Optional[int] = Field(
|
||||
default=512
|
||||
)
|
||||
stop: List[str] = Field(
|
||||
default=[]
|
||||
)
|
||||
stream: Optional[bool] = Field(
|
||||
default=False
|
||||
)
|
||||
presence_penalty: Optional[int] = Field(
|
||||
default=0
|
||||
)
|
||||
frequency_penalty: Optional[int] = Field(
|
||||
default=0
|
||||
)
|
||||
Loading…
Reference in New Issue