2025-02-25 09:27:14 +08:00
|
|
|
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer
|
|
|
|
|
import paddle
|
|
|
|
|
|
|
|
|
|
# 1. 加载模型和 tokenizer
|
2025-03-16 14:40:56 +08:00
|
|
|
model_path = R"E:\workingSpace\PycharmProjects\Intention_dev\uie\output\checkpoint-2440" # 你的模型路径
|
2025-02-25 09:27:14 +08:00
|
|
|
model = ErnieForTokenClassification.from_pretrained(model_path)
|
|
|
|
|
tokenizer = ErnieTokenizer.from_pretrained(model_path)
|
|
|
|
|
|
|
|
|
|
# 2. 处理输入文本
|
2025-03-16 14:40:56 +08:00
|
|
|
text = "李四班组今天有多少作业计划"
|
2025-02-25 09:27:14 +08:00
|
|
|
inputs = tokenizer(text, max_len=512, return_tensors="pd")
|
|
|
|
|
|
|
|
|
|
# 3. 进行预测
|
|
|
|
|
model.eval()
|
|
|
|
|
with paddle.no_grad():
|
|
|
|
|
logits = model(**inputs)
|
|
|
|
|
predictions = paddle.argmax(logits, axis=-1)
|
|
|
|
|
|
|
|
|
|
# 4. 标签映射
|
|
|
|
|
label_map = {
|
|
|
|
|
0: 'O', # 非实体
|
2025-03-16 14:40:56 +08:00
|
|
|
1: 'B-date', 14: 'I-date',
|
|
|
|
|
2: 'B-projectName', 15: 'I-projectName',
|
|
|
|
|
3: 'B-projectType', 16: 'I-projectType',
|
|
|
|
|
4: 'B-constructionUnit', 17: 'I-constructionUnit',
|
|
|
|
|
5: 'B-implementationOrganization', 18: 'I-implementationOrganization',
|
|
|
|
|
6: 'B-projectDepartment', 19: 'I-projectDepartment',
|
|
|
|
|
7: 'B-projectManager', 20: 'I-projectManager',
|
|
|
|
|
8: 'B-subcontractor', 21: 'I-subcontractor',
|
|
|
|
|
9: 'B-teamLeader', 22: 'I-teamLeader',
|
|
|
|
|
10: 'B-riskLevel', 23: 'I-riskLevel',
|
|
|
|
|
11: 'B-page', 24: 'I-page',
|
|
|
|
|
12: 'B-operating', 25: 'I-operating',
|
|
|
|
|
13: 'B-teamName', 26: 'I-teamName',
|
2025-02-25 09:27:14 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# 5. 解析预测结果
|
|
|
|
|
predicted_labels = predictions.numpy()[0]
|
|
|
|
|
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].numpy())
|
|
|
|
|
|
|
|
|
|
entities = []
|
|
|
|
|
current_entity = None
|
|
|
|
|
current_label = None
|
|
|
|
|
|
|
|
|
|
for token, label_id in zip(tokens, predicted_labels):
|
2025-03-16 14:40:56 +08:00
|
|
|
print(label_id)
|
2025-02-25 09:27:14 +08:00
|
|
|
label = label_map.get(label_id, "O")
|
|
|
|
|
|
|
|
|
|
if label.startswith("B-"): # 开始新实体
|
|
|
|
|
if current_entity:
|
|
|
|
|
entities.append({"text": "".join(current_entity), "label": current_label})
|
|
|
|
|
current_entity = [token]
|
|
|
|
|
current_label = label[2:] # 去掉 B-
|
|
|
|
|
|
|
|
|
|
elif label.startswith("I-") and current_entity and label[2:] == current_label:
|
|
|
|
|
current_entity.append(token) # 继续合并同一实体
|
|
|
|
|
|
|
|
|
|
else: # 非实体
|
|
|
|
|
if current_entity:
|
|
|
|
|
entities.append({"text": "".join(current_entity), "label": current_label})
|
|
|
|
|
current_entity = None
|
|
|
|
|
current_label = None
|
|
|
|
|
|
|
|
|
|
# 处理最后一个实体
|
|
|
|
|
if current_entity:
|
|
|
|
|
entities.append({"text": "".join(current_entity), "label": current_label})
|
|
|
|
|
|
|
|
|
|
# 输出最终的实体
|
|
|
|
|
for entity in entities:
|
|
|
|
|
print(f"Entity: {entity['text']}, Label: {entity['label']}")
|