Intention/uie/train1.py

248 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import yaml
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Optional
import paddle
from paddlenlp.metrics import SpanEvaluator
from sklearn.metrics import classification_report
from paddlenlp.data import DataCollatorWithPadding
from paddlenlp.datasets import MapDataset
from paddlenlp.trainer import Trainer, TrainingArguments, get_last_checkpoint
from paddlenlp.transformers import ErnieForTokenClassification, AutoTokenizer, UIEM, UIE, export_model
from paddlenlp.utils.log import logger
def load_config(config_path):
"""加载YAML配置文件"""
with open(config_path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
# === 1. 加载数据 ===
def load_dataset(data_path):
with open(data_path, "r", encoding="utf-8") as f:
data = json.load(f)
return MapDataset(data)
@dataclass
class DataArguments:
train_path: str
dev_path: str
max_seq_length: Optional[int] = 512
dynamic_max_length: Optional[List[int]] = None
@dataclass
class ModelArguments:
model_name_or_path: str = "uie-base"
export_model_dir: Optional[str] = None
multilingual: bool = False
def preprocess_function(example, tokenizer):
# 文本 Tokenization
inputs = tokenizer(example["text"], max_length=512, truncation=True, return_offsets_mapping=True)
offset_mapping = inputs["offset_mapping"]
# 初始化 label_ids0 表示 O 标签)
label_ids = [0] * len(offset_mapping) # 0: O, 1: B-XXX, 2: I-XXX
# 处理实体
if "annotations" in example:
for entity in example["annotations"]:
start, end, entity_label = entity["start"], entity["end"], entity["label"]
# 确保 entity_label 在我们的标签范围内
if entity_label not in config["labels"]:
continue # 如果实体标签不在范围内,则跳过
# 将实体类型映射到索引编号
entity_class = config["labels"].index(entity_label) + 1 # 1: B-XXX, 2: B-XXX, ...
# 处理实体的起始位置
entity_started = False # 标记实体是否已开始
for idx, (char_start, char_end) in enumerate(offset_mapping):
token = inputs['input_ids'][idx]
# 排除特殊 token
if token == tokenizer.cls_token_id or token == tokenizer.sep_token_id:
continue # 跳过 [CLS] 和 [SEP] token
if char_start >= start and char_end <= end:
if not entity_started:
label_ids[idx] = entity_class # B-实体
entity_started = True
else:
label_ids[idx] = entity_class + len(config["labels"]) # I-实体
# 将标注结果加到输入
inputs["labels"] = label_ids
del inputs["offset_mapping"] # 删除 offset_mapping
return inputs
# 加载配置文件
config = load_config("data.yaml")
# 从配置文件中提取参数
model_args = ModelArguments(**config["model_args"])
data_args = DataArguments(**config["data_args"])
# 确保学习率是浮动数值
if isinstance(config["training_args"]["learning_rate"], str):
config["training_args"]["learning_rate"] = float(config["training_args"]["learning_rate"])
training_args = CompressionArguments(**config["training_args"])
# 打印模型和数据配置
print(f"Model config: {model_args}")
print(f"Data config: {data_args}")
print(f"Training config: {training_args}")
paddle.set_device(training_args.device)
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# 检查是否存在上次训练的检查点
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
if model_args.multilingual:
model = UIEM.from_pretrained(model_args.model_name_or_path)
else:
model = UIE.from_pretrained(model_args.model_name_or_path)
# === 4. 加载数据集 ===
train_dataset = load_dataset(R"data/train.json") # 训练数据集
dev_dataset = load_dataset(R"data/val.json") # 验证数据集
# === 5. 处理数据 ===
train_ds = train_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False)
dev_ds = dev_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False)
if training_args.device == "npu":
data_collator = DataCollatorWithPadding(tokenizer, padding="longest")
else:
data_collator = DataCollatorWithPadding(tokenizer)
criterion = paddle.nn.BCELoss()
def uie_loss_func(outputs, labels):
start_ids, end_ids = labels
start_prob, end_prob = outputs
start_ids = paddle.cast(start_ids, "float32")
end_ids = paddle.cast(end_ids, "float32")
loss_start = criterion(start_prob, start_ids)
loss_end = criterion(end_prob, end_ids)
loss = (loss_start + loss_end) / 2.0
return loss
def compute_metrics(p):
metric = SpanEvaluator()
start_prob, end_prob = p.predictions
start_ids, end_ids = p.label_ids
metric.reset()
num_correct, num_infer, num_label = metric.compute(start_prob, end_prob, start_ids, end_ids)
metric.update(num_correct, num_infer, num_label)
precision, recall, f1 = metric.accumulate()
metric.reset()
return {"precision": precision, "recall": recall, "f1": f1}
trainer = Trainer(
model=model,
criterion=uie_loss_func,
args=training_args,
data_collator=data_collator,
train_dataset=train_ds if training_args.do_train or training_args.do_compress else None,
eval_dataset=dev_ds if training_args.do_eval or training_args.do_compress else None,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
trainer.optimizer = paddle.optimizer.AdamW(
learning_rate=training_args.learning_rate, parameters=model.parameters()
)
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
# 训练过程
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
trainer.save_model()
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
# 评估模型
if training_args.do_eval:
eval_metrics = trainer.evaluate()
trainer.log_metrics("eval", eval_metrics)
# 导出推理模型
if training_args.do_export:
if training_args.device == "npu":
input_spec_dtype = "int32"
else:
input_spec_dtype = "int64"
input_spec = [
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="input_ids"),
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="position_ids"),
]
if model_args.export_model_dir is None:
model_args.export_model_dir = os.path.join(training_args.output_dir, "export")
export_model(model=trainer.model, input_spec=input_spec, path=model_args.export_model_dir)
trainer.tokenizer.save_pretrained(model_args.export_model_dir)
# 如果需要压缩模型
if training_args.do_compress:
@paddle.no_grad()
def custom_evaluate(self, model, data_loader):
metric = SpanEvaluator()
model.eval()
metric.reset()
for batch in data_loader:
if model_args.multilingual:
logits = model(input_ids=batch["input_ids"], position_ids=batch["position_ids"])
else:
logits = model(
input_ids=batch["input_ids"],
token_type_ids=batch["token_type_ids"],
position_ids=batch["position_ids"],
attention_mask=batch["attention_mask"],
)
start_prob, end_prob = logits
start_ids, end_ids = batch["start_positions"], batch["end_positions"]
num_correct, num_infer, num_label = metric.compute(start_prob, end_prob, start_ids, end_ids)
metric.update(num_correct, num_infer, num_label)
precision, recall, f1 = metric.accumulate()
logger.info("f1: %s, precision: %s, recall: %s" % (f1, precision, recall))
model.train()
return f1
trainer.compress(custom_evaluate=custom_evaluate)