2025-02-21 16:52:03 +08:00
|
|
|
import paddle
|
|
|
|
|
import numpy as np
|
|
|
|
|
import yaml
|
|
|
|
|
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
|
|
|
|
from paddlenlp.transformers import ErnieTokenizer
|
|
|
|
|
from paddle.io import DataLoader
|
|
|
|
|
from paddlenlp.data import DataCollatorWithPadding
|
|
|
|
|
import json
|
|
|
|
|
import functools
|
|
|
|
|
from paddlenlp.datasets import load_dataset
|
2025-03-03 13:21:11 +08:00
|
|
|
from paddlenlp.transformers import ErnieTokenizer, ErnieForSequenceClassification
|
|
|
|
|
import paddle.nn.functional as F
|
2025-02-21 16:52:03 +08:00
|
|
|
|
|
|
|
|
# 加载配置
|
|
|
|
|
with open("data.yaml", "r", encoding="utf-8") as f:
|
|
|
|
|
config = yaml.safe_load(f)
|
|
|
|
|
|
|
|
|
|
# 加载模型和 tokenizer
|
2025-03-03 13:21:11 +08:00
|
|
|
model = ErnieForSequenceClassification.from_pretrained(
|
|
|
|
|
R"E:\workingSpace\PycharmProjects\Intention_dev\ernie\output\checkpoint-4160"
|
|
|
|
|
)
|
|
|
|
|
tokenizer = ErnieTokenizer.from_pretrained(
|
|
|
|
|
R"E:\workingSpace\PycharmProjects\Intention_dev\ernie\output\checkpoint-4160"
|
|
|
|
|
)
|
2025-02-21 16:52:03 +08:00
|
|
|
|
2025-03-03 13:21:11 +08:00
|
|
|
# 生成 label2id 和 id2label
|
|
|
|
|
label_id = {label: idx for idx, label in enumerate(config["labels"])}
|
|
|
|
|
id_label = {idx: label for label, idx in label_id.items()}
|
2025-02-21 16:52:03 +08:00
|
|
|
|
2025-03-03 13:21:11 +08:00
|
|
|
# 读取数据集
|
2025-02-21 16:52:03 +08:00
|
|
|
def read_local_dataset(path, label2id=None, is_test=True):
|
|
|
|
|
with open(path, "r", encoding="utf-8") as f:
|
2025-03-03 13:21:11 +08:00
|
|
|
data = json.load(f)
|
2025-02-21 16:52:03 +08:00
|
|
|
for item in data:
|
|
|
|
|
if "text" in item:
|
2025-03-03 13:21:11 +08:00
|
|
|
yield {
|
|
|
|
|
"text": item["text"],
|
|
|
|
|
"label": label2id.get(item["label"], -1) if not is_test else -1,
|
|
|
|
|
}
|
2025-02-21 16:52:03 +08:00
|
|
|
|
2025-03-03 13:21:11 +08:00
|
|
|
# 预处理数据
|
2025-02-21 16:52:03 +08:00
|
|
|
def preprocess_function(examples, tokenizer, max_length, is_test=False):
|
2025-03-03 13:21:11 +08:00
|
|
|
result = tokenizer(examples["text"], max_length=max_length, truncation=True, padding='max_length', return_attention_mask=True ) # 确保返回 attention_mask)
|
2025-02-21 16:52:03 +08:00
|
|
|
if not is_test:
|
2025-03-03 13:21:11 +08:00
|
|
|
result["labels"] = np.array(examples["label"], dtype="int64") # 直接使用 int64 避免额外封装
|
2025-02-21 16:52:03 +08:00
|
|
|
return result
|
|
|
|
|
|
2025-03-03 13:21:11 +08:00
|
|
|
# 加载并预处理数据集
|
|
|
|
|
def load_and_preprocess_dataset(path, label2id, tokenizer, max_length, is_test=False):
|
|
|
|
|
try:
|
|
|
|
|
dataset = load_dataset(read_local_dataset, path=path, label2id=label2id, lazy=True, is_test=is_test)
|
|
|
|
|
trans_func = functools.partial(preprocess_function, tokenizer=tokenizer, max_length=max_length, is_test=is_test)
|
|
|
|
|
return dataset.map(trans_func)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise ValueError(f"加载和预处理数据集时出错: {str(e)}")
|
2025-02-21 16:52:03 +08:00
|
|
|
|
2025-03-03 13:21:11 +08:00
|
|
|
# 加载测试数据集
|
|
|
|
|
test_ds = load_and_preprocess_dataset(config["test"], label_id, tokenizer, max_length=256, is_test=False)
|
|
|
|
|
|
|
|
|
|
# 使用 DataLoader 进行批量处理
|
2025-02-21 16:52:03 +08:00
|
|
|
data_collator = DataCollatorWithPadding(tokenizer)
|
2025-03-03 13:21:11 +08:00
|
|
|
test_dataloader = DataLoader(test_ds, batch_size=16, collate_fn=data_collator)
|
2025-02-21 16:52:03 +08:00
|
|
|
|
|
|
|
|
# 评估模型
|
2025-03-03 13:21:11 +08:00
|
|
|
model.eval()
|
2025-02-21 16:52:03 +08:00
|
|
|
all_preds = []
|
|
|
|
|
all_labels = []
|
|
|
|
|
|
|
|
|
|
# 遍历数据集进行推理
|
|
|
|
|
for batch in test_dataloader:
|
2025-03-03 13:21:11 +08:00
|
|
|
input_ids = paddle.to_tensor(batch["input_ids"])
|
|
|
|
|
attention_mask = paddle.to_tensor(batch["attention_mask"])
|
|
|
|
|
|
|
|
|
|
logits = model(input_ids, attention_mask=attention_mask)
|
|
|
|
|
probs = F.softmax(logits, axis=1)
|
|
|
|
|
pred_labels = paddle.argmax(probs, axis=1).numpy()
|
2025-02-21 16:52:03 +08:00
|
|
|
|
|
|
|
|
all_preds.extend(pred_labels)
|
2025-03-03 13:21:11 +08:00
|
|
|
all_labels.extend(batch["labels"].numpy())
|
2025-02-21 16:52:03 +08:00
|
|
|
|
|
|
|
|
# 计算评估指标
|
|
|
|
|
accuracy = accuracy_score(all_labels, all_preds)
|
|
|
|
|
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')
|
|
|
|
|
|
|
|
|
|
# 输出性能评估结果
|
|
|
|
|
print(f"Accuracy: {accuracy:.4f}")
|
|
|
|
|
print(f"Precision: {precision:.4f}")
|
|
|
|
|
print(f"Recall: {recall:.4f}")
|
|
|
|
|
print(f"F1 Score: {f1:.4f}")
|