248 lines
9.0 KiB
Python
248 lines
9.0 KiB
Python
import json
|
||
import os
|
||
import yaml
|
||
import numpy as np
|
||
from dataclasses import dataclass
|
||
from typing import List, Dict, Optional
|
||
import paddle
|
||
from paddlenlp.metrics import SpanEvaluator
|
||
from sklearn.metrics import classification_report
|
||
from paddlenlp.data import DataCollatorWithPadding
|
||
from paddlenlp.datasets import MapDataset
|
||
from paddlenlp.trainer import Trainer, TrainingArguments, get_last_checkpoint
|
||
from paddlenlp.transformers import ErnieForTokenClassification, AutoTokenizer, UIEM, UIE, export_model
|
||
from paddlenlp.utils.log import logger
|
||
|
||
|
||
def load_config(config_path):
|
||
"""加载YAML配置文件"""
|
||
with open(config_path, "r", encoding="utf-8") as f:
|
||
return yaml.safe_load(f)
|
||
|
||
|
||
# === 1. 加载数据 ===
|
||
def load_dataset(data_path):
|
||
with open(data_path, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
return MapDataset(data)
|
||
|
||
|
||
@dataclass
|
||
class DataArguments:
|
||
train_path: str
|
||
dev_path: str
|
||
max_seq_length: Optional[int] = 512
|
||
dynamic_max_length: Optional[List[int]] = None
|
||
|
||
|
||
@dataclass
|
||
class ModelArguments:
|
||
model_name_or_path: str = "uie-base"
|
||
export_model_dir: Optional[str] = None
|
||
multilingual: bool = False
|
||
|
||
def preprocess_function(example, tokenizer):
|
||
# 文本 Tokenization
|
||
inputs = tokenizer(example["text"], max_length=512, truncation=True, return_offsets_mapping=True)
|
||
offset_mapping = inputs["offset_mapping"]
|
||
# 初始化 label_ids(0 表示 O 标签)
|
||
label_ids = [0] * len(offset_mapping) # 0: O, 1: B-XXX, 2: I-XXX
|
||
|
||
# 处理实体
|
||
if "annotations" in example:
|
||
for entity in example["annotations"]:
|
||
start, end, entity_label = entity["start"], entity["end"], entity["label"]
|
||
|
||
# 确保 entity_label 在我们的标签范围内
|
||
if entity_label not in config["labels"]:
|
||
continue # 如果实体标签不在范围内,则跳过
|
||
# 将实体类型映射到索引编号
|
||
entity_class = config["labels"].index(entity_label) + 1 # 1: B-XXX, 2: B-XXX, ...
|
||
# 处理实体的起始位置
|
||
entity_started = False # 标记实体是否已开始
|
||
for idx, (char_start, char_end) in enumerate(offset_mapping):
|
||
token = inputs['input_ids'][idx]
|
||
# 排除特殊 token
|
||
if token == tokenizer.cls_token_id or token == tokenizer.sep_token_id:
|
||
continue # 跳过 [CLS] 和 [SEP] token
|
||
if char_start >= start and char_end <= end:
|
||
if not entity_started:
|
||
label_ids[idx] = entity_class # B-实体
|
||
entity_started = True
|
||
else:
|
||
label_ids[idx] = entity_class + len(config["labels"]) # I-实体
|
||
|
||
# 将标注结果加到输入
|
||
inputs["labels"] = label_ids
|
||
del inputs["offset_mapping"] # 删除 offset_mapping
|
||
return inputs
|
||
|
||
|
||
# 加载配置文件
|
||
config = load_config("data.yaml")
|
||
|
||
# 从配置文件中提取参数
|
||
model_args = ModelArguments(**config["model_args"])
|
||
data_args = DataArguments(**config["data_args"])
|
||
|
||
# 确保学习率是浮动数值
|
||
if isinstance(config["training_args"]["learning_rate"], str):
|
||
config["training_args"]["learning_rate"] = float(config["training_args"]["learning_rate"])
|
||
|
||
training_args = CompressionArguments(**config["training_args"])
|
||
|
||
# 打印模型和数据配置
|
||
print(f"Model config: {model_args}")
|
||
print(f"Data config: {data_args}")
|
||
print(f"Training config: {training_args}")
|
||
|
||
paddle.set_device(training_args.device)
|
||
|
||
logger.warning(
|
||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
|
||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||
)
|
||
|
||
# 检查是否存在上次训练的检查点
|
||
last_checkpoint = None
|
||
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||
raise ValueError(
|
||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||
"Use --overwrite_output_dir to overcome."
|
||
)
|
||
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||
logger.info(
|
||
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
||
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||
)
|
||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
|
||
if model_args.multilingual:
|
||
model = UIEM.from_pretrained(model_args.model_name_or_path)
|
||
else:
|
||
model = UIE.from_pretrained(model_args.model_name_or_path)
|
||
|
||
# === 4. 加载数据集 ===
|
||
train_dataset = load_dataset(R"data/train.json") # 训练数据集
|
||
dev_dataset = load_dataset(R"data/val.json") # 验证数据集
|
||
|
||
# === 5. 处理数据 ===
|
||
train_ds = train_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False)
|
||
dev_ds = dev_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False)
|
||
|
||
if training_args.device == "npu":
|
||
data_collator = DataCollatorWithPadding(tokenizer, padding="longest")
|
||
else:
|
||
data_collator = DataCollatorWithPadding(tokenizer)
|
||
|
||
criterion = paddle.nn.BCELoss()
|
||
|
||
|
||
def uie_loss_func(outputs, labels):
|
||
start_ids, end_ids = labels
|
||
start_prob, end_prob = outputs
|
||
start_ids = paddle.cast(start_ids, "float32")
|
||
end_ids = paddle.cast(end_ids, "float32")
|
||
loss_start = criterion(start_prob, start_ids)
|
||
loss_end = criterion(end_prob, end_ids)
|
||
loss = (loss_start + loss_end) / 2.0
|
||
return loss
|
||
|
||
|
||
def compute_metrics(p):
|
||
metric = SpanEvaluator()
|
||
start_prob, end_prob = p.predictions
|
||
start_ids, end_ids = p.label_ids
|
||
metric.reset()
|
||
|
||
num_correct, num_infer, num_label = metric.compute(start_prob, end_prob, start_ids, end_ids)
|
||
metric.update(num_correct, num_infer, num_label)
|
||
precision, recall, f1 = metric.accumulate()
|
||
metric.reset()
|
||
|
||
return {"precision": precision, "recall": recall, "f1": f1}
|
||
|
||
|
||
trainer = Trainer(
|
||
model=model,
|
||
criterion=uie_loss_func,
|
||
args=training_args,
|
||
data_collator=data_collator,
|
||
train_dataset=train_ds if training_args.do_train or training_args.do_compress else None,
|
||
eval_dataset=dev_ds if training_args.do_eval or training_args.do_compress else None,
|
||
tokenizer=tokenizer,
|
||
compute_metrics=compute_metrics,
|
||
)
|
||
|
||
trainer.optimizer = paddle.optimizer.AdamW(
|
||
learning_rate=training_args.learning_rate, parameters=model.parameters()
|
||
)
|
||
checkpoint = None
|
||
if training_args.resume_from_checkpoint is not None:
|
||
checkpoint = training_args.resume_from_checkpoint
|
||
elif last_checkpoint is not None:
|
||
checkpoint = last_checkpoint
|
||
|
||
# 训练过程
|
||
if training_args.do_train:
|
||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||
metrics = train_result.metrics
|
||
trainer.save_model()
|
||
trainer.log_metrics("train", metrics)
|
||
trainer.save_metrics("train", metrics)
|
||
trainer.save_state()
|
||
|
||
# 评估模型
|
||
if training_args.do_eval:
|
||
eval_metrics = trainer.evaluate()
|
||
trainer.log_metrics("eval", eval_metrics)
|
||
|
||
# 导出推理模型
|
||
if training_args.do_export:
|
||
if training_args.device == "npu":
|
||
input_spec_dtype = "int32"
|
||
else:
|
||
input_spec_dtype = "int64"
|
||
|
||
input_spec = [
|
||
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="input_ids"),
|
||
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="position_ids"),
|
||
]
|
||
|
||
if model_args.export_model_dir is None:
|
||
model_args.export_model_dir = os.path.join(training_args.output_dir, "export")
|
||
|
||
export_model(model=trainer.model, input_spec=input_spec, path=model_args.export_model_dir)
|
||
trainer.tokenizer.save_pretrained(model_args.export_model_dir)
|
||
|
||
# 如果需要压缩模型
|
||
if training_args.do_compress:
|
||
@paddle.no_grad()
|
||
def custom_evaluate(self, model, data_loader):
|
||
metric = SpanEvaluator()
|
||
model.eval()
|
||
metric.reset()
|
||
for batch in data_loader:
|
||
if model_args.multilingual:
|
||
logits = model(input_ids=batch["input_ids"], position_ids=batch["position_ids"])
|
||
else:
|
||
logits = model(
|
||
input_ids=batch["input_ids"],
|
||
token_type_ids=batch["token_type_ids"],
|
||
position_ids=batch["position_ids"],
|
||
attention_mask=batch["attention_mask"],
|
||
)
|
||
start_prob, end_prob = logits
|
||
start_ids, end_ids = batch["start_positions"], batch["end_positions"]
|
||
num_correct, num_infer, num_label = metric.compute(start_prob, end_prob, start_ids, end_ids)
|
||
metric.update(num_correct, num_infer, num_label)
|
||
precision, recall, f1 = metric.accumulate()
|
||
logger.info("f1: %s, precision: %s, recall: %s" % (f1, precision, recall))
|
||
model.train()
|
||
return f1
|
||
|
||
|
||
trainer.compress(custom_evaluate=custom_evaluate)
|