Intention/ernie/test_model.py

90 lines
3.3 KiB
Python
Raw Normal View History

2025-02-21 16:52:03 +08:00
import paddle
import numpy as np
import yaml
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from paddlenlp.transformers import ErnieTokenizer
from paddle.io import DataLoader
from paddlenlp.data import DataCollatorWithPadding
import json
import functools
from paddlenlp.datasets import load_dataset
2025-03-03 13:21:11 +08:00
from paddlenlp.transformers import ErnieTokenizer, ErnieForSequenceClassification
import paddle.nn.functional as F
2025-02-21 16:52:03 +08:00
# 加载配置
with open("data.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
# 加载模型和 tokenizer
2025-03-03 13:21:11 +08:00
model = ErnieForSequenceClassification.from_pretrained(
R"E:\workingSpace\PycharmProjects\Intention_dev\ernie\output\checkpoint-4160"
)
tokenizer = ErnieTokenizer.from_pretrained(
R"E:\workingSpace\PycharmProjects\Intention_dev\ernie\output\checkpoint-4160"
)
2025-02-21 16:52:03 +08:00
2025-03-03 13:21:11 +08:00
# 生成 label2id 和 id2label
label_id = {label: idx for idx, label in enumerate(config["labels"])}
id_label = {idx: label for label, idx in label_id.items()}
2025-02-21 16:52:03 +08:00
2025-03-03 13:21:11 +08:00
# 读取数据集
2025-02-21 16:52:03 +08:00
def read_local_dataset(path, label2id=None, is_test=True):
with open(path, "r", encoding="utf-8") as f:
2025-03-03 13:21:11 +08:00
data = json.load(f)
2025-02-21 16:52:03 +08:00
for item in data:
if "text" in item:
2025-03-03 13:21:11 +08:00
yield {
"text": item["text"],
"label": label2id.get(item["label"], -1) if not is_test else -1,
}
2025-02-21 16:52:03 +08:00
2025-03-03 13:21:11 +08:00
# 预处理数据
2025-02-21 16:52:03 +08:00
def preprocess_function(examples, tokenizer, max_length, is_test=False):
2025-03-03 13:21:11 +08:00
result = tokenizer(examples["text"], max_length=max_length, truncation=True, padding='max_length', return_attention_mask=True ) # 确保返回 attention_mask)
2025-02-21 16:52:03 +08:00
if not is_test:
2025-03-03 13:21:11 +08:00
result["labels"] = np.array(examples["label"], dtype="int64") # 直接使用 int64 避免额外封装
2025-02-21 16:52:03 +08:00
return result
2025-03-03 13:21:11 +08:00
# 加载并预处理数据集
def load_and_preprocess_dataset(path, label2id, tokenizer, max_length, is_test=False):
try:
dataset = load_dataset(read_local_dataset, path=path, label2id=label2id, lazy=True, is_test=is_test)
trans_func = functools.partial(preprocess_function, tokenizer=tokenizer, max_length=max_length, is_test=is_test)
return dataset.map(trans_func)
except Exception as e:
raise ValueError(f"加载和预处理数据集时出错: {str(e)}")
2025-02-21 16:52:03 +08:00
2025-03-03 13:21:11 +08:00
# 加载测试数据集
test_ds = load_and_preprocess_dataset(config["test"], label_id, tokenizer, max_length=256, is_test=False)
# 使用 DataLoader 进行批量处理
2025-02-21 16:52:03 +08:00
data_collator = DataCollatorWithPadding(tokenizer)
2025-03-03 13:21:11 +08:00
test_dataloader = DataLoader(test_ds, batch_size=16, collate_fn=data_collator)
2025-02-21 16:52:03 +08:00
# 评估模型
2025-03-03 13:21:11 +08:00
model.eval()
2025-02-21 16:52:03 +08:00
all_preds = []
all_labels = []
# 遍历数据集进行推理
for batch in test_dataloader:
2025-03-03 13:21:11 +08:00
input_ids = paddle.to_tensor(batch["input_ids"])
attention_mask = paddle.to_tensor(batch["attention_mask"])
logits = model(input_ids, attention_mask=attention_mask)
probs = F.softmax(logits, axis=1)
pred_labels = paddle.argmax(probs, axis=1).numpy()
2025-02-21 16:52:03 +08:00
all_preds.extend(pred_labels)
2025-03-03 13:21:11 +08:00
all_labels.extend(batch["labels"].numpy())
2025-02-21 16:52:03 +08:00
# 计算评估指标
accuracy = accuracy_score(all_labels, all_preds)
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')
# 输出性能评估结果
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")