This repository has been archived on 2025-11-14. You can view files and clone it, but cannot push or open issues or pull requests.
qwen-vl-finetune-bonus/train.py

289 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from datasets import Dataset
from modelscope import snapshot_download, AutoTokenizer
from swanlab.integration.transformers import SwanLabCallback
from qwen_vl_utils import process_vision_info
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
from transformers import (
TrainingArguments, # type: ignore
Trainer, # type: ignore
Qwen2_5_VLForConditionalGeneration,
AutoProcessor,
)
from transformers.data.data_collator import DataCollatorForSeq2Seq
import swanlab
import json
def process_func(example):
"""
将数据集进行预处理
"""
MAX_LENGTH = 8192
input_ids, attention_mask, labels = [], [], []
conversation = example["conversations"]
input_content = conversation[0]["value"]
output_content = conversation[1]["value"]
file_path = input_content.split("<|vision_start|>")[1].split("<|vision_end|>")[0] # 获取图像路径
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": f"{file_path}",
"resized_height": 280,
"resized_width": 280,
},
{"type": "text", "text": "COCO Yes:"},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
) # 获取文本
image_inputs, video_inputs = process_vision_info(messages) # 获取数据数据(预处理过)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = {key: value.tolist() for key, value in inputs.items()} #tensor -> list,为了方便拼接
instruction = inputs
response = tokenizer(f"{output_content}", add_special_tokens=False)
input_ids = (
instruction["input_ids"][0] + response["input_ids"] + [tokenizer.pad_token_id]
)
attention_mask = instruction["attention_mask"][0] + response["attention_mask"] + [1]
labels = (
[-100] * len(instruction["input_ids"][0])
+ response["input_ids"]
+ [tokenizer.pad_token_id]
)
if len(input_ids) > MAX_LENGTH: # 做一个截断
input_ids = input_ids[:MAX_LENGTH]
attention_mask = attention_mask[:MAX_LENGTH]
labels = labels[:MAX_LENGTH]
input_ids = torch.tensor(input_ids)
attention_mask = torch.tensor(attention_mask)
labels = torch.tensor(labels)
inputs['pixel_values'] = torch.tensor(inputs['pixel_values'])
inputs['image_grid_thw'] = torch.tensor(inputs['image_grid_thw']).squeeze(0) #由1,h,w)变换为h,w
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels,
"pixel_values": inputs['pixel_values'], "image_grid_thw": inputs['image_grid_thw']}
def predict(messages, model):
# 准备推理
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
# 生成输出
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0]
def load_and_convert_data(file_path):
"""加载并转换数据"""
loaded_data = []
with open(file_path, 'r', encoding='utf-8') as file:
for line in file:
loaded_data.append(json.loads(line))
# 将 loaded_data 转换为适合 Dataset 的格式
dataset_dicts = []
for item in loaded_data:
user_content = item[0]['content']
assistant_content = item[1]['content']
# 提取图像和文本信息
image_info = next((x for x in user_content if x['type'] == 'image'), None)
text_info = next((x for x in user_content if x['type'] == 'text'), None)
# 构造新的字典
dataset_entry = {
'role': 'user',
'image_path': image_info['image'] if image_info else None,
'question': text_info['text'] if text_info else None,
'assistant_answer': assistant_content
}
dataset_dicts.append(dataset_entry)
return dataset_dicts
# 在modelscope上下载Qwen2-VL模型到本地目录下
# model_dir = snapshot_download("Qwen/Qwen2-VL-2B-Instruct", cache_dir="./", revision="master")
min_pixel = 256*28*28
max_pixel = 1280*28*28
# 使用Transformers加载模型权重
tokenizer = AutoTokenizer.from_pretrained("/home/gyk/models/Qwen2.5-VL-7B-Instruct/", use_fast=False, trust_remote_code=True)
processor = AutoProcessor.from_pretrained("/home/gyk/models/Qwen2.5-VL-7B-Instruct/")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained("/home/gyk/models/Qwen2.5-VL-7B-Instruct/", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True,)
model.enable_input_require_grads() # 开启梯度检查点时,要执行该方法
# 处理数据集读取json文件
# 拆分成训练集和测试集保存为data_vl_train.json和data_vl_test.json
if True:
train_json_path = "data_vl.json"
with open(train_json_path, 'r') as f:
data = json.load(f)
train_data = data[:-4]
test_data = data[-4:]
with open("data_vl_train.json", "w") as f:
json.dump(train_data, f)
with open("data_vl_test.json", "w") as f:
json.dump(test_data, f)
train_ds = Dataset.from_json("data_vl_train.json")
train_dataset = train_ds.map(process_func) # type: ignore
else:
# 分别加载 test 和 val 数据集
test_data_path = 'data_test.jsonl'
val_data_path = 'data_val.jsonl'
test_dataset_dicts = load_and_convert_data(test_data_path)
val_dataset_dicts = load_and_convert_data(val_data_path)
# 创建 Dataset 对象
test_tmp__dataset = Dataset.from_list(test_dataset_dicts)
val_tmp_dataset = Dataset.from_list(val_dataset_dicts)
test_tmp_dataset = test_tmp__dataset.select(list(range(1000)))
val_tmp_dataset = val_tmp_dataset.select(list(range(50)))
test_dataset = test_tmp_dataset.map(process_func, batched=True,batch_size=4)
val_dataset = val_tmp_dataset.map(process_func, batched=True, batch_size=4)
# 配置LoRA
config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
inference_mode=False, # 训练模式
r=64, # Lora 秩
lora_alpha=16, # Lora alaph具体作用参见 Lora 原理
lora_dropout=0.05, # Dropout 比例
bias="none",
)
# 获取LoRA模型
peft_model = get_peft_model(model, config)
# 配置训练参数
args = TrainingArguments(
output_dir="./output/Qwen2.5-VL-7B",
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
logging_steps=10,
logging_first_step=True, #!!注意这里原本是5
num_train_epochs=2,
save_steps=100,
learning_rate=1e-4,
save_on_each_node=True,
gradient_checkpointing=True,
report_to="none",
)
# 设置SwanLab回调
swanlab_callback = SwanLabCallback(
project="Qwen2_5-VL-finetune",
experiment_name="qwen2.5-vl-coco2014",
config={
"model": "https://modelscope.cn/models/Qwen/Qwen2.5-VL-7B-Instruct",
"dataset": "https://modelscope.cn/datasets/modelscope/coco_2014_caption/quickstart",
"github": "https://github.com/datawhalechina/self-llm",
"prompt": "COCO Yes: ",
"train_data_number": len(train_data),
"lora_rank": 64,
"lora_alpha": 16,
"lora_dropout": 0.1,
},
)
# 配置Trainer
trainer = Trainer(
model=peft_model,
args=args,
train_dataset=train_dataset, # type: ignore
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
callbacks=[swanlab_callback],
)
# 开启模型训练
trainer.train()
# ====================测试模式===================
# 配置测试参数
val_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
inference_mode=True, # 训练模式
r=64, # Lora 秩
lora_alpha=16, # Lora alaph具体作用参见 Lora 原理
lora_dropout=0.05, # Dropout 比例
bias="none",
)
# 获取测试模型
check_point_path = "/home/gyk/Qwen2.5-VL/output/Qwen2.5-VL-7B/checkpoint-126/"
val_peft_model = PeftModel.from_pretrained(model, check_point_path, config=val_config)
# 读取测试数据
with open("data_vl_test.json", "r") as f:
test_dataset = json.load(f)
test_image_list = []
for item in test_dataset:
input_image_prompt = item["conversations"][0]["value"]
# 去掉前后的<|vision_start|>和<|vision_end|>
origin_image_path = input_image_prompt.split("<|vision_start|>")[1].split("<|vision_end|>")[0]
messages = [{
"role": "user",
"content": [
{
"type": "image",
"image": origin_image_path
},
{
"type": "text",
"text": "COCO Yes:"
}
]}]
response = predict(messages, val_peft_model)
messages.append({"role": "assistant", "content": f"{response}"})
print(messages[-1])
test_image_list.append(swanlab.Image(origin_image_path, caption=response))
swanlab.log({"Prediction": test_image_list})
# 在Jupyter Notebook中运行时要停止SwanLab记录需要调用swanlab.finish()
swanlab.finish()