Intention/api/slotRecognition.py

190 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import paddle
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer
from globalData import GlobalData
from utils import standardize_name_only_high_score, clean_useless_company_name
from constants import SUBCONTRACTOR, CONSTRUCTION_UNIT, IMPLEMENTATION_ORG, PAGE, PROGRAM_NAVIGATION
import paddle.nn.functional as F
class SlotRecognition:
def __init__(self, model_path: str, label_map: dict):
"""
初始化槽位识别模型和tokenizer
:param model_path: 模型路径
:param label_map: 标签映射字典
"""
self.model = ErnieForTokenClassification.from_pretrained(model_path)
self.tokenizer = ErnieTokenizer.from_pretrained(model_path)
self.label_map = label_map
def recognize(self, text: str):
"""
对输入的文本进行槽位识别,返回识别出的实体。
:param text: 输入的文本
:return: entities 字典,包含识别出的槽位实体
"""
# 处理输入文本
inputs = self.tokenizer(text, max_length=512, return_tensors="pd")
# 使用无梯度计算
with paddle.no_grad():
logits = self.model(**inputs)
predictions = paddle.argmax(logits, axis=-1)
# 解析预测结果
predicted_labels = predictions.numpy()[0]
tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].numpy())
entities = {}
current_entity = None
current_label = None
for token, label_id in zip(tokens, predicted_labels):
label = self.label_map.get(label_id, "O")
if label.startswith("B-"): # 开始新实体
if current_entity:
entities[current_label] = "".join(current_entity)
current_entity = [token]
current_label = label[2:] # 去掉 B-
elif label.startswith("I-") and current_entity and label[2:] == current_label:
current_entity.append(token) # 继续合并同一实体
else: # 非实体
if current_entity:
entities[current_label] = "".join(current_entity)
current_entity = None
current_label = None
# 处理最后一个实体
if current_entity:
entities[current_label] = "".join(current_entity)
# 对所有实体进行替换:替换每个实体中的 '##' 为 ' '
# for key, value in entities.items():
# entities[key] = value.replace('#', '')
updates = {}
for key, value in entities.items():
#对所有实体进行替换:替换每个实体中的 '##' 为 ' '
entities[key] = value.replace('#', '')
#暂时不支持分包商和监管单位的查询
if (key == SUBCONTRACTOR or key == CONSTRUCTION_UNIT) and ("宏源" in value or "宏远" in value):
updates[IMPLEMENTATION_ORG] = value # 统一映射到 IMPLEMENTATION_ORG
else:
updates[key] = value # 保留原 key
# 更新 entities
entities.clear()
entities.update(updates)
return entities
# def recognize_probability(self, text: str):
def recognize_probability(self, text: str):
"""
对输入的文本进行槽位识别,返回识别出的实体及其概率。
:param text: 输入的文本
:return: (entities, slot_probabilities)
- entities: dict, 槽位名 -> 实体内容
- slot_probabilities: dict, 槽位名 -> 该实体平均概率0-1之间
"""
# 处理输入文本
inputs = self.tokenizer(text, max_length=512, return_tensors="pd")
with paddle.no_grad():
logits = self.model(**inputs) # shape: (1, seq_len, num_labels)
probs = F.softmax(logits, axis=-1) # 获取概率分布
predictions = paddle.argmax(probs, axis=-1) # shape: (1, seq_len)
predicted_labels = predictions.numpy()[0] # 标签ID
probs_np = probs.numpy()[0] # shape: (seq_len, num_labels)
tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].numpy())
entities = {}
slot_probabilities = {}
current_entity = []
current_label = None
current_probs = []
for idx, (token, label_id) in enumerate(zip(tokens, predicted_labels)):
label = self.label_map.get(label_id, "O")
if label.startswith("B-"): # 开始一个新实体
if current_entity:
ent_text = "".join(current_entity).replace('#', '')
avg_prob = float(sum(current_probs) / len(current_probs))
entities[current_label] = ent_text
slot_probabilities[current_label] = round(avg_prob, 4)
current_entity = [token]
current_label = label[2:]
current_probs = [probs_np[idx][label_id]]
elif label.startswith("I-") and current_entity and label[2:] == current_label:
current_entity.append(token)
current_probs.append(probs_np[idx][label_id])
else: # 非实体
if current_entity:
ent_text = "".join(current_entity).replace('#', '')
avg_prob = float(sum(current_probs) / len(current_probs))
entities[current_label] = ent_text
slot_probabilities[current_label] = round(avg_prob, 4)
current_entity = []
current_label = None
current_probs = []
# 处理最后一个实体
if current_entity:
ent_text = "".join(current_entity).replace('#', '')
avg_prob = float(sum(current_probs) / len(current_probs))
entities[current_label] = ent_text
slot_probabilities[current_label] = round(avg_prob, 4)
# 后处理与更新逻辑
updates = {}
prob_updates = {}
for key, value in entities.items():
value = value.replace('#', '')
# 暂时不支持分包商和监管单位的查询
if key == SUBCONTRACTOR or key == CONSTRUCTION_UNIT or key == IMPLEMENTATION_ORG:
# print(f"recognize_probability- key:{key}value:{value}")
match_results = standardize_name_only_high_score(value,clean_useless_company_name, GlobalData.simply_to_standard_company_name_map, GlobalData.pinyin_simply_to_standard_company_name_map, 90)
if match_results:
updates[IMPLEMENTATION_ORG] = value
# print(f"recognize_probability-key:{IMPLEMENTATION_ORG}value:{value}")
prob_updates[IMPLEMENTATION_ORG] = slot_probabilities[key]
else:
match_results = standardize_name_only_high_score(value,clean_useless_company_name, GlobalData.simply_to_standard_construct_name_map, GlobalData.pinyin_simply_to_standard_construct_name_map, 90)
if match_results:
updates[CONSTRUCTION_UNIT] = value
# print(f"recognize_probability-key:{CONSTRUCTION_UNIT}value:{value}")
prob_updates[CONSTRUCTION_UNIT] = slot_probabilities[key]
else:
match_results = standardize_name_only_high_score(value,clean_useless_company_name, GlobalData.simply_to_standard_constractor_name_map, GlobalData.pinyin_simply_to_standard_constractor_name_map, 90)
if match_results:
updates[SUBCONTRACTOR] = value
# print(f"recognize_probability-key:{SUBCONTRACTOR}value:{value}")
prob_updates[SUBCONTRACTOR] = slot_probabilities[key]
else:
updates[key] = value
prob_updates[key] = slot_probabilities[key]
elif key == PROGRAM_NAVIGATION or key == PAGE:
if "" in value:
updates[key] = "施工生产管理平台"
else:
updates[key] = value
prob_updates[key] = slot_probabilities[key]
else:
updates[key] = value
prob_updates[key] = slot_probabilities[key]
entities.clear()
slot_probabilities.clear()
entities.update(updates)
slot_probabilities.update(prob_updates)
return entities, slot_probabilities