126 lines
4.6 KiB
Python
126 lines
4.6 KiB
Python
import json
|
||
import paddle
|
||
import paddlenlp
|
||
from paddlenlp.datasets import MapDataset
|
||
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer
|
||
from paddlenlp.trainer import Trainer, TrainingArguments
|
||
from paddlenlp.data import DataCollatorForTokenClassification
|
||
|
||
|
||
# === 1. 加载数据 ===
|
||
def load_dataset(data_path):
|
||
with open(data_path, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
return MapDataset(data)
|
||
|
||
# === 2. 预处理数据 ===
|
||
def preprocess_function(example, tokenizer):
|
||
# 预定义实体类型列表
|
||
entity_types = [
|
||
'date', 'project_name', 'project_type', 'construction_unit',
|
||
'implementation_organization', 'project_department', 'project_manager',
|
||
'subcontractor', 'team_leader', 'risk_level'
|
||
]
|
||
|
||
# 文本 Tokenization
|
||
inputs = tokenizer(example["text"], max_length=512, truncation=True, return_offsets_mapping=True)
|
||
offset_mapping = inputs["offset_mapping"]
|
||
|
||
# 初始化 label_ids(0 表示 O 标签)
|
||
label_ids = [0] * len(offset_mapping) # 0: O, 1: B-XXX, 2: I-XXX
|
||
|
||
# 处理实体
|
||
if "annotations" in example:
|
||
for entity in example["annotations"]:
|
||
print(entity)
|
||
entity_text = entity["text"]
|
||
start, end, entity_label = entity["start"], entity["end"], entity["label"]
|
||
|
||
# 确保 entity_label 在我们的标签范围内
|
||
if entity_label not in entity_types:
|
||
continue # 如果实体标签不在范围内,则跳过
|
||
# 将实体类型映射到索引编号
|
||
entity_class = entity_types.index(entity_label) + 1 # 1: B-XXX, 2: B-XXX, ...
|
||
# 处理实体的起始位置
|
||
entity_started = False # 标记实体是否已开始
|
||
for idx, (char_start, char_end) in enumerate(offset_mapping):
|
||
token = inputs['input_ids'][idx]
|
||
# 排除特殊 token
|
||
if token == tokenizer.cls_token_id or token == tokenizer.sep_token_id:
|
||
continue # 跳过 [CLS] 和 [SEP] token
|
||
if char_start >= start and char_end <= end:
|
||
if not entity_started:
|
||
label_ids[idx] = entity_class # B-实体
|
||
entity_started = True
|
||
else:
|
||
label_ids[idx] = entity_class + len(entity_types) # I-实体
|
||
|
||
# 将标注结果加到输入
|
||
inputs["labels"] = label_ids
|
||
del inputs["offset_mapping"] # 删除 offset_mapping
|
||
print(inputs)
|
||
return inputs
|
||
|
||
|
||
# === 3. 加载 UIE 预训练模型 ===
|
||
model = ErnieForTokenClassification.from_pretrained(r"/mnt/d/weiweiwang/intention/models/uie-base", num_classes=21) # 3 类 (O, B, I)
|
||
tokenizer = ErnieTokenizer.from_pretrained(r"/mnt/d/weiweiwang/intention/models/uie-base")
|
||
|
||
# === 4. 加载数据集 ===
|
||
train_dataset = load_dataset("data/data_part1.json") # 训练数据集
|
||
dev_dataset = load_dataset("data/data_part2.json") # 验证数据集
|
||
print(train_dataset)
|
||
|
||
# === 5. 处理数据 ===
|
||
train_dataset = train_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False)
|
||
dev_dataset = dev_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False)
|
||
|
||
|
||
# === 6. 数据整理 ===
|
||
data_collator = DataCollatorForTokenClassification(tokenizer, padding=True)
|
||
|
||
# === 7. 训练参数 ===
|
||
training_args = TrainingArguments(
|
||
output_dir="./output",
|
||
evaluation_strategy="epoch",
|
||
save_strategy="epoch",
|
||
per_device_train_batch_size=16, # 你的显存较大,可调整 batch_size
|
||
per_device_eval_batch_size=16,
|
||
learning_rate=2e-5,
|
||
num_train_epochs=10, # 训练轮数
|
||
weight_decay=0.01,
|
||
save_total_limit=2, # 只保留最新 2 个模型
|
||
logging_dir="./logs",
|
||
logging_steps=10,
|
||
load_best_model_at_end=True,
|
||
)
|
||
|
||
# === 8. 训练 ===
|
||
trainer = Trainer(
|
||
model=model,
|
||
args=training_args,
|
||
train_dataset=train_dataset,
|
||
eval_dataset=dev_dataset,
|
||
tokenizer=tokenizer,
|
||
data_collator=data_collator,
|
||
)
|
||
|
||
trainer.train()
|
||
|
||
|
||
# 为模型定义输入规格
|
||
input_spec = [
|
||
paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="input_ids"),
|
||
paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="token_type_ids"),
|
||
paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="position_ids"),
|
||
paddle.static.InputSpec(shape=[None, 512], dtype="float32", name="attention_mask")
|
||
]
|
||
|
||
# === 9. 保存模型为静态图 ===
|
||
# 在训练完成后保存模型为静态图
|
||
paddle.jit.save(model, "./saved_model_static", input_spec=input_spec)
|
||
|
||
# === 10. 保存模型的权重 ===
|
||
# 保存模型权重,可以在之后加载
|
||
trainer.save_model("./saved_model_static") # 默认保存为 './uie_ner' 目录
|