From 12ee17f3b3319562183335512c915b4f75ceb8c4 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Mon, 10 Apr 2023 22:55:22 +0800 Subject: [PATCH] use RetrievalQA instead of ChatVectorDBChain --- chatglm_llm.py | 56 ++++++++++++++------------------------ knowledge_based_chatglm.py | 52 +++++++++++++++++++---------------- 2 files changed, 49 insertions(+), 59 deletions(-) diff --git a/chatglm_llm.py b/chatglm_llm.py index 810e98a..55c0177 100644 --- a/chatglm_llm.py +++ b/chatglm_llm.py @@ -16,28 +16,14 @@ def torch_gc(): torch.cuda.ipc_collect() -tokenizer = AutoTokenizer.from_pretrained( - "/Users/liuqian/Downloads/ChatGLM-6B/chatglm_hf_model", - # "THUDM/chatglm-6b", - trust_remote_code=True -) -model = ( - AutoModel.from_pretrained( - "/Users/liuqian/Downloads/ChatGLM-6B/chatglm_hf_model", - # "THUDM/chatglm-6b", - trust_remote_code=True) - .float() - .to("mps") - # .half() - # .cuda() -) - - class ChatGLM(LLM): max_token: int = 10000 temperature: float = 0.1 top_p = 0.9 history = [] + tokenizer: object = None + model: object = None + history_len: int = 10 def __init__(self): super().__init__() @@ -49,31 +35,29 @@ class ChatGLM(LLM): def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: - response, updated_history = model.chat( - tokenizer, + response, _ = self.model.chat( + self.tokenizer, prompt, - history=self.history, + history=self.history[-self.history_len:], max_length=self.max_token, temperature=self.temperature, ) torch_gc() - print("history: ", self.history) if stop is not None: response = enforce_stop_tokens(response, stop) - self.history = updated_history + self.history = self.history+[[None, response]] return response - def get_num_tokens(self, text: str) -> int: - tokenized_text = tokenizer.tokenize(text) - return len(tokenized_text) - -if __name__ == "__main__": - history = [] - while True: - query = input("Input your question 请输入问题:") - resp, history = model.chat(tokenizer, - query, - history=history, - temperature=0.01, - max_length=100000) - print(resp) \ No newline at end of file + def load_model(self, + model_name_or_path: str = "THUDM/chatglm-6b"): + self.tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, + trust_remote_code=True + ) + self.model = ( + AutoModel.from_pretrained( + model_name_or_path, + trust_remote_code=True) + .half() + .cuda() + ) diff --git a/knowledge_based_chatglm.py b/knowledge_based_chatglm.py index ae5d456..0d136d4 100644 --- a/knowledge_based_chatglm.py +++ b/knowledge_based_chatglm.py @@ -1,5 +1,4 @@ -from langchain.prompts.prompt import PromptTemplate -from langchain.chains import ChatVectorDBChain, ConversationalRetrievalChain +from langchain.chains import RetrievalQA from langchain.prompts.chat import ( ChatPromptTemplate, SystemMessagePromptTemplate, @@ -10,19 +9,34 @@ from langchain.vectorstores import FAISS from langchain.document_loaders import UnstructuredFileLoader from chatglm_llm import ChatGLM +# Global Parameters +EMBEDDING_MODEL = "text2vec" +VECTOR_SEARCH_TOP_K = 6 +LLM_MODEL = "chatglm-6b" +LLM_HISTORY_LEN = 3 + +# Show reply with source text from input document +REPLY_WITH_SOURCE = True + + embedding_model_dict = { "ernie-tiny": "nghuyong/ernie-3.0-nano-zh", "ernie-base": "nghuyong/ernie-3.0-base-zh", - "text2vec": "/Users/liuqian/Downloads/ChatGLM-6B/chatglm_embedding"#"GanymedeNil/text2vec-large-chinese" + "text2vec": "GanymedeNil/text2vec-large-chinese", } - +llm_model_dict = { + "chatglm-6b": "THUDM/chatglm-6b", + "chatglm-6b-int4": "THUDM/chatglm-6b-int4", + "chatglm-6b-int4-qe":"THUDM/chatglm-6b-int4-qe", +} chatglm = ChatGLM() - +chatglm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL]) +chatglm.history_len = LLM_HISTORY_LEN def init_knowledge_vector_store(filepath): - embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict["text2vec"], ) + embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL], ) loader = UnstructuredFileLoader(filepath, mode="elements") docs = loader.load() @@ -43,28 +57,17 @@ def get_knowledge_based_answer(query, vector_store, chat_history=[]): ] prompt = ChatPromptTemplate.from_messages(messages) - condese_propmt_template = """任务: 给一段对话和一个后续问题,将后续问题改写成一个独立的问题。确保问题是完整的,没有模糊的指代。 - ---------------- - 聊天记录: - {chat_history} - ---------------- - 后续问题:{question} - ---------------- - 改写后的独立、完整的问题:""" - new_question_prompt = PromptTemplate.from_template(condese_propmt_template) chatglm.history = chat_history - knowledge_chain = ConversationalRetrievalChain.from_llm( + knowledge_chain = RetrievalQA.from_llm( llm=chatglm, - retriever=vector_store.as_retriever(), - qa_prompt=prompt, - condense_question_prompt=new_question_prompt, + retriever=vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K}), + prompt=prompt ) knowledge_chain.return_source_documents = True - # knowledge_chain.top_k_docs_for_context = 10 - knowledge_chain.max_tokens_limit = 10000 - result = knowledge_chain({"question": query, "chat_history": chat_history}) + result = knowledge_chain({"query": query}) + chatglm.history[-1][0] = query return result, chatglm.history @@ -77,4 +80,7 @@ if __name__ == "__main__": resp, history = get_knowledge_based_answer(query=query, vector_store=vector_store, chat_history=history) - print(resp) + if REPLY_WITH_SOURCE: + print(resp) + else: + print(resp["result"])