Intention/api/slotRecognition.py

184 lines
8.1 KiB
Python
Raw Normal View History

2025-02-27 09:06:34 +08:00
import paddle
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer
from globalData import GlobalData
from utils import standardize_name_only_high_score, clean_useless_company_name
from constants import SUBCONTRACTOR, CONSTRUCTION_UNIT, IMPLEMENTATION_ORG
import paddle.nn.functional as F
2025-02-27 09:06:34 +08:00
class SlotRecognition:
def __init__(self, model_path: str, label_map: dict):
"""
初始化槽位识别模型和tokenizer
:param model_path: 模型路径
:param label_map: 标签映射字典
"""
self.model = ErnieForTokenClassification.from_pretrained(model_path)
self.tokenizer = ErnieTokenizer.from_pretrained(model_path)
self.label_map = label_map
def recognize(self, text: str):
"""
对输入的文本进行槽位识别返回识别出的实体
:param text: 输入的文本
:return: entities 字典包含识别出的槽位实体
"""
# 处理输入文本
inputs = self.tokenizer(text, max_length=512, return_tensors="pd")
# 使用无梯度计算
with paddle.no_grad():
logits = self.model(**inputs)
predictions = paddle.argmax(logits, axis=-1)
# 解析预测结果
predicted_labels = predictions.numpy()[0]
tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].numpy())
entities = {}
current_entity = None
current_label = None
for token, label_id in zip(tokens, predicted_labels):
label = self.label_map.get(label_id, "O")
if label.startswith("B-"): # 开始新实体
if current_entity:
entities[current_label] = "".join(current_entity)
current_entity = [token]
current_label = label[2:] # 去掉 B-
elif label.startswith("I-") and current_entity and label[2:] == current_label:
current_entity.append(token) # 继续合并同一实体
else: # 非实体
if current_entity:
entities[current_label] = "".join(current_entity)
current_entity = None
current_label = None
# 处理最后一个实体
if current_entity:
entities[current_label] = "".join(current_entity)
# 对所有实体进行替换:替换每个实体中的 '##' 为 ' '
# for key, value in entities.items():
# entities[key] = value.replace('#', '')
updates = {}
2025-02-27 09:06:34 +08:00
for key, value in entities.items():
#对所有实体进行替换:替换每个实体中的 '##' 为 ' '
2025-02-27 09:06:34 +08:00
entities[key] = value.replace('#', '')
#暂时不支持分包商和监管单位的查询
2025-04-17 09:11:53 +08:00
if (key == SUBCONTRACTOR or key == CONSTRUCTION_UNIT) and ("宏源" in value or "宏远" in value):
updates[IMPLEMENTATION_ORG] = value # 统一映射到 IMPLEMENTATION_ORG
else:
updates[key] = value # 保留原 key
# 更新 entities
entities.clear()
entities.update(updates)
return entities
# def recognize_probability(self, text: str):
def recognize_probability(self, text: str):
"""
对输入的文本进行槽位识别返回识别出的实体及其概率
:param text: 输入的文本
:return: (entities, slot_probabilities)
- entities: dict, 槽位名 -> 实体内容
- slot_probabilities: dict, 槽位名 -> 该实体平均概率0-1之间
"""
# 处理输入文本
inputs = self.tokenizer(text, max_length=512, return_tensors="pd")
with paddle.no_grad():
logits = self.model(**inputs) # shape: (1, seq_len, num_labels)
probs = F.softmax(logits, axis=-1) # 获取概率分布
predictions = paddle.argmax(probs, axis=-1) # shape: (1, seq_len)
predicted_labels = predictions.numpy()[0] # 标签ID
probs_np = probs.numpy()[0] # shape: (seq_len, num_labels)
tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].numpy())
entities = {}
slot_probabilities = {}
current_entity = []
current_label = None
current_probs = []
for idx, (token, label_id) in enumerate(zip(tokens, predicted_labels)):
label = self.label_map.get(label_id, "O")
if label.startswith("B-"): # 开始一个新实体
if current_entity:
ent_text = "".join(current_entity).replace('#', '')
avg_prob = float(sum(current_probs) / len(current_probs))
entities[current_label] = ent_text
slot_probabilities[current_label] = round(avg_prob, 4)
current_entity = [token]
current_label = label[2:]
current_probs = [probs_np[idx][label_id]]
elif label.startswith("I-") and current_entity and label[2:] == current_label:
current_entity.append(token)
current_probs.append(probs_np[idx][label_id])
else: # 非实体
if current_entity:
ent_text = "".join(current_entity).replace('#', '')
avg_prob = float(sum(current_probs) / len(current_probs))
entities[current_label] = ent_text
slot_probabilities[current_label] = round(avg_prob, 4)
current_entity = []
current_label = None
current_probs = []
# 处理最后一个实体
if current_entity:
ent_text = "".join(current_entity).replace('#', '')
avg_prob = float(sum(current_probs) / len(current_probs))
entities[current_label] = ent_text
slot_probabilities[current_label] = round(avg_prob, 4)
# 后处理与更新逻辑
updates = {}
prob_updates = {}
for key, value in entities.items():
value = value.replace('#', '')
# 暂时不支持分包商和监管单位的查询
if (key == SUBCONTRACTOR or key == CONSTRUCTION_UNIT or key == IMPLEMENTATION_ORG) :
print(f"recognize_probability- key:{key}value:{value}")
match_results = standardize_name_only_high_score(value,clean_useless_company_name, GlobalData.simply_to_standard_company_name_map, GlobalData.pinyin_simply_to_standard_company_name_map, 90)
if match_results:
updates[IMPLEMENTATION_ORG] = value
print(f"recognize_probability-key:{IMPLEMENTATION_ORG}value:{value}")
prob_updates[IMPLEMENTATION_ORG] = slot_probabilities[key]
else:
match_results = standardize_name_only_high_score(value,clean_useless_company_name, GlobalData.simply_to_standard_construct_name_map, GlobalData.pinyin_simply_to_standard_construct_name_map, 90)
if match_results:
updates[CONSTRUCTION_UNIT] = value
print(f"recognize_probability-key:{CONSTRUCTION_UNIT}value:{value}")
prob_updates[CONSTRUCTION_UNIT] = slot_probabilities[key]
else:
match_results = standardize_name_only_high_score(value,clean_useless_company_name, GlobalData.simply_to_standard_constractor_name_map, GlobalData.pinyin_simply_to_standard_constractor_name_map, 90)
if match_results:
updates[SUBCONTRACTOR] = value
print(f"recognize_probability-key:{SUBCONTRACTOR}value:{value}")
prob_updates[SUBCONTRACTOR] = slot_probabilities[key]
else:
updates[key] = value
prob_updates[key] = slot_probabilities[key]
else:
updates[key] = value
prob_updates[key] = slot_probabilities[key]
entities.clear()
slot_probabilities.clear()
entities.update(updates)
slot_probabilities.update(prob_updates)
return entities, slot_probabilities