diff --git a/chains/dialogue_answering/__init__.py b/chains/dialogue_answering/__init__.py new file mode 100644 index 0000000..566e767 --- /dev/null +++ b/chains/dialogue_answering/__init__.py @@ -0,0 +1,7 @@ +from .base import ( + DialogueWithSharedMemoryChains +) + +__all__ = [ + "DialogueWithSharedMemoryChains" +] diff --git a/chains/dialogue_answering/__main__.py b/chains/dialogue_answering/__main__.py new file mode 100644 index 0000000..9b9f412 --- /dev/null +++ b/chains/dialogue_answering/__main__.py @@ -0,0 +1,36 @@ +import sys +import os +import argparse +import asyncio +from argparse import Namespace +sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../../') +from chains.dialogue_answering import * +from langchain.llms import OpenAI +from models.base import (BaseAnswer, + AnswerResult) +import models.shared as shared +from models.loader.args import parser +from models.loader import LoaderCheckPoint + +async def dispatch(args: Namespace): + + args_dict = vars(args) + shared.loaderCheckPoint = LoaderCheckPoint(args_dict) + llm_model_ins = shared.loaderLLM() + if not os.path.isfile(args.dialogue_path): + raise FileNotFoundError(f'Invalid dialogue file path for demo mode: "{args.dialogue_path}"') + llm = OpenAI(temperature=0) + dialogue_instance = DialogueWithSharedMemoryChains(zero_shot_react_llm=llm, ask_llm=llm_model_ins, params=args_dict) + + dialogue_instance.agent_chain.run(input="What did David say before, summarize it") + + +if __name__ == '__main__': + + parser.add_argument('--dialogue-path', default='', type=str, help='dialogue-path') + parser.add_argument('--embedding-model', default='', type=str, help='embedding-model') + args = parser.parse_args(['--dialogue-path', '/home/dmeck/Downloads/log.txt', + '--embedding-mode', '/media/checkpoint/text2vec-large-chinese/']) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(dispatch(args)) diff --git a/chains/dialogue_answering/base.py b/chains/dialogue_answering/base.py new file mode 100644 index 0000000..6925f40 --- /dev/null +++ b/chains/dialogue_answering/base.py @@ -0,0 +1,99 @@ +from langchain.base_language import BaseLanguageModel +from langchain.agents import ZeroShotAgent, Tool, AgentExecutor +from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory +from langchain.chains import LLMChain, RetrievalQA +from langchain.embeddings.huggingface import HuggingFaceEmbeddings +from langchain.prompts import PromptTemplate +from langchain.text_splitter import CharacterTextSplitter +from langchain.vectorstores import Chroma + +from loader import DialogueLoader +from chains.dialogue_answering.prompts import ( + DIALOGUE_PREFIX, + DIALOGUE_SUFFIX, + SUMMARY_PROMPT +) + + +class DialogueWithSharedMemoryChains: + zero_shot_react_llm: BaseLanguageModel = None + ask_llm: BaseLanguageModel = None + embeddings: HuggingFaceEmbeddings = None + embedding_model: str = None + vector_search_top_k: int = 6 + dialogue_path: str = None + dialogue_loader: DialogueLoader = None + device: str = None + + def __init__(self, zero_shot_react_llm: BaseLanguageModel = None, ask_llm: BaseLanguageModel = None, + params: dict = None): + self.zero_shot_react_llm = zero_shot_react_llm + self.ask_llm = ask_llm + params = params or {} + self.embedding_model = params.get('embedding_model', 'GanymedeNil/text2vec-large-chinese') + self.vector_search_top_k = params.get('vector_search_top_k', 6) + self.dialogue_path = params.get('dialogue_path', '') + self.device = 'cuda' if params.get('use_cuda', False) else 'cpu' + + self.dialogue_loader = DialogueLoader(self.dialogue_path) + self._init_cfg() + self._init_state_of_history() + self.memory_chain, self.memory = self._agents_answer() + self.agent_chain = self._create_agent_chain() + + def _init_cfg(self): + model_kwargs = { + 'device': self.device + } + self.embeddings = HuggingFaceEmbeddings(model_name=self.embedding_model, model_kwargs=model_kwargs) + + def _init_state_of_history(self): + documents = self.dialogue_loader.load() + text_splitter = CharacterTextSplitter(chunk_size=3, chunk_overlap=1) + texts = text_splitter.split_documents(documents) + docsearch = Chroma.from_documents(texts, self.embeddings, collection_name="state-of-history") + self.state_of_history = RetrievalQA.from_chain_type(llm=self.ask_llm, chain_type="stuff", + retriever=docsearch.as_retriever()) + + def _agents_answer(self): + + memory = ConversationBufferMemory(memory_key="chat_history") + readonly_memory = ReadOnlySharedMemory(memory=memory) + memory_chain = LLMChain( + llm=self.ask_llm, + prompt=SUMMARY_PROMPT, + verbose=True, + memory=readonly_memory, # use the read-only memory to prevent the tool from modifying the memory + ) + return memory_chain, memory + + def _create_agent_chain(self): + dialogue_participants = self.dialogue_loader.dialogue.participants_to_export() + tools = [ + Tool( + name="State of Dialogue History System", + func=self.state_of_history.run, + description=f"Dialogue with {dialogue_participants} - The answers in this section are very useful " + f"when searching for chat content between {dialogue_participants}. Input should be a " + f"complete question. " + ), + Tool( + name="Summary", + func=self.memory_chain.run, + description="useful for when you summarize a conversation. The input to this tool should be a string, " + "representing who will read this summary. " + ) + ] + + prompt = ZeroShotAgent.create_prompt( + tools, + prefix=DIALOGUE_PREFIX, + suffix=DIALOGUE_SUFFIX, + input_variables=["input", "chat_history", "agent_scratchpad"] + ) + + llm_chain = LLMChain(llm=self.zero_shot_react_llm, prompt=prompt) + agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True) + agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=self.memory) + + return agent_chain diff --git a/chains/dialogue_answering/prompts.py b/chains/dialogue_answering/prompts.py new file mode 100644 index 0000000..6cc7e8f --- /dev/null +++ b/chains/dialogue_answering/prompts.py @@ -0,0 +1,22 @@ +from langchain.prompts.prompt import PromptTemplate + + +SUMMARY_TEMPLATE = """This is a conversation between a human and a bot: + +{chat_history} + +Write a summary of the conversation for {input}: +""" + +SUMMARY_PROMPT = PromptTemplate( + input_variables=["input", "chat_history"], + template=SUMMARY_TEMPLATE +) + +DIALOGUE_PREFIX = """Have a conversation with a human,Analyze the content of the conversation. +You have access to the following tools: """ +DIALOGUE_SUFFIX = """Begin! + +{chat_history} +Question: {input} +{agent_scratchpad}"""