From 2cd52f660586b05024edef151d98e24424f6efe6 Mon Sep 17 00:00:00 2001 From: Thaumstrial <2738722282@qq.com> Date: Sat, 15 Apr 2023 14:43:12 +0800 Subject: [PATCH] Support p-tuning-v2 --- configs/model_config.py | 3 +++ models/chatglm_llm.py | 36 ++++++++++++++++++++++++++++++++++-- ptuning-v2/readme.md | 5 +++++ webui.py | 10 +++++++--- 4 files changed, 49 insertions(+), 5 deletions(-) create mode 100644 ptuning-v2/readme.md diff --git a/configs/model_config.py b/configs/model_config.py index fd309e1..45bcff7 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -24,6 +24,9 @@ llm_model_dict = { # LLM model name LLM_MODEL = "chatglm-6b" +# Use p-tuning-v2 PrefixEncoder +USE_PTUNING_V2 = False + # LLM running device LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py index 7cf3b24..22b9694 100644 --- a/models/chatglm_llm.py +++ b/models/chatglm_llm.py @@ -1,7 +1,10 @@ +import json +import os + from langchain.llms.base import LLM from typing import Optional, List from langchain.llms.utils import enforce_stop_tokens -from transformers import AutoTokenizer, AutoModel +from transformers import AutoTokenizer, AutoModel, AutoConfig import torch from configs.model_config import LLM_DEVICE @@ -51,15 +54,30 @@ class ChatGLM(LLM): def load_model(self, model_name_or_path: str = "THUDM/chatglm-6b", - llm_device=LLM_DEVICE): + llm_device=LLM_DEVICE, + use_ptuning_v2=False): self.tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, trust_remote_code=True ) + + model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) + + if use_ptuning_v2: + try: + prefix_encoder_file = open('ptuning-v2/config.json', 'r') + prefix_encoder_config = json.loads(prefix_encoder_file.read()) + prefix_encoder_file.close() + model_config.pre_seq_len = prefix_encoder_config['pre_seq_len'] + model_config.prefix_projection = prefix_encoder_config['prefix_projection'] + except Exception: + print("加载PrefixEncoder config.json失败") + if torch.cuda.is_available() and llm_device.lower().startswith("cuda"): self.model = ( AutoModel.from_pretrained( model_name_or_path, + config=model_config, trust_remote_code=True) .half() .cuda() @@ -68,8 +86,22 @@ class ChatGLM(LLM): self.model = ( AutoModel.from_pretrained( model_name_or_path, + config=model_config, trust_remote_code=True) .float() .to(llm_device) ) + + if use_ptuning_v2: + try: + prefix_state_dict = torch.load('ptuning-v2/pytorch_model.bin') + new_prefix_state_dict = {} + for k, v in prefix_state_dict.items(): + if k.startswith("transformer.prefix_encoder."): + new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v + self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) + self.model.transformer.prefix_encoder.float() + except Exception: + print("加载PrefixEncoder模型参数失败") + self.model = self.model.eval() diff --git a/ptuning-v2/readme.md b/ptuning-v2/readme.md new file mode 100644 index 0000000..7479d2d --- /dev/null +++ b/ptuning-v2/readme.md @@ -0,0 +1,5 @@ +如果使用了[p-tuning-v2](https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning)方式微调了模型,可以将得到的PrefixEndoer放入此文件夹。 + +只需要放入模型的*config.json*和*pytorch_model.bin* + +并在加载模型时勾选 *"使用p-tuning-v2微调过的模型"* \ No newline at end of file diff --git a/webui.py b/webui.py index 4678db6..b2ed760 100644 --- a/webui.py +++ b/webui.py @@ -53,11 +53,12 @@ def init_model(): return """模型未成功加载,请重新选择后点击"加载模型"按钮""" -def reinit_model(llm_model, embedding_model, llm_history_len, top_k, history): +def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, history): try: local_doc_qa.init_cfg(llm_model=llm_model, embedding_model=embedding_model, llm_history_len=llm_history_len, + use_ptuning_v2=use_ptuning_v2, top_k=top_k) model_status = """模型已成功重新加载,请选择文件后点击"加载文件"按钮""" except: @@ -97,7 +98,7 @@ webui_title = """ """ init_message = """欢迎使用 langchain-ChatGLM Web UI,开始提问前,请依次如下 3 个步骤: -1. 选择语言模型、Embedding 模型及相关参数后点击"重新加载模型",并等待加载完成提示 +1. 选择语言模型、Embedding 模型及相关参数,如果使用ptuning-v2方式微调过模型,将PrefixEncoder模型放在ptuning-v2文件夹里并勾选相关选项,然后点击"重新加载模型",并等待加载完成提示 2. 上传或选择已有文件作为本地知识文档输入后点击"重新加载文档",并等待加载完成提示 3. 输入要提交的问题后,点击回车提交 """ @@ -127,6 +128,9 @@ with gr.Blocks(css=block_css) as demo: step=1, label="LLM history len", interactive=True) + use_ptuning_v2 = gr.Checkbox(USE_PTUNING_V2, + label="使用p-tuning-v2微调过的模型", + interactive=True) embedding_model = gr.Radio(embedding_model_dict_list, label="Embedding 模型", value=EMBEDDING_MODEL, @@ -152,7 +156,7 @@ with gr.Blocks(css=block_css) as demo: load_file_button = gr.Button("加载文件") load_model_button.click(reinit_model, show_progress=True, - inputs=[llm_model, embedding_model, llm_history_len, top_k, chatbot], + inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, chatbot], outputs=chatbot ) # 将上传的文件保存到content文件夹下,并更新下拉框