add llmchain_with_history.py as example
This commit is contained in:
parent
135af5f3ff
commit
c8b6d53ede
|
|
@ -1,5 +0,0 @@
|
||||||
from server.chat.chat_openai_chain.chat_openai_chain import BaseChatOpenAIChain
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BaseChatOpenAIChain"
|
|
||||||
]
|
|
||||||
|
|
@ -1,152 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
from abc import ABC
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.schema import (
|
|
||||||
BaseMessage,
|
|
||||||
AIMessage,
|
|
||||||
BaseMessage,
|
|
||||||
ChatMessage,
|
|
||||||
HumanMessage,
|
|
||||||
SystemMessage,
|
|
||||||
LLMResult
|
|
||||||
)
|
|
||||||
from langchain.chat_models import ChatOpenAI
|
|
||||||
from langchain.callbacks.manager import (
|
|
||||||
AsyncCallbackManagerForChainRun,
|
|
||||||
CallbackManagerForChainRun,
|
|
||||||
Callbacks,
|
|
||||||
)
|
|
||||||
from server.model.chat_openai_chain import OpenAiChatMsgDto, OpenAiMessageDto, BaseMessageDto
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_dict_to_message(_dict: dict) -> BaseMessage:
|
|
||||||
role = _dict["role"]
|
|
||||||
if role == "user":
|
|
||||||
return HumanMessage(content=_dict["content"])
|
|
||||||
elif role == "assistant":
|
|
||||||
return AIMessage(content=_dict["content"])
|
|
||||||
elif role == "system":
|
|
||||||
return SystemMessage(content=_dict["content"])
|
|
||||||
else:
|
|
||||||
return ChatMessage(content=_dict["content"], role=role)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_message_processors(message_data: List[OpenAiMessageDto]) -> List[BaseMessage]:
|
|
||||||
"""
|
|
||||||
前端消息传输对象DTO转换为chat消息传输对象DTO
|
|
||||||
:param message_data:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
messages = []
|
|
||||||
for message_datum in message_data:
|
|
||||||
messages.append(message_datum.dict())
|
|
||||||
return _convert_dict_to_message(messages)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseChatOpenAIChain(Chain, ABC):
|
|
||||||
chat: ChatOpenAI
|
|
||||||
message_dto_key: str = "message_dto" #: :meta private:
|
|
||||||
output_key: str = "text" #: :meta private:
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_chain(
|
|
||||||
cls,
|
|
||||||
model_name: str,
|
|
||||||
streaming: Optional[bool],
|
|
||||||
verbose: Optional[bool],
|
|
||||||
callbacks: Optional[Callbacks],
|
|
||||||
openai_api_key: Optional[str],
|
|
||||||
# openai_api_base: Optional[str],
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> BaseChatOpenAIChain:
|
|
||||||
chat = ChatOpenAI(
|
|
||||||
streaming=streaming,
|
|
||||||
verbose=verbose,
|
|
||||||
callbacks=callbacks,
|
|
||||||
openai_api_key=openai_api_key,
|
|
||||||
# openai_api_base=openai_api_base,
|
|
||||||
model_name=model_name
|
|
||||||
)
|
|
||||||
return cls(chat=chat, **kwargs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _chain_type(self) -> str:
|
|
||||||
return "BaseChatOpenAIChain"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
return [self.message_dto_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
return [self.output_key]
|
|
||||||
|
|
||||||
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
|
|
||||||
"""Create outputs from response."""
|
|
||||||
return [
|
|
||||||
# Get the text of the top generated string.
|
|
||||||
{self.output_key: generation[0].text}
|
|
||||||
for generation in response.generations
|
|
||||||
]
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, OpenAiChatMsgDto],
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
msg_dto = inputs[self.message_dto_key]
|
|
||||||
openai_messages_dto = convert_message_processors(msg_dto.messages)
|
|
||||||
|
|
||||||
_text = "openai_messages_dto after formatting:\n" + str(openai_messages_dto)
|
|
||||||
if run_manager:
|
|
||||||
run_manager.on_text(_text, end="\n", verbose=self.verbose)
|
|
||||||
|
|
||||||
response = self.chat(messages=openai_messages_dto, stop=msg_dto.stop)
|
|
||||||
return self.create_outputs(response)[0]
|
|
||||||
|
|
||||||
async def _acall(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, OpenAiChatMsgDto],
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
msg_dto = inputs[self.message_dto_key]
|
|
||||||
|
|
||||||
openai_messages_dto = convert_message_processors(msg_dto.messages)
|
|
||||||
|
|
||||||
_text = "openai_messages_dto after formatting:\n" + str(openai_messages_dto)
|
|
||||||
if run_manager:
|
|
||||||
run_manager.on_text(_text, end="\n", verbose=self.verbose)
|
|
||||||
|
|
||||||
response = await self.chat(messages=openai_messages_dto, stop=msg_dto.stop)
|
|
||||||
return self.create_outputs(response)[0]
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
|
||||||
import json
|
|
||||||
|
|
||||||
# Convert instances of the classes to dictionaries
|
|
||||||
message_dto1 = OpenAiMessageDto(type="human", data=BaseMessageDto(content="Hello!"))
|
|
||||||
message_dto2 = OpenAiMessageDto(type="system", data=BaseMessageDto(content="hi!"))
|
|
||||||
chat_msg_dto = OpenAiChatMsgDto(model_name="gpt-3.5-turbo", messages=[message_dto1, message_dto2])
|
|
||||||
|
|
||||||
chat_msg_json = json.dumps(chat_msg_dto.dict(), indent=2)
|
|
||||||
|
|
||||||
print("OpenAiChatMsgDto JSON:")
|
|
||||||
print(chat_msg_json)
|
|
||||||
|
|
||||||
callback = AsyncIteratorCallbackHandler()
|
|
||||||
|
|
||||||
chains = BaseChatOpenAIChain.from_chain(
|
|
||||||
streaming=chat_msg_dto.stream,
|
|
||||||
verbose=True,
|
|
||||||
callbacks=[callback],
|
|
||||||
openai_api_key="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",
|
|
||||||
model_name=chat_msg_dto.model_name
|
|
||||||
|
|
||||||
)
|
|
||||||
|
|
||||||
out = chains({"message_dto": chat_msg_dto})
|
|
||||||
|
|
||||||
print(out)
|
|
||||||
|
|
@ -0,0 +1,29 @@
|
||||||
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
from configs.model_config import llm_model_dict, LLM_MODEL
|
||||||
|
from langchain import LLMChain
|
||||||
|
from langchain.prompts.chat import (
|
||||||
|
ChatPromptTemplate,
|
||||||
|
HumanMessagePromptTemplate,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = ChatOpenAI(
|
||||||
|
streaming=True,
|
||||||
|
verbose=True,
|
||||||
|
# callbacks=[callback],
|
||||||
|
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||||
|
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||||
|
model_name=LLM_MODEL
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
human_prompt = "{input}"
|
||||||
|
human_message_template = HumanMessagePromptTemplate.from_template(human_prompt)
|
||||||
|
|
||||||
|
chat_prompt = ChatPromptTemplate.from_messages(
|
||||||
|
[("human", "我们来玩成语接龙,我先来,生龙活虎"),
|
||||||
|
("ai", "虎头虎脑"),
|
||||||
|
("human", "{input}")])
|
||||||
|
|
||||||
|
|
||||||
|
chain = LLMChain(prompt=chat_prompt, llm=model, verbose=True)
|
||||||
|
print(chain({"input": "恼羞成怒"}))
|
||||||
|
|
@ -1,44 +0,0 @@
|
||||||
"""
|
|
||||||
前端消息传输结构
|
|
||||||
"""
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
|
|
||||||
class BaseMessageDto(BaseModel):
|
|
||||||
"""Message Dto."""
|
|
||||||
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAiMessageDto(BaseModel):
|
|
||||||
"""
|
|
||||||
see @Link{langchain.schema._message_from_dict}
|
|
||||||
"""
|
|
||||||
type: Optional[str] = Field(
|
|
||||||
default="user"
|
|
||||||
)
|
|
||||||
data: BaseMessageDto
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAiChatMsgDto(BaseModel):
|
|
||||||
model_name: str
|
|
||||||
messages: List[OpenAiMessageDto]
|
|
||||||
temperature: Optional[float] = Field(
|
|
||||||
default=0.7
|
|
||||||
)
|
|
||||||
max_tokens: Optional[int] = Field(
|
|
||||||
default=512
|
|
||||||
)
|
|
||||||
stop: List[str] = Field(
|
|
||||||
default=[]
|
|
||||||
)
|
|
||||||
stream: Optional[bool] = Field(
|
|
||||||
default=False
|
|
||||||
)
|
|
||||||
presence_penalty: Optional[int] = Field(
|
|
||||||
default=0
|
|
||||||
)
|
|
||||||
frequency_penalty: Optional[int] = Field(
|
|
||||||
default=0
|
|
||||||
)
|
|
||||||
Loading…
Reference in New Issue