118 lines
4.2 KiB
Python
118 lines
4.2 KiB
Python
import json
|
||
import paddle
|
||
from paddlenlp.datasets import MapDataset
|
||
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer
|
||
from paddlenlp.trainer import Trainer, TrainingArguments
|
||
from paddlenlp.data import DataCollatorForTokenClassification
|
||
|
||
|
||
# === 1. 加载数据 ===
|
||
def load_dataset(data_path):
|
||
with open(data_path, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
return MapDataset(data)
|
||
|
||
# === 2. 预处理数据 ===
|
||
def preprocess_function(example, tokenizer):
|
||
# 预定义实体类型列表
|
||
entity_types = [
|
||
'date', 'project_name', 'project_type', 'construction_unit',
|
||
'implementation_organization', 'project_department', 'project_manager',
|
||
'subcontractor', 'team_leader', 'risk_level','page','team_name'
|
||
]
|
||
|
||
# 文本 Tokenization
|
||
inputs = tokenizer(example["text"], max_length=512, truncation=True, return_offsets_mapping=True)
|
||
offset_mapping = inputs["offset_mapping"]
|
||
|
||
# 初始化 label_ids(0 表示 O 标签)
|
||
label_ids = [0] * len(offset_mapping) # 0: O, 1: B-XXX, 2: I-XXX
|
||
|
||
# 处理实体
|
||
if "annotations" in example:
|
||
for entity in example["annotations"]:
|
||
entity_text = entity["text"]
|
||
start, end, entity_label = entity["start"], entity["end"], entity["label"]
|
||
|
||
# 确保 entity_label 在我们的标签范围内
|
||
if entity_label not in entity_types:
|
||
continue # 如果实体标签不在范围内,则跳过
|
||
# 将实体类型映射到索引编号
|
||
entity_class = entity_types.index(entity_label) + 1 # 1: B-XXX, 2: B-XXX, ...
|
||
# 处理实体的起始位置
|
||
entity_started = False # 标记实体是否已开始
|
||
for idx, (char_start, char_end) in enumerate(offset_mapping):
|
||
token = inputs['input_ids'][idx]
|
||
# 排除特殊 token
|
||
if token == tokenizer.cls_token_id or token == tokenizer.sep_token_id:
|
||
continue # 跳过 [CLS] 和 [SEP] token
|
||
if char_start >= start and char_end <= end:
|
||
if not entity_started:
|
||
label_ids[idx] = entity_class # B-实体
|
||
entity_started = True
|
||
else:
|
||
label_ids[idx] = entity_class + len(entity_types) # I-实体
|
||
|
||
# 将标注结果加到输入
|
||
inputs["labels"] = label_ids
|
||
del inputs["offset_mapping"] # 删除 offset_mapping
|
||
return inputs
|
||
|
||
|
||
# === 3. 加载 UIE 预训练模型 ===
|
||
model = ErnieForTokenClassification.from_pretrained("uie-base", num_classes=25) # 3 类 (O, B, I)
|
||
tokenizer = ErnieTokenizer.from_pretrained("uie-base")
|
||
|
||
# === 4. 加载数据集 ===
|
||
train_dataset = load_dataset("data/train.json") # 训练数据集
|
||
dev_dataset = load_dataset("data/val.json") # 验证数据集
|
||
|
||
# === 5. 处理数据 ===
|
||
train_dataset = train_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False)
|
||
dev_dataset = dev_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False)
|
||
|
||
|
||
# === 6. 数据整理 ===
|
||
data_collator = DataCollatorForTokenClassification(tokenizer, padding=True)
|
||
|
||
# === 7. 训练参数 ===
|
||
training_args = TrainingArguments(
|
||
output_dir="./output",
|
||
evaluation_strategy="epoch",
|
||
save_strategy="epoch",
|
||
per_device_train_batch_size=16, # 你的显存较大,可调整 batch_size
|
||
per_device_eval_batch_size=16,
|
||
learning_rate=2e-5,
|
||
num_train_epochs=10, # 训练轮数
|
||
weight_decay=0.01,
|
||
save_total_limit=1, # 只保留最新 2 个模型
|
||
logging_dir="./logs",
|
||
logging_steps=10,
|
||
eval_steps=2000,
|
||
save_steps=2000,
|
||
seed=1000,
|
||
load_best_model_at_end=True,
|
||
)
|
||
|
||
# === 8. 训练 ===
|
||
trainer = Trainer(
|
||
model=model,
|
||
args=training_args,
|
||
train_dataset=train_dataset,
|
||
eval_dataset=dev_dataset,
|
||
tokenizer=tokenizer,
|
||
data_collator=data_collator,
|
||
)
|
||
|
||
trainer.train()
|
||
|
||
|
||
# 为模型定义输入规格
|
||
input_spec = [
|
||
paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="input_ids"),
|
||
paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="token_type_ids"),
|
||
paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="position_ids"),
|
||
paddle.static.InputSpec(shape=[None, 512], dtype="float32", name="attention_mask")
|
||
]
|
||
|