Intention/ernie/train.py

157 lines
6.0 KiB
Python
Raw Normal View History

2025-02-21 16:52:03 +08:00
import paddle
2025-02-25 09:27:14 +08:00
from paddlenlp.datasets import load_dataset
2025-02-21 16:52:03 +08:00
from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer
import yaml
import json
import numpy as np
import functools
from paddle.nn import CrossEntropyLoss
from paddlenlp.data import DataCollatorWithPadding
2025-05-05 15:02:37 +08:00
from paddlenlp.trainer import Trainer, TrainingArguments
2025-02-21 16:52:03 +08:00
import os
2025-02-25 09:27:14 +08:00
from sklearn.metrics import precision_score, recall_score, f1_score
def load_config(config_path):
"""加载 YAML 配置文件"""
2025-02-27 09:06:34 +08:00
try:
with open(config_path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
except Exception as e:
raise ValueError(f"读取配置文件时出错: {str(e)}")
2025-02-25 09:27:14 +08:00
def generate_label_mappings(labels):
"""生成 label2id 和 id2label 映射"""
label_id = {label: idx for idx, label in enumerate(labels)}
id_label = {idx: label for label, idx in label_id.items()}
return label_id, id_label
2025-02-21 16:52:03 +08:00
def preprocess_function(examples, tokenizer, max_length, is_test=False):
2025-02-25 09:27:14 +08:00
"""数据预处理函数"""
2025-02-21 16:52:03 +08:00
result = tokenizer(examples["text"], max_length=max_length, truncation=True, padding='max_length')
if not is_test:
result["labels"] = np.array([examples["label"]], dtype="int64")
return result
2025-02-25 09:27:14 +08:00
2025-02-21 16:52:03 +08:00
def read_local_dataset(path, label2id=None, is_test=False):
2025-02-25 09:27:14 +08:00
"""读取本地数据集"""
2025-02-27 09:06:34 +08:00
try:
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
for item in data:
if is_test:
if "text" in item:
yield {"text": item["text"]}
else:
if "text" in item and "label" in item:
yield {"text": item["text"], "label": label2id.get(item["label"], -1)}
except Exception as e:
raise ValueError(f"读取数据集时出错: {str(e)}")
2025-02-21 16:52:03 +08:00
2025-02-25 09:27:14 +08:00
def load_and_preprocess_dataset(path, label2id, tokenizer, max_length, is_test=False):
"""加载并预处理数据集"""
2025-02-27 09:06:34 +08:00
try:
dataset = load_dataset(read_local_dataset, path=path, label2id=label2id, lazy=False, is_test=is_test)
trans_func = functools.partial(preprocess_function, tokenizer=tokenizer, max_length=max_length, is_test=is_test)
return dataset.map(trans_func)
except Exception as e:
raise ValueError(f"加载和预处理数据集时出错: {str(e)}")
2025-02-21 16:52:03 +08:00
2025-02-25 09:27:14 +08:00
def export_model(trainer, export_model_dir):
"""导出模型和 tokenizer"""
os.makedirs(export_model_dir, exist_ok=True)
model_to_export = trainer.model
input_spec = [paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids")]
paddle.jit.save(model_to_export, os.path.join(export_model_dir, 'model'), input_spec=input_spec)
trainer.tokenizer.save_pretrained(export_model_dir)
2025-02-21 16:52:03 +08:00
2025-02-27 09:06:34 +08:00
# 保存 id2label 和 label2id 文件
2025-02-25 09:27:14 +08:00
id2label_file = os.path.join(export_model_dir, 'id2label.json')
label2id_file = os.path.join(export_model_dir, 'label2id.json')
with open(id2label_file, 'w', encoding='utf-8') as f:
json.dump(trainer.model.id2label, f, ensure_ascii=False)
with open(label2id_file, 'w', encoding='utf-8') as f:
json.dump(trainer.model.label2id, f, ensure_ascii=False)
print(f'Model and tokenizer have been saved to {export_model_dir}')
2025-02-21 16:52:03 +08:00
2025-02-25 09:27:14 +08:00
def compute_metrics(p):
"""计算评估指标"""
predictions, labels = p
2025-02-27 09:06:34 +08:00
pred_labels = np.argmax(predictions, axis=1) + 1
2025-02-25 09:27:14 +08:00
accuracy = np.sum(pred_labels == labels) / len(labels)
precision = precision_score(labels, pred_labels, average='macro')
recall = recall_score(labels, pred_labels, average='macro')
f1 = f1_score(labels, pred_labels, average='macro')
2025-02-27 09:06:34 +08:00
2025-02-25 09:27:14 +08:00
metrics = {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}
2025-02-27 09:06:34 +08:00
print("Computed metrics:", metrics) # 打印计算出来的指标
2025-02-25 09:27:14 +08:00
return metrics
2025-02-27 09:06:34 +08:00
2025-02-25 09:27:14 +08:00
def main():
2025-02-27 09:06:34 +08:00
try:
# 读取配置
config = load_config("data.yaml")
label_id, id_label = generate_label_mappings(config["labels"])
# 加载数据集
tokenizer = ErnieTokenizer.from_pretrained(config["model_path"])
train_ds = load_and_preprocess_dataset(config["train"], label_id, tokenizer, max_length=256)
2025-03-03 13:21:11 +08:00
test_ds = load_and_preprocess_dataset(config["val"], label_id, tokenizer, max_length=256, is_test=True)
2025-02-27 09:06:34 +08:00
# 加载模型
model = ErnieForSequenceClassification.from_pretrained(config["model_path"], num_classes=len(label_id),
label2id=label_id, id2label=id_label)
# 定义 DataLoader
data_collator = DataCollatorWithPadding(tokenizer)
# 定义训练参数
training_args = TrainingArguments(
2025-04-17 09:11:53 +08:00
output_dir="./output_temp",
2025-02-27 09:06:34 +08:00
evaluation_strategy="epoch",
save_strategy="epoch",
2025-05-04 15:29:03 +08:00
eval_steps=2000, # 每2000步评估一次evaluation_strategy="steps"时生效
save_steps=2000, # 每2000步保存一次save_strategy="steps"时生效
2025-02-27 09:06:34 +08:00
logging_dir="./logs",
2025-05-04 15:29:03 +08:00
logging_steps=100, # 每100步输出一次日志
2025-02-27 09:06:34 +08:00
num_train_epochs=10, # 训练轮数
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
2025-02-27 09:06:34 +08:00
gradient_accumulation_steps=1,
learning_rate=5e-5,
weight_decay=0.01,
disable_tqdm=False,
greater_is_better=True, # 准确率越高越好
)
# 创建 Trainer
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
criterion=CrossEntropyLoss(),
train_dataset=train_ds,
eval_dataset=test_ds,
data_collator=data_collator,
compute_metrics=compute_metrics, # 使用自定义的评估指标
)
# 训练模型
trainer.train()
# 保存模型
trainer.save_model("./saved_model_static") # 默认保存为 './uie_ner' 目录
except Exception as e:
print(f"训练过程中出错: {str(e)}")
2025-02-25 09:27:14 +08:00
if __name__ == "__main__":
2025-02-27 09:06:34 +08:00
main()