Intention/uie/test_model.py

67 lines
2.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer
import paddle
# 1. 加载模型和 tokenizer
model_path = R"E:\workingSpace\PycharmProjects\Intention_dev\uie\uie_ner\checkpoint-4320" # 你的模型路径
model = ErnieForTokenClassification.from_pretrained(model_path)
tokenizer = ErnieTokenizer.from_pretrained(model_path)
# 2. 处理输入文本
text = "5月24日金上-湖北线路工程川12标风险等级为8级的工程作业内容是什么"
inputs = tokenizer(text, max_len=512, return_tensors="pd")
# 3. 进行预测
model.eval()
with paddle.no_grad():
logits = model(**inputs)
predictions = paddle.argmax(logits, axis=-1)
# 4. 标签映射
label_map = {
0: 'O', # 非实体
1: 'B-date', 11: 'I-date',
2: 'B-project_name', 12: 'I-project_name',
3: 'B-project_type', 13: 'I-project_type',
4: 'B-construction_unit', 14: 'I-construction_unit',
5: 'B-implementation_organization', 15: 'I-implementation_organization',
6: 'B-project_department', 16: 'I-project_department',
7: 'B-project_manager', 17: 'I-project_manager',
8: 'B-subcontractor', 18: 'I-subcontractor',
9: 'B-team_leader', 19: 'I-team_leader',
10: 'B-risk_level', 20: 'I-risk_level'
}
# 5. 解析预测结果
predicted_labels = predictions.numpy()[0]
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].numpy())
entities = []
current_entity = None
current_label = None
for token, label_id in zip(tokens, predicted_labels):
label = label_map.get(label_id, "O")
if label.startswith("B-"): # 开始新实体
if current_entity:
entities.append({"text": "".join(current_entity), "label": current_label})
current_entity = [token]
current_label = label[2:] # 去掉 B-
elif label.startswith("I-") and current_entity and label[2:] == current_label:
current_entity.append(token) # 继续合并同一实体
else: # 非实体
if current_entity:
entities.append({"text": "".join(current_entity), "label": current_label})
current_entity = None
current_label = None
# 处理最后一个实体
if current_entity:
entities.append({"text": "".join(current_entity), "label": current_label})
# 输出最终的实体
for entity in entities:
print(f"Entity: {entity['text']}, Label: {entity['label']}")