Intention/uie/train1.py

248 lines
9.0 KiB
Python
Raw Normal View History

2025-03-16 14:40:56 +08:00
import json
import os
import yaml
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Optional
import paddle
from paddlenlp.metrics import SpanEvaluator
from sklearn.metrics import classification_report
from paddlenlp.data import DataCollatorWithPadding
from paddlenlp.datasets import MapDataset
from paddlenlp.trainer import Trainer, TrainingArguments, get_last_checkpoint
from paddlenlp.transformers import ErnieForTokenClassification, AutoTokenizer, UIEM, UIE, export_model
from paddlenlp.utils.log import logger
def load_config(config_path):
"""加载YAML配置文件"""
with open(config_path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
# === 1. 加载数据 ===
def load_dataset(data_path):
with open(data_path, "r", encoding="utf-8") as f:
data = json.load(f)
return MapDataset(data)
@dataclass
class DataArguments:
train_path: str
dev_path: str
max_seq_length: Optional[int] = 512
dynamic_max_length: Optional[List[int]] = None
@dataclass
class ModelArguments:
model_name_or_path: str = "uie-base"
export_model_dir: Optional[str] = None
multilingual: bool = False
def preprocess_function(example, tokenizer):
# 文本 Tokenization
inputs = tokenizer(example["text"], max_length=512, truncation=True, return_offsets_mapping=True)
offset_mapping = inputs["offset_mapping"]
# 初始化 label_ids0 表示 O 标签)
label_ids = [0] * len(offset_mapping) # 0: O, 1: B-XXX, 2: I-XXX
# 处理实体
if "annotations" in example:
for entity in example["annotations"]:
start, end, entity_label = entity["start"], entity["end"], entity["label"]
# 确保 entity_label 在我们的标签范围内
if entity_label not in config["labels"]:
continue # 如果实体标签不在范围内,则跳过
# 将实体类型映射到索引编号
entity_class = config["labels"].index(entity_label) + 1 # 1: B-XXX, 2: B-XXX, ...
# 处理实体的起始位置
entity_started = False # 标记实体是否已开始
for idx, (char_start, char_end) in enumerate(offset_mapping):
token = inputs['input_ids'][idx]
# 排除特殊 token
if token == tokenizer.cls_token_id or token == tokenizer.sep_token_id:
continue # 跳过 [CLS] 和 [SEP] token
if char_start >= start and char_end <= end:
if not entity_started:
label_ids[idx] = entity_class # B-实体
entity_started = True
else:
label_ids[idx] = entity_class + len(config["labels"]) # I-实体
# 将标注结果加到输入
inputs["labels"] = label_ids
del inputs["offset_mapping"] # 删除 offset_mapping
return inputs
# 加载配置文件
config = load_config("data.yaml")
# 从配置文件中提取参数
model_args = ModelArguments(**config["model_args"])
data_args = DataArguments(**config["data_args"])
# 确保学习率是浮动数值
if isinstance(config["training_args"]["learning_rate"], str):
config["training_args"]["learning_rate"] = float(config["training_args"]["learning_rate"])
training_args = CompressionArguments(**config["training_args"])
# 打印模型和数据配置
print(f"Model config: {model_args}")
print(f"Data config: {data_args}")
print(f"Training config: {training_args}")
paddle.set_device(training_args.device)
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# 检查是否存在上次训练的检查点
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
if model_args.multilingual:
model = UIEM.from_pretrained(model_args.model_name_or_path)
else:
model = UIE.from_pretrained(model_args.model_name_or_path)
# === 4. 加载数据集 ===
train_dataset = load_dataset(R"data/train.json") # 训练数据集
dev_dataset = load_dataset(R"data/val.json") # 验证数据集
# === 5. 处理数据 ===
train_ds = train_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False)
dev_ds = dev_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False)
if training_args.device == "npu":
data_collator = DataCollatorWithPadding(tokenizer, padding="longest")
else:
data_collator = DataCollatorWithPadding(tokenizer)
criterion = paddle.nn.BCELoss()
def uie_loss_func(outputs, labels):
start_ids, end_ids = labels
start_prob, end_prob = outputs
start_ids = paddle.cast(start_ids, "float32")
end_ids = paddle.cast(end_ids, "float32")
loss_start = criterion(start_prob, start_ids)
loss_end = criterion(end_prob, end_ids)
loss = (loss_start + loss_end) / 2.0
return loss
def compute_metrics(p):
metric = SpanEvaluator()
start_prob, end_prob = p.predictions
start_ids, end_ids = p.label_ids
metric.reset()
num_correct, num_infer, num_label = metric.compute(start_prob, end_prob, start_ids, end_ids)
metric.update(num_correct, num_infer, num_label)
precision, recall, f1 = metric.accumulate()
metric.reset()
return {"precision": precision, "recall": recall, "f1": f1}
trainer = Trainer(
model=model,
criterion=uie_loss_func,
args=training_args,
data_collator=data_collator,
train_dataset=train_ds if training_args.do_train or training_args.do_compress else None,
eval_dataset=dev_ds if training_args.do_eval or training_args.do_compress else None,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
trainer.optimizer = paddle.optimizer.AdamW(
learning_rate=training_args.learning_rate, parameters=model.parameters()
)
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
# 训练过程
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
trainer.save_model()
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
# 评估模型
if training_args.do_eval:
eval_metrics = trainer.evaluate()
trainer.log_metrics("eval", eval_metrics)
# 导出推理模型
if training_args.do_export:
if training_args.device == "npu":
input_spec_dtype = "int32"
else:
input_spec_dtype = "int64"
input_spec = [
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="input_ids"),
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="position_ids"),
]
if model_args.export_model_dir is None:
model_args.export_model_dir = os.path.join(training_args.output_dir, "export")
export_model(model=trainer.model, input_spec=input_spec, path=model_args.export_model_dir)
trainer.tokenizer.save_pretrained(model_args.export_model_dir)
# 如果需要压缩模型
if training_args.do_compress:
@paddle.no_grad()
def custom_evaluate(self, model, data_loader):
metric = SpanEvaluator()
model.eval()
metric.reset()
for batch in data_loader:
if model_args.multilingual:
logits = model(input_ids=batch["input_ids"], position_ids=batch["position_ids"])
else:
logits = model(
input_ids=batch["input_ids"],
token_type_ids=batch["token_type_ids"],
position_ids=batch["position_ids"],
attention_mask=batch["attention_mask"],
)
start_prob, end_prob = logits
start_ids, end_ids = batch["start_positions"], batch["end_positions"]
num_correct, num_infer, num_label = metric.compute(start_prob, end_prob, start_ids, end_ids)
metric.update(num_correct, num_infer, num_label)
precision, recall, f1 = metric.accumulate()
logger.info("f1: %s, precision: %s, recall: %s" % (f1, precision, recall))
model.train()
return f1
trainer.compress(custom_evaluate=custom_evaluate)