Intention/uie/test_model.py

67 lines
2.3 KiB
Python
Raw Normal View History

2025-02-25 09:27:14 +08:00
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer
import paddle
# 1. 加载模型和 tokenizer
model_path = R"E:\workingSpace\PycharmProjects\Intention_dev\uie\uie_ner\checkpoint-4320" # 你的模型路径
model = ErnieForTokenClassification.from_pretrained(model_path)
tokenizer = ErnieTokenizer.from_pretrained(model_path)
# 2. 处理输入文本
text = "5月24日金上-湖北线路工程川12标风险等级为8级的工程作业内容是什么"
inputs = tokenizer(text, max_len=512, return_tensors="pd")
# 3. 进行预测
model.eval()
with paddle.no_grad():
logits = model(**inputs)
predictions = paddle.argmax(logits, axis=-1)
# 4. 标签映射
label_map = {
0: 'O', # 非实体
1: 'B-date', 11: 'I-date',
2: 'B-project_name', 12: 'I-project_name',
3: 'B-project_type', 13: 'I-project_type',
4: 'B-construction_unit', 14: 'I-construction_unit',
5: 'B-implementation_organization', 15: 'I-implementation_organization',
6: 'B-project_department', 16: 'I-project_department',
7: 'B-project_manager', 17: 'I-project_manager',
8: 'B-subcontractor', 18: 'I-subcontractor',
9: 'B-team_leader', 19: 'I-team_leader',
10: 'B-risk_level', 20: 'I-risk_level'
}
# 5. 解析预测结果
predicted_labels = predictions.numpy()[0]
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].numpy())
entities = []
current_entity = None
current_label = None
for token, label_id in zip(tokens, predicted_labels):
label = label_map.get(label_id, "O")
if label.startswith("B-"): # 开始新实体
if current_entity:
entities.append({"text": "".join(current_entity), "label": current_label})
current_entity = [token]
current_label = label[2:] # 去掉 B-
elif label.startswith("I-") and current_entity and label[2:] == current_label:
current_entity.append(token) # 继续合并同一实体
else: # 非实体
if current_entity:
entities.append({"text": "".join(current_entity), "label": current_label})
current_entity = None
current_label = None
# 处理最后一个实体
if current_entity:
entities.append({"text": "".join(current_entity), "label": current_label})
# 输出最终的实体
for entity in entities:
print(f"Entity: {entity['text']}, Label: {entity['label']}")