Intention/uie/train.py

118 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import paddle
from paddlenlp.datasets import MapDataset
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer
from paddlenlp.trainer import Trainer, TrainingArguments
from paddlenlp.data import DataCollatorForTokenClassification
# === 1. 加载数据 ===
def load_dataset(data_path):
with open(data_path, "r", encoding="utf-8") as f:
data = json.load(f)
return MapDataset(data)
# === 2. 预处理数据 ===
def preprocess_function(example, tokenizer):
# 预定义实体类型列表
entity_types = [
'date', 'project_name', 'project_type', 'construction_unit',
'implementation_organization', 'project_department', 'project_manager',
'subcontractor', 'team_leader', 'risk_level','page','team_name'
]
# 文本 Tokenization
inputs = tokenizer(example["text"], max_length=512, truncation=True, return_offsets_mapping=True)
offset_mapping = inputs["offset_mapping"]
# 初始化 label_ids0 表示 O 标签)
label_ids = [0] * len(offset_mapping) # 0: O, 1: B-XXX, 2: I-XXX
# 处理实体
if "annotations" in example:
for entity in example["annotations"]:
entity_text = entity["text"]
start, end, entity_label = entity["start"], entity["end"], entity["label"]
# 确保 entity_label 在我们的标签范围内
if entity_label not in entity_types:
continue # 如果实体标签不在范围内,则跳过
# 将实体类型映射到索引编号
entity_class = entity_types.index(entity_label) + 1 # 1: B-XXX, 2: B-XXX, ...
# 处理实体的起始位置
entity_started = False # 标记实体是否已开始
for idx, (char_start, char_end) in enumerate(offset_mapping):
token = inputs['input_ids'][idx]
# 排除特殊 token
if token == tokenizer.cls_token_id or token == tokenizer.sep_token_id:
continue # 跳过 [CLS] 和 [SEP] token
if char_start >= start and char_end <= end:
if not entity_started:
label_ids[idx] = entity_class # B-实体
entity_started = True
else:
label_ids[idx] = entity_class + len(entity_types) # I-实体
# 将标注结果加到输入
inputs["labels"] = label_ids
del inputs["offset_mapping"] # 删除 offset_mapping
return inputs
# === 3. 加载 UIE 预训练模型 ===
model = ErnieForTokenClassification.from_pretrained("uie-base", num_classes=25) # 3 类 (O, B, I)
tokenizer = ErnieTokenizer.from_pretrained("uie-base")
# === 4. 加载数据集 ===
train_dataset = load_dataset("data/train.json") # 训练数据集
dev_dataset = load_dataset("data/val.json") # 验证数据集
# === 5. 处理数据 ===
train_dataset = train_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False)
dev_dataset = dev_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False)
# === 6. 数据整理 ===
data_collator = DataCollatorForTokenClassification(tokenizer, padding=True)
# === 7. 训练参数 ===
training_args = TrainingArguments(
output_dir="./output",
evaluation_strategy="epoch",
save_strategy="epoch",
per_device_train_batch_size=16, # 你的显存较大,可调整 batch_size
per_device_eval_batch_size=16,
learning_rate=2e-5,
num_train_epochs=10, # 训练轮数
weight_decay=0.01,
save_total_limit=1, # 只保留最新 2 个模型
logging_dir="./logs",
logging_steps=10,
eval_steps=2000,
save_steps=2000,
seed=1000,
load_best_model_at_end=True,
)
# === 8. 训练 ===
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=dev_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)
trainer.train()
# 为模型定义输入规格
input_spec = [
paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="input_ids"),
paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="token_type_ids"),
paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="position_ids"),
paddle.static.InputSpec(shape=[None, 512], dtype="float32", name="attention_mask")
]