Intention/uie/train.py

117 lines
4.3 KiB
Python
Raw Normal View History

2025-02-21 16:52:03 +08:00
import json
import paddle
from paddlenlp.datasets import MapDataset
2025-02-25 09:27:14 +08:00
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer
from paddlenlp.trainer import Trainer, TrainingArguments
from paddlenlp.data import DataCollatorForTokenClassification
2025-02-21 16:52:03 +08:00
2025-02-25 09:27:14 +08:00
# === 1. 加载数据 ===
def load_dataset(data_path):
with open(data_path, "r", encoding="utf-8") as f:
2025-02-21 16:52:03 +08:00
data = json.load(f)
2025-02-25 09:27:14 +08:00
return MapDataset(data)
# === 2. 预处理数据 ===
def preprocess_function(example, tokenizer):
# 预定义实体类型列表
entity_types = [
'date', 'project_name', 'project_type', 'construction_unit',
'implementation_organization', 'project_department', 'project_manager',
'subcontractor', 'team_leader', 'risk_level', 'page', 'operating', 'team_name', 'construction_area'
2025-02-25 09:27:14 +08:00
]
# 文本 Tokenization
inputs = tokenizer(example["text"], max_length=512, truncation=True, return_offsets_mapping=True)
offset_mapping = inputs["offset_mapping"]
# 初始化 label_ids0 表示 O 标签)
label_ids = [0] * len(offset_mapping) # 0: O, 1: B-XXX, 2: I-XXX
# 处理实体
if "annotations" in example:
for entity in example["annotations"]:
entity_text = entity["text"]
start, end, entity_label = entity["start"], entity["end"], entity["label"]
# 确保 entity_label 在我们的标签范围内
if entity_label not in entity_types:
continue # 如果实体标签不在范围内,则跳过
# 将实体类型映射到索引编号
entity_class = entity_types.index(entity_label) + 1 # 1: B-XXX, 2: B-XXX, ...
# 处理实体的起始位置
entity_started = False # 标记实体是否已开始
for idx, (char_start, char_end) in enumerate(offset_mapping):
token = inputs['input_ids'][idx]
# 排除特殊 token
if token == tokenizer.cls_token_id or token == tokenizer.sep_token_id:
continue # 跳过 [CLS] 和 [SEP] token
if char_start >= start and char_end <= end:
if not entity_started:
label_ids[idx] = entity_class # B-实体
entity_started = True
else:
label_ids[idx] = entity_class + len(entity_types) # I-实体
# 将标注结果加到输入
inputs["labels"] = label_ids
del inputs["offset_mapping"] # 删除 offset_mapping
return inputs
# === 3. 加载 UIE 预训练模型 ===
model = ErnieForTokenClassification.from_pretrained("uie-base", num_classes=29) # 3 类 (O, B, I)
2025-03-12 21:49:06 +08:00
tokenizer = ErnieTokenizer.from_pretrained("uie-base")
2025-02-25 09:27:14 +08:00
# === 4. 加载数据集 ===
2025-03-03 13:21:11 +08:00
train_dataset = load_dataset("data/train.json") # 训练数据集
dev_dataset = load_dataset("data/val.json") # 验证数据集
2025-02-25 09:27:14 +08:00
# === 5. 处理数据 ===
train_dataset = train_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False)
dev_dataset = dev_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False)
# === 6. 数据整理 ===
data_collator = DataCollatorForTokenClassification(tokenizer, padding=True)
# === 7. 训练参数 ===
2025-02-21 16:52:03 +08:00
training_args = TrainingArguments(
2025-04-17 09:11:53 +08:00
output_dir="./output_temp",
2025-02-25 09:27:14 +08:00
evaluation_strategy="epoch",
save_strategy="epoch",
per_device_train_batch_size=32, # 你的显存较大,可调整 batch_size
per_device_eval_batch_size=32,
2025-02-25 09:27:14 +08:00
learning_rate=2e-5,
num_train_epochs=10, # 训练轮数
weight_decay=0.01,
2025-03-03 13:21:11 +08:00
save_total_limit=1, # 只保留最新 2 个模型
2025-02-21 16:52:03 +08:00
logging_dir="./logs",
logging_steps=100,
eval_steps=5000,
save_steps=5000,
2025-03-03 13:21:11 +08:00
seed=1000,
2025-02-25 09:27:14 +08:00
load_best_model_at_end=True,
2025-02-21 16:52:03 +08:00
)
2025-02-25 09:27:14 +08:00
# === 8. 训练 ===
2025-02-21 16:52:03 +08:00
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
2025-02-25 09:27:14 +08:00
eval_dataset=dev_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
2025-02-21 16:52:03 +08:00
)
trainer.train()
2025-02-25 09:27:14 +08:00
# 为模型定义输入规格
input_spec = [
paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="input_ids"),
paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="token_type_ids"),
paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="position_ids"),
paddle.static.InputSpec(shape=[None, 512], dtype="float32", name="attention_mask")
]