Intention/ernie/train.py

125 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import paddle
from paddlenlp.datasets import MapDataset, load_dataset
from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer
import yaml
import json
import numpy as np
import functools
from paddle.io import DataLoader
from paddle.nn import CrossEntropyLoss
from paddlenlp.data import DataCollatorWithPadding
from paddlenlp.trainer import Trainer, TrainingArguments
import os
import json
import paddle
# 读取 YAML 配置
with open("data.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
# 生成 label2id 和 id2label
label_id = {label: idx for idx, label in enumerate(config["labels"])}
id_label = {idx: label for label, idx in label_id.items()}
# 数据预处理函数
def preprocess_function(examples, tokenizer, max_length, is_test=False):
result = tokenizer(examples["text"], max_length=max_length, truncation=True, padding='max_length')
if not is_test:
result["labels"] = np.array([examples["label"]], dtype="int64")
return result
# 读取本地数据集
def read_local_dataset(path, label2id=None, is_test=False):
with open(path, "r", encoding="utf-8") as f:
data = json.load(f) # 读取 JSON 数据
for item in data:
if is_test:
if "text" in item:
yield {"text": item["text"]} # 测试集仅返回文本
else:
if "text" in item and "label" in item:
yield {"text": item["text"], "label": label2id.get(item["label"], -1)} # 如果label缺失默认标记为 -1
# 加载数据集
train_ds = load_dataset(read_local_dataset, path=config["train"], label2id=label_id, lazy=False)
test_ds = load_dataset(read_local_dataset, path=config["test"], label2id=label_id, lazy=False)
# 加载模型
model = ErnieForSequenceClassification.from_pretrained(config["model_path"], num_classes=config["nc"],
label2id=label_id, id2label=id_label)
tokenizer = ErnieTokenizer.from_pretrained(config["model_path"])
# 转换数据集
trans_func = functools.partial(preprocess_function, tokenizer=tokenizer, max_length=256)
train_ds = train_ds.map(trans_func)
test_ds = test_ds.map(trans_func)
# 定义 DataLoader并使用 DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer)
train_dataloader = DataLoader(train_ds, batch_size=16, shuffle=True, collate_fn=data_collator)
test_dataloader = DataLoader(test_ds, batch_size=16, shuffle=False, collate_fn=data_collator)
# 定义训练参数
training_args = TrainingArguments(
output_dir="./output",
evaluation_strategy="steps", # 评估频率
save_steps=500, # 保存频率
logging_dir="./logs", # 日志目录
logging_steps=100, # 日志频率
num_train_epochs=100, # 训练轮数
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
gradient_accumulation_steps=1, # 梯度累积
learning_rate=5e-5,
weight_decay=0.01, # 权重衰减
disable_tqdm=False, # 是否禁用 tqdm 进度条
)
# 定义评估指标(如果有的话)
def compute_metrics(p):
predictions, labels = p
pred_labels = np.argmax(predictions, axis=1)
accuracy = np.sum(pred_labels == labels) / len(labels)
return {"accuracy": accuracy}
# 创建 Trainer
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
criterion=CrossEntropyLoss(),
train_dataset=train_ds,
eval_dataset=test_ds,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
# 训练模型
trainer.train()
# 设置导出路径
export_model_dir = './output/export'
# 确保目录存在
os.makedirs(export_model_dir, exist_ok=True)
# 导出模型
model_to_export = trainer.model
input_spec = [paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids")]
# 导出模型的静态图
paddle.jit.save(model_to_export, os.path.join(export_model_dir, 'model'), input_spec=input_spec)
# 保存 tokenizer 配置
tokenizer.save_pretrained(export_model_dir)
# 保存标签映射文件
id2label_file = os.path.join(export_model_dir, 'id2label.json')
with open(id2label_file, 'w', encoding='utf-8') as f:
json.dump(id_label, f, ensure_ascii=False)
label2id_file = os.path.join(export_model_dir, 'label2id.json')
with open(label2id_file, 'w', encoding='utf-8') as f:
json.dump(label_id, f, ensure_ascii=False)
print(f'Model and tokenizer have been saved to {export_model_dir}')