Intention/ernie/train.py

158 lines
5.8 KiB
Python
Raw Normal View History

2025-02-21 16:52:03 +08:00
import paddle
2025-02-25 09:27:14 +08:00
from paddlenlp.datasets import load_dataset
2025-02-21 16:52:03 +08:00
from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer
import yaml
import json
import numpy as np
import functools
from paddle.nn import CrossEntropyLoss
from paddlenlp.data import DataCollatorWithPadding
from paddlenlp.trainer import Trainer, TrainingArguments
import os
2025-02-25 09:27:14 +08:00
from sklearn.metrics import precision_score, recall_score, f1_score
2025-02-27 09:06:34 +08:00
2025-02-25 09:27:14 +08:00
def load_config(config_path):
"""加载 YAML 配置文件"""
2025-02-27 09:06:34 +08:00
try:
with open(config_path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
except Exception as e:
raise ValueError(f"读取配置文件时出错: {str(e)}")
2025-02-25 09:27:14 +08:00
def generate_label_mappings(labels):
"""生成 label2id 和 id2label 映射"""
label_id = {label: idx for idx, label in enumerate(labels)}
id_label = {idx: label for label, idx in label_id.items()}
return label_id, id_label
2025-02-21 16:52:03 +08:00
def preprocess_function(examples, tokenizer, max_length, is_test=False):
2025-02-25 09:27:14 +08:00
"""数据预处理函数"""
2025-02-21 16:52:03 +08:00
result = tokenizer(examples["text"], max_length=max_length, truncation=True, padding='max_length')
if not is_test:
result["labels"] = np.array([examples["label"]], dtype="int64")
return result
2025-02-25 09:27:14 +08:00
2025-02-21 16:52:03 +08:00
def read_local_dataset(path, label2id=None, is_test=False):
2025-02-25 09:27:14 +08:00
"""读取本地数据集"""
2025-02-27 09:06:34 +08:00
try:
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
for item in data:
if is_test:
if "text" in item:
yield {"text": item["text"]}
else:
if "text" in item and "label" in item:
yield {"text": item["text"], "label": label2id.get(item["label"], -1)}
except Exception as e:
raise ValueError(f"读取数据集时出错: {str(e)}")
2025-02-21 16:52:03 +08:00
2025-02-25 09:27:14 +08:00
def load_and_preprocess_dataset(path, label2id, tokenizer, max_length, is_test=False):
"""加载并预处理数据集"""
2025-02-27 09:06:34 +08:00
try:
dataset = load_dataset(read_local_dataset, path=path, label2id=label2id, lazy=False, is_test=is_test)
trans_func = functools.partial(preprocess_function, tokenizer=tokenizer, max_length=max_length, is_test=is_test)
return dataset.map(trans_func)
except Exception as e:
raise ValueError(f"加载和预处理数据集时出错: {str(e)}")
2025-02-21 16:52:03 +08:00
2025-02-25 09:27:14 +08:00
def export_model(trainer, export_model_dir):
"""导出模型和 tokenizer"""
os.makedirs(export_model_dir, exist_ok=True)
model_to_export = trainer.model
input_spec = [paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids")]
paddle.jit.save(model_to_export, os.path.join(export_model_dir, 'model'), input_spec=input_spec)
trainer.tokenizer.save_pretrained(export_model_dir)
2025-02-21 16:52:03 +08:00
2025-02-27 09:06:34 +08:00
# 保存 id2label 和 label2id 文件
2025-02-25 09:27:14 +08:00
id2label_file = os.path.join(export_model_dir, 'id2label.json')
label2id_file = os.path.join(export_model_dir, 'label2id.json')
with open(id2label_file, 'w', encoding='utf-8') as f:
json.dump(trainer.model.id2label, f, ensure_ascii=False)
with open(label2id_file, 'w', encoding='utf-8') as f:
json.dump(trainer.model.label2id, f, ensure_ascii=False)
print(f'Model and tokenizer have been saved to {export_model_dir}')
2025-02-21 16:52:03 +08:00
2025-02-25 09:27:14 +08:00
def compute_metrics(p):
"""计算评估指标"""
predictions, labels = p
2025-02-27 09:06:34 +08:00
pred_labels = np.argmax(predictions, axis=1) + 1
2025-02-25 09:27:14 +08:00
accuracy = np.sum(pred_labels == labels) / len(labels)
precision = precision_score(labels, pred_labels, average='macro')
recall = recall_score(labels, pred_labels, average='macro')
f1 = f1_score(labels, pred_labels, average='macro')
2025-02-27 09:06:34 +08:00
2025-02-25 09:27:14 +08:00
metrics = {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}
2025-02-27 09:06:34 +08:00
print("Computed metrics:", metrics) # 打印计算出来的指标
2025-02-25 09:27:14 +08:00
return metrics
2025-02-27 09:06:34 +08:00
2025-02-25 09:27:14 +08:00
def main():
2025-02-27 09:06:34 +08:00
try:
# 读取配置
config = load_config("data.yaml")
label_id, id_label = generate_label_mappings(config["labels"])
# 加载数据集
tokenizer = ErnieTokenizer.from_pretrained(config["model_path"])
train_ds = load_and_preprocess_dataset(config["train"], label_id, tokenizer, max_length=256)
test_ds = load_and_preprocess_dataset(config["test"], label_id, tokenizer, max_length=256, is_test=True)
# 加载模型
model = ErnieForSequenceClassification.from_pretrained(config["model_path"], num_classes=len(label_id),
label2id=label_id, id2label=id_label)
# 定义 DataLoader
data_collator = DataCollatorWithPadding(tokenizer)
# 定义训练参数
training_args = TrainingArguments(
output_dir="./output",
evaluation_strategy="epoch",
save_strategy="epoch",
eval_steps=100, # 每100步评估一次
save_steps=500,
logging_dir="./logs",
logging_steps=50, # 每50步输出一次日志
num_train_epochs=10, # 训练轮数
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
gradient_accumulation_steps=1,
learning_rate=5e-5,
weight_decay=0.01,
disable_tqdm=False,
greater_is_better=True, # 准确率越高越好
)
# 创建 Trainer
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
criterion=CrossEntropyLoss(),
train_dataset=train_ds,
eval_dataset=test_ds,
data_collator=data_collator,
compute_metrics=compute_metrics, # 使用自定义的评估指标
)
# 训练模型
trainer.train()
# 保存模型
trainer.save_model("./saved_model_static") # 默认保存为 './uie_ner' 目录
except Exception as e:
print(f"训练过程中出错: {str(e)}")
2025-02-25 09:27:14 +08:00
if __name__ == "__main__":
2025-02-27 09:06:34 +08:00
main()