Intention/uie/train.py

67 lines
1.6 KiB
Python
Raw Normal View History

2025-02-21 16:52:03 +08:00
import os
import json
import paddle
import paddlenlp
from paddlenlp.utils.log import logger
from paddlenlp.datasets import MapDataset
from paddlenlp.transformers import UIE, ErnieTokenizer
from paddlenlp.trainer import TrainingArguments, Trainer
# 读取数据
def read_data(filepath):
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
return data
# 数据转换函数
def convert_data(examples, tokenizer):
results = []
for example in examples:
text = example["text"]
encoding = tokenizer(text)
results.append({
"input_ids": encoding["input_ids"],
"token_type_ids": encoding["token_type_ids"],
"attention_mask": encoding["attention_mask"],
"labels": example["label"],
})
return results
# 加载数据
train_data = read_data("data/train.json")
dev_data = read_data("data/dev.json")
# 选择模型
model_name = "uie-base"
tokenizer = ErnieTokenizer.from_pretrained(model_name)
model = UIE.from_pretrained(model_name)
# 预处理数据
train_dataset = MapDataset(convert_data(train_data, tokenizer))
dev_dataset = MapDataset(convert_data(dev_data, tokenizer))
# 训练参数
training_args = TrainingArguments(
output_dir="./checkpoint",
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
learning_rate=1e-5,
num_train_epochs=10,
logging_dir="./logs",
logging_steps=100,
save_steps=500,
evaluation_strategy="epoch",
save_total_limit=2
)
# 训练器
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=dev_dataset
)
# 开始训练
trainer.train()