Support p-tuning-v2
This commit is contained in:
parent
dc0cdfba90
commit
2cd52f6605
|
|
@ -24,6 +24,9 @@ llm_model_dict = {
|
||||||
# LLM model name
|
# LLM model name
|
||||||
LLM_MODEL = "chatglm-6b"
|
LLM_MODEL = "chatglm-6b"
|
||||||
|
|
||||||
|
# Use p-tuning-v2 PrefixEncoder
|
||||||
|
USE_PTUNING_V2 = False
|
||||||
|
|
||||||
# LLM running device
|
# LLM running device
|
||||||
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
from transformers import AutoTokenizer, AutoModel
|
from transformers import AutoTokenizer, AutoModel, AutoConfig
|
||||||
import torch
|
import torch
|
||||||
from configs.model_config import LLM_DEVICE
|
from configs.model_config import LLM_DEVICE
|
||||||
|
|
||||||
|
|
@ -51,15 +54,30 @@ class ChatGLM(LLM):
|
||||||
|
|
||||||
def load_model(self,
|
def load_model(self,
|
||||||
model_name_or_path: str = "THUDM/chatglm-6b",
|
model_name_or_path: str = "THUDM/chatglm-6b",
|
||||||
llm_device=LLM_DEVICE):
|
llm_device=LLM_DEVICE,
|
||||||
|
use_ptuning_v2=False):
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_name_or_path,
|
model_name_or_path,
|
||||||
trust_remote_code=True
|
trust_remote_code=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||||
|
|
||||||
|
if use_ptuning_v2:
|
||||||
|
try:
|
||||||
|
prefix_encoder_file = open('ptuning-v2/config.json', 'r')
|
||||||
|
prefix_encoder_config = json.loads(prefix_encoder_file.read())
|
||||||
|
prefix_encoder_file.close()
|
||||||
|
model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
|
||||||
|
model_config.prefix_projection = prefix_encoder_config['prefix_projection']
|
||||||
|
except Exception:
|
||||||
|
print("加载PrefixEncoder config.json失败")
|
||||||
|
|
||||||
if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
|
if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
|
||||||
self.model = (
|
self.model = (
|
||||||
AutoModel.from_pretrained(
|
AutoModel.from_pretrained(
|
||||||
model_name_or_path,
|
model_name_or_path,
|
||||||
|
config=model_config,
|
||||||
trust_remote_code=True)
|
trust_remote_code=True)
|
||||||
.half()
|
.half()
|
||||||
.cuda()
|
.cuda()
|
||||||
|
|
@ -68,8 +86,22 @@ class ChatGLM(LLM):
|
||||||
self.model = (
|
self.model = (
|
||||||
AutoModel.from_pretrained(
|
AutoModel.from_pretrained(
|
||||||
model_name_or_path,
|
model_name_or_path,
|
||||||
|
config=model_config,
|
||||||
trust_remote_code=True)
|
trust_remote_code=True)
|
||||||
.float()
|
.float()
|
||||||
.to(llm_device)
|
.to(llm_device)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if use_ptuning_v2:
|
||||||
|
try:
|
||||||
|
prefix_state_dict = torch.load('ptuning-v2/pytorch_model.bin')
|
||||||
|
new_prefix_state_dict = {}
|
||||||
|
for k, v in prefix_state_dict.items():
|
||||||
|
if k.startswith("transformer.prefix_encoder."):
|
||||||
|
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
||||||
|
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
||||||
|
self.model.transformer.prefix_encoder.float()
|
||||||
|
except Exception:
|
||||||
|
print("加载PrefixEncoder模型参数失败")
|
||||||
|
|
||||||
self.model = self.model.eval()
|
self.model = self.model.eval()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
如果使用了[p-tuning-v2](https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning)方式微调了模型,可以将得到的PrefixEndoer放入此文件夹。
|
||||||
|
|
||||||
|
只需要放入模型的*config.json*和*pytorch_model.bin*
|
||||||
|
|
||||||
|
并在加载模型时勾选 *"使用p-tuning-v2微调过的模型"*
|
||||||
10
webui.py
10
webui.py
|
|
@ -53,11 +53,12 @@ def init_model():
|
||||||
return """模型未成功加载,请重新选择后点击"加载模型"按钮"""
|
return """模型未成功加载,请重新选择后点击"加载模型"按钮"""
|
||||||
|
|
||||||
|
|
||||||
def reinit_model(llm_model, embedding_model, llm_history_len, top_k, history):
|
def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, history):
|
||||||
try:
|
try:
|
||||||
local_doc_qa.init_cfg(llm_model=llm_model,
|
local_doc_qa.init_cfg(llm_model=llm_model,
|
||||||
embedding_model=embedding_model,
|
embedding_model=embedding_model,
|
||||||
llm_history_len=llm_history_len,
|
llm_history_len=llm_history_len,
|
||||||
|
use_ptuning_v2=use_ptuning_v2,
|
||||||
top_k=top_k)
|
top_k=top_k)
|
||||||
model_status = """模型已成功重新加载,请选择文件后点击"加载文件"按钮"""
|
model_status = """模型已成功重新加载,请选择文件后点击"加载文件"按钮"""
|
||||||
except:
|
except:
|
||||||
|
|
@ -97,7 +98,7 @@ webui_title = """
|
||||||
"""
|
"""
|
||||||
|
|
||||||
init_message = """欢迎使用 langchain-ChatGLM Web UI,开始提问前,请依次如下 3 个步骤:
|
init_message = """欢迎使用 langchain-ChatGLM Web UI,开始提问前,请依次如下 3 个步骤:
|
||||||
1. 选择语言模型、Embedding 模型及相关参数后点击"重新加载模型",并等待加载完成提示
|
1. 选择语言模型、Embedding 模型及相关参数,如果使用ptuning-v2方式微调过模型,将PrefixEncoder模型放在ptuning-v2文件夹里并勾选相关选项,然后点击"重新加载模型",并等待加载完成提示
|
||||||
2. 上传或选择已有文件作为本地知识文档输入后点击"重新加载文档",并等待加载完成提示
|
2. 上传或选择已有文件作为本地知识文档输入后点击"重新加载文档",并等待加载完成提示
|
||||||
3. 输入要提交的问题后,点击回车提交 """
|
3. 输入要提交的问题后,点击回车提交 """
|
||||||
|
|
||||||
|
|
@ -127,6 +128,9 @@ with gr.Blocks(css=block_css) as demo:
|
||||||
step=1,
|
step=1,
|
||||||
label="LLM history len",
|
label="LLM history len",
|
||||||
interactive=True)
|
interactive=True)
|
||||||
|
use_ptuning_v2 = gr.Checkbox(USE_PTUNING_V2,
|
||||||
|
label="使用p-tuning-v2微调过的模型",
|
||||||
|
interactive=True)
|
||||||
embedding_model = gr.Radio(embedding_model_dict_list,
|
embedding_model = gr.Radio(embedding_model_dict_list,
|
||||||
label="Embedding 模型",
|
label="Embedding 模型",
|
||||||
value=EMBEDDING_MODEL,
|
value=EMBEDDING_MODEL,
|
||||||
|
|
@ -152,7 +156,7 @@ with gr.Blocks(css=block_css) as demo:
|
||||||
load_file_button = gr.Button("加载文件")
|
load_file_button = gr.Button("加载文件")
|
||||||
load_model_button.click(reinit_model,
|
load_model_button.click(reinit_model,
|
||||||
show_progress=True,
|
show_progress=True,
|
||||||
inputs=[llm_model, embedding_model, llm_history_len, top_k, chatbot],
|
inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, chatbot],
|
||||||
outputs=chatbot
|
outputs=chatbot
|
||||||
)
|
)
|
||||||
# 将上传的文件保存到content文件夹下,并更新下拉框
|
# 将上传的文件保存到content文件夹下,并更新下拉框
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue