Support p-tuning-v2
This commit is contained in:
parent
dc0cdfba90
commit
2cd52f6605
|
|
@ -24,6 +24,9 @@ llm_model_dict = {
|
|||
# LLM model name
|
||||
LLM_MODEL = "chatglm-6b"
|
||||
|
||||
# Use p-tuning-v2 PrefixEncoder
|
||||
USE_PTUNING_V2 = False
|
||||
|
||||
# LLM running device
|
||||
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
from typing import Optional, List
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
from transformers import AutoTokenizer, AutoModel, AutoConfig
|
||||
import torch
|
||||
from configs.model_config import LLM_DEVICE
|
||||
|
||||
|
|
@ -51,15 +54,30 @@ class ChatGLM(LLM):
|
|||
|
||||
def load_model(self,
|
||||
model_name_or_path: str = "THUDM/chatglm-6b",
|
||||
llm_device=LLM_DEVICE):
|
||||
llm_device=LLM_DEVICE,
|
||||
use_ptuning_v2=False):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name_or_path,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||
|
||||
if use_ptuning_v2:
|
||||
try:
|
||||
prefix_encoder_file = open('ptuning-v2/config.json', 'r')
|
||||
prefix_encoder_config = json.loads(prefix_encoder_file.read())
|
||||
prefix_encoder_file.close()
|
||||
model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
|
||||
model_config.prefix_projection = prefix_encoder_config['prefix_projection']
|
||||
except Exception:
|
||||
print("加载PrefixEncoder config.json失败")
|
||||
|
||||
if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
|
||||
self.model = (
|
||||
AutoModel.from_pretrained(
|
||||
model_name_or_path,
|
||||
config=model_config,
|
||||
trust_remote_code=True)
|
||||
.half()
|
||||
.cuda()
|
||||
|
|
@ -68,8 +86,22 @@ class ChatGLM(LLM):
|
|||
self.model = (
|
||||
AutoModel.from_pretrained(
|
||||
model_name_or_path,
|
||||
config=model_config,
|
||||
trust_remote_code=True)
|
||||
.float()
|
||||
.to(llm_device)
|
||||
)
|
||||
|
||||
if use_ptuning_v2:
|
||||
try:
|
||||
prefix_state_dict = torch.load('ptuning-v2/pytorch_model.bin')
|
||||
new_prefix_state_dict = {}
|
||||
for k, v in prefix_state_dict.items():
|
||||
if k.startswith("transformer.prefix_encoder."):
|
||||
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
||||
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
||||
self.model.transformer.prefix_encoder.float()
|
||||
except Exception:
|
||||
print("加载PrefixEncoder模型参数失败")
|
||||
|
||||
self.model = self.model.eval()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
如果使用了[p-tuning-v2](https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning)方式微调了模型,可以将得到的PrefixEndoer放入此文件夹。
|
||||
|
||||
只需要放入模型的*config.json*和*pytorch_model.bin*
|
||||
|
||||
并在加载模型时勾选 *"使用p-tuning-v2微调过的模型"*
|
||||
10
webui.py
10
webui.py
|
|
@ -53,11 +53,12 @@ def init_model():
|
|||
return """模型未成功加载,请重新选择后点击"加载模型"按钮"""
|
||||
|
||||
|
||||
def reinit_model(llm_model, embedding_model, llm_history_len, top_k, history):
|
||||
def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, history):
|
||||
try:
|
||||
local_doc_qa.init_cfg(llm_model=llm_model,
|
||||
embedding_model=embedding_model,
|
||||
llm_history_len=llm_history_len,
|
||||
use_ptuning_v2=use_ptuning_v2,
|
||||
top_k=top_k)
|
||||
model_status = """模型已成功重新加载,请选择文件后点击"加载文件"按钮"""
|
||||
except:
|
||||
|
|
@ -97,7 +98,7 @@ webui_title = """
|
|||
"""
|
||||
|
||||
init_message = """欢迎使用 langchain-ChatGLM Web UI,开始提问前,请依次如下 3 个步骤:
|
||||
1. 选择语言模型、Embedding 模型及相关参数后点击"重新加载模型",并等待加载完成提示
|
||||
1. 选择语言模型、Embedding 模型及相关参数,如果使用ptuning-v2方式微调过模型,将PrefixEncoder模型放在ptuning-v2文件夹里并勾选相关选项,然后点击"重新加载模型",并等待加载完成提示
|
||||
2. 上传或选择已有文件作为本地知识文档输入后点击"重新加载文档",并等待加载完成提示
|
||||
3. 输入要提交的问题后,点击回车提交 """
|
||||
|
||||
|
|
@ -127,6 +128,9 @@ with gr.Blocks(css=block_css) as demo:
|
|||
step=1,
|
||||
label="LLM history len",
|
||||
interactive=True)
|
||||
use_ptuning_v2 = gr.Checkbox(USE_PTUNING_V2,
|
||||
label="使用p-tuning-v2微调过的模型",
|
||||
interactive=True)
|
||||
embedding_model = gr.Radio(embedding_model_dict_list,
|
||||
label="Embedding 模型",
|
||||
value=EMBEDDING_MODEL,
|
||||
|
|
@ -152,7 +156,7 @@ with gr.Blocks(css=block_css) as demo:
|
|||
load_file_button = gr.Button("加载文件")
|
||||
load_model_button.click(reinit_model,
|
||||
show_progress=True,
|
||||
inputs=[llm_model, embedding_model, llm_history_len, top_k, chatbot],
|
||||
inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, chatbot],
|
||||
outputs=chatbot
|
||||
)
|
||||
# 将上传的文件保存到content文件夹下,并更新下拉框
|
||||
|
|
|
|||
Loading…
Reference in New Issue