From b4aefca555025d47dd11fab7877c8696122fd6b6 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Wed, 26 Apr 2023 22:29:20 +0800 Subject: [PATCH] add stream support to cli_demo.py --- chains/local_doc_qa.py | 83 +++++++++++++++++++----------------------- cli_demo.py | 18 ++++++--- models/chatglm_llm.py | 73 +++++++++++++++++++++---------------- 3 files changed, 90 insertions(+), 84 deletions(-) diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index 01bf520..6467c0b 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -2,7 +2,6 @@ from langchain.chains import RetrievalQA from langchain.prompts import PromptTemplate from langchain.embeddings.huggingface import HuggingFaceEmbeddings from langchain.vectorstores import FAISS -from langchain.vectorstores.base import VectorStoreRetriever from langchain.document_loaders import UnstructuredFileLoader from models.chatglm_llm import ChatGLM import sentence_transformers @@ -34,23 +33,21 @@ def load_file(filepath): docs = loader.load_and_split(text_splitter=textsplitter) return docs +def generate_prompt(related_docs: List[str], + query: str, + prompt_template=PROMPT_TEMPLATE) -> str: + context = "\n".join([doc.page_content for doc in related_docs]) + prompt = prompt_template.replace("{question}", query).replace("{context}", context) + return prompt -def get_relevant_documents(self, query: str) -> List[Document]: - if self.search_type == "similarity": - docs = self.vectorstore._similarity_search_with_relevance_scores(query, **self.search_kwargs) - for doc in docs: - doc[0].metadata["score"] = doc[1] - docs = [doc[0] for doc in docs] - elif self.search_type == "mmr": - docs = self.vectorstore.max_marginal_relevance_search( - query, **self.search_kwargs - ) - else: - raise ValueError(f"search_type of {self.search_type} not allowed.") + +def get_docs_with_score(docs_with_score): + docs=[] + for doc, score in docs_with_score: + doc.metadata["score"] = score + docs.append(doc) return docs - - class LocalDocQA: llm: object = None embeddings: object = None @@ -73,8 +70,6 @@ class LocalDocQA: self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], model_kwargs={'device': embedding_device}) - # self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name, - # device=embedding_device) self.top_k = top_k def init_knowledge_vector_store(self, @@ -134,34 +129,30 @@ class LocalDocQA: def get_knowledge_based_answer(self, query, vs_path, - chat_history=[], ): - prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。 - 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。 - - 已知内容: - {context} - - 问题: - {question}""" - prompt = PromptTemplate( - template=prompt_template, - input_variables=["context", "question"] - ) - self.llm.history = chat_history + chat_history=[], + streaming=True): + self.llm.streaming = streaming vector_store = FAISS.load_local(vs_path, self.embeddings) - vs_r = vector_store.as_retriever(search_type="mmr", - search_kwargs={"k": self.top_k}) - # VectorStoreRetriever.get_relevant_documents = get_relevant_documents - knowledge_chain = RetrievalQA.from_llm( - llm=self.llm, - retriever=vs_r, - prompt=prompt - ) - knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate( - input_variables=["page_content"], template="{page_content}" - ) + related_docs_with_score = vector_store.similarity_search_with_score(query, + k=self.top_k) + related_docs = get_docs_with_score(related_docs_with_score) + prompt = generate_prompt(related_docs, query) - knowledge_chain.return_source_documents = True - result = knowledge_chain({"query": query}) - self.llm.history[-1][0] = query - return result, self.llm.history + if streaming: + for result, history in self.llm._call(prompt=prompt, + history=chat_history): + history[-1] = list(history[-1]) + history[-1][0] = query + response = {"query": query, + "result": result, + "source_documents": related_docs} + yield response, history + else: + result, history = self.llm._call(prompt=prompt, + history=chat_history) + history[-1] = list(history[-1]) + history[-1][0] = query + response = {"query": query, + "result": result, + "source_documents": related_docs} + return response, history diff --git a/cli_demo.py b/cli_demo.py index b594380..232f75c 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -28,10 +28,16 @@ if __name__ == "__main__": history = [] while True: query = input("Input your question 请输入问题:") - resp, history = local_doc_qa.get_knowledge_based_answer(query=query, - vs_path=vs_path, - chat_history=history) + last_print_len = 0 + for resp, history in local_doc_qa.get_knowledge_based_answer(query=query, + vs_path=vs_path, + chat_history=history, + streaming=True): + print(resp["result"][last_print_len:], end="", flush=True) + last_print_len = len(resp["result"]) if REPLY_WITH_SOURCE: - print(resp) - else: - print(resp["result"]) + source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" + # f"""相关度:{doc.metadata['score']}\n\n""" + for inum, doc in + enumerate(resp["source_documents"])] + print("\n\n" + "\n\n".join(source_text)) diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py index 1c020de..eb15c33 100644 --- a/models/chatglm_llm.py +++ b/models/chatglm_llm.py @@ -5,7 +5,8 @@ from langchain.llms.utils import enforce_stop_tokens from transformers import AutoTokenizer, AutoModel, AutoConfig import torch from configs.model_config import LLM_DEVICE - +from langchain.callbacks.base import CallbackManager +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from typing import Dict, Tuple, Union, Optional DEVICE = LLM_DEVICE @@ -54,10 +55,12 @@ class ChatGLM(LLM): max_token: int = 10000 temperature: float = 0.01 top_p = 0.9 - history = [] + # history = [] tokenizer: object = None model: object = None history_len: int = 10 + streaming: bool = True + callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) def __init__(self): super().__init__() @@ -68,46 +71,45 @@ class ChatGLM(LLM): def _call(self, prompt: str, - stop: Optional[List[str]] = None, - stream=True) -> str: - if stream: - self.history = self.history + [[None, ""]] - for response, history in self.model.stream_chat( - self.tokenizer, - prompt, - history=self.history[-self.history_len:] if self.history_len > 0 else [], - max_length=self.max_token, - temperature=self.temperature, + history: List[List[str]] = [], + stop: Optional[List[str]] = None) -> str: + if self.streaming: + history = history + [[None, ""]] + for stream_resp, history in self.model.stream_chat( + self.tokenizer, + prompt, + history=history[-self.history_len:] if self.history_len > 0 else [], + max_length=self.max_token, + temperature=self.temperature, ): - torch_gc() - self.history[-1][-1] = response - yield response + yield stream_resp, history + else: response, _ = self.model.chat( self.tokenizer, prompt, - history=self.history[-self.history_len:] if self.history_len > 0 else [], + history=history[-self.history_len:] if self.history_len > 0 else [], max_length=self.max_token, temperature=self.temperature, ) torch_gc() if stop is not None: response = enforce_stop_tokens(response, stop) - self.history = self.history + [[None, response]] - return response + history = history + [[None, response]] + return response, history - def chat(self, - prompt: str) -> str: - response, _ = self.model.chat( - self.tokenizer, - prompt, - history=self.history[-self.history_len:] if self.history_len > 0 else [], - max_length=self.max_token, - temperature=self.temperature, - ) - torch_gc() - self.history = self.history + [[None, response]] - return response + # def chat(self, + # prompt: str) -> str: + # response, _ = self.model.chat( + # self.tokenizer, + # prompt, + # history=self.history[-self.history_len:] if self.history_len > 0 else [], + # max_length=self.max_token, + # temperature=self.temperature, + # ) + # torch_gc() + # self.history = self.history + [[None, response]] + # return response def load_model(self, model_name_or_path: str = "THUDM/chatglm-6b", @@ -149,7 +151,13 @@ class ChatGLM(LLM): else: from accelerate import dispatch_model - model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True, **kwargs).half() + model = ( + AutoModel.from_pretrained( + model_name_or_path, + trust_remote_code=True, + config=model_config, + **kwargs) + .half()) # 可传入device_map自定义每张卡的部署情况 if device_map is None: device_map = auto_configure_device_map(num_gpus) @@ -160,7 +168,8 @@ class ChatGLM(LLM): AutoModel.from_pretrained( model_name_or_path, config=model_config, - trust_remote_code=True) + trust_remote_code=True, + **kwargs) .float() .to(llm_device) )