190 lines
8.5 KiB
Python
190 lines
8.5 KiB
Python
import paddle
|
||
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer
|
||
|
||
from globalData import GlobalData
|
||
from utils import standardize_name_only_high_score, clean_useless_company_name
|
||
from constants import SUBCONTRACTOR, CONSTRUCTION_UNIT, IMPLEMENTATION_ORG, PAGE, PROGRAM_NAVIGATION
|
||
import paddle.nn.functional as F
|
||
|
||
|
||
class SlotRecognition:
|
||
def __init__(self, model_path: str, label_map: dict):
|
||
"""
|
||
初始化槽位识别模型和tokenizer
|
||
:param model_path: 模型路径
|
||
:param label_map: 标签映射字典
|
||
"""
|
||
self.model = ErnieForTokenClassification.from_pretrained(model_path)
|
||
self.tokenizer = ErnieTokenizer.from_pretrained(model_path)
|
||
self.label_map = label_map
|
||
|
||
def recognize(self, text: str):
|
||
"""
|
||
对输入的文本进行槽位识别,返回识别出的实体。
|
||
:param text: 输入的文本
|
||
:return: entities 字典,包含识别出的槽位实体
|
||
"""
|
||
# 处理输入文本
|
||
inputs = self.tokenizer(text, max_length=512, return_tensors="pd")
|
||
|
||
# 使用无梯度计算
|
||
with paddle.no_grad():
|
||
logits = self.model(**inputs)
|
||
predictions = paddle.argmax(logits, axis=-1)
|
||
|
||
# 解析预测结果
|
||
predicted_labels = predictions.numpy()[0]
|
||
tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].numpy())
|
||
|
||
entities = {}
|
||
current_entity = None
|
||
current_label = None
|
||
|
||
for token, label_id in zip(tokens, predicted_labels):
|
||
label = self.label_map.get(label_id, "O")
|
||
|
||
if label.startswith("B-"): # 开始新实体
|
||
if current_entity:
|
||
entities[current_label] = "".join(current_entity)
|
||
current_entity = [token]
|
||
current_label = label[2:] # 去掉 B-
|
||
|
||
elif label.startswith("I-") and current_entity and label[2:] == current_label:
|
||
current_entity.append(token) # 继续合并同一实体
|
||
|
||
else: # 非实体
|
||
if current_entity:
|
||
entities[current_label] = "".join(current_entity)
|
||
current_entity = None
|
||
current_label = None
|
||
|
||
# 处理最后一个实体
|
||
if current_entity:
|
||
entities[current_label] = "".join(current_entity)
|
||
# 对所有实体进行替换:替换每个实体中的 '##' 为 ' '
|
||
# for key, value in entities.items():
|
||
# entities[key] = value.replace('#', '')
|
||
updates = {}
|
||
for key, value in entities.items():
|
||
#对所有实体进行替换:替换每个实体中的 '##' 为 ' '
|
||
entities[key] = value.replace('#', '')
|
||
#暂时不支持分包商和监管单位的查询
|
||
if (key == SUBCONTRACTOR or key == CONSTRUCTION_UNIT) and ("宏源" in value or "宏远" in value):
|
||
updates[IMPLEMENTATION_ORG] = value # 统一映射到 IMPLEMENTATION_ORG
|
||
else:
|
||
updates[key] = value # 保留原 key
|
||
# 更新 entities
|
||
entities.clear()
|
||
entities.update(updates)
|
||
return entities
|
||
|
||
# def recognize_probability(self, text: str):
|
||
|
||
def recognize_probability(self, text: str):
|
||
"""
|
||
对输入的文本进行槽位识别,返回识别出的实体及其概率。
|
||
:param text: 输入的文本
|
||
:return: (entities, slot_probabilities)
|
||
- entities: dict, 槽位名 -> 实体内容
|
||
- slot_probabilities: dict, 槽位名 -> 该实体平均概率(0-1之间)
|
||
"""
|
||
# 处理输入文本
|
||
inputs = self.tokenizer(text, max_length=512, return_tensors="pd")
|
||
|
||
with paddle.no_grad():
|
||
logits = self.model(**inputs) # shape: (1, seq_len, num_labels)
|
||
probs = F.softmax(logits, axis=-1) # 获取概率分布
|
||
predictions = paddle.argmax(probs, axis=-1) # shape: (1, seq_len)
|
||
|
||
predicted_labels = predictions.numpy()[0] # 标签ID
|
||
probs_np = probs.numpy()[0] # shape: (seq_len, num_labels)
|
||
tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].numpy())
|
||
|
||
entities = {}
|
||
slot_probabilities = {}
|
||
|
||
current_entity = []
|
||
current_label = None
|
||
current_probs = []
|
||
|
||
for idx, (token, label_id) in enumerate(zip(tokens, predicted_labels)):
|
||
label = self.label_map.get(label_id, "O")
|
||
|
||
if label.startswith("B-"): # 开始一个新实体
|
||
if current_entity:
|
||
ent_text = "".join(current_entity).replace('#', '')
|
||
avg_prob = float(sum(current_probs) / len(current_probs))
|
||
entities[current_label] = ent_text
|
||
slot_probabilities[current_label] = round(avg_prob, 4)
|
||
|
||
current_entity = [token]
|
||
current_label = label[2:]
|
||
current_probs = [probs_np[idx][label_id]]
|
||
|
||
elif label.startswith("I-") and current_entity and label[2:] == current_label:
|
||
current_entity.append(token)
|
||
current_probs.append(probs_np[idx][label_id])
|
||
|
||
else: # 非实体
|
||
if current_entity:
|
||
ent_text = "".join(current_entity).replace('#', '')
|
||
avg_prob = float(sum(current_probs) / len(current_probs))
|
||
entities[current_label] = ent_text
|
||
slot_probabilities[current_label] = round(avg_prob, 4)
|
||
|
||
current_entity = []
|
||
current_label = None
|
||
current_probs = []
|
||
|
||
# 处理最后一个实体
|
||
if current_entity:
|
||
ent_text = "".join(current_entity).replace('#', '')
|
||
avg_prob = float(sum(current_probs) / len(current_probs))
|
||
entities[current_label] = ent_text
|
||
slot_probabilities[current_label] = round(avg_prob, 4)
|
||
|
||
# 后处理与更新逻辑
|
||
updates = {}
|
||
prob_updates = {}
|
||
for key, value in entities.items():
|
||
value = value.replace('#', '')
|
||
# 暂时不支持分包商和监管单位的查询
|
||
if key == SUBCONTRACTOR or key == CONSTRUCTION_UNIT or key == IMPLEMENTATION_ORG:
|
||
# print(f"recognize_probability- key:{key},value:{value}")
|
||
match_results = standardize_name_only_high_score(value,clean_useless_company_name, GlobalData.simply_to_standard_company_name_map, GlobalData.pinyin_simply_to_standard_company_name_map, 90)
|
||
if match_results:
|
||
updates[IMPLEMENTATION_ORG] = value
|
||
# print(f"recognize_probability-key:{IMPLEMENTATION_ORG},value:{value}")
|
||
prob_updates[IMPLEMENTATION_ORG] = slot_probabilities[key]
|
||
else:
|
||
match_results = standardize_name_only_high_score(value,clean_useless_company_name, GlobalData.simply_to_standard_construct_name_map, GlobalData.pinyin_simply_to_standard_construct_name_map, 90)
|
||
if match_results:
|
||
updates[CONSTRUCTION_UNIT] = value
|
||
# print(f"recognize_probability-key:{CONSTRUCTION_UNIT},value:{value}")
|
||
prob_updates[CONSTRUCTION_UNIT] = slot_probabilities[key]
|
||
else:
|
||
match_results = standardize_name_only_high_score(value,clean_useless_company_name, GlobalData.simply_to_standard_constractor_name_map, GlobalData.pinyin_simply_to_standard_constractor_name_map, 90)
|
||
if match_results:
|
||
updates[SUBCONTRACTOR] = value
|
||
# print(f"recognize_probability-key:{SUBCONTRACTOR},value:{value}")
|
||
prob_updates[SUBCONTRACTOR] = slot_probabilities[key]
|
||
else:
|
||
updates[key] = value
|
||
prob_updates[key] = slot_probabilities[key]
|
||
elif key == PROGRAM_NAVIGATION or key == PAGE:
|
||
if "施" in value:
|
||
updates[key] = "施工生产管理平台"
|
||
else:
|
||
updates[key] = value
|
||
prob_updates[key] = slot_probabilities[key]
|
||
else:
|
||
updates[key] = value
|
||
prob_updates[key] = slot_probabilities[key]
|
||
|
||
entities.clear()
|
||
slot_probabilities.clear()
|
||
entities.update(updates)
|
||
slot_probabilities.update(prob_updates)
|
||
|
||
return entities, slot_probabilities
|