import paddle from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer from globalData import GlobalData from utils import standardize_name_only_high_score, clean_useless_company_name from constants import SUBCONTRACTOR, CONSTRUCTION_UNIT, IMPLEMENTATION_ORG import paddle.nn.functional as F class SlotRecognition: def __init__(self, model_path: str, label_map: dict): """ 初始化槽位识别模型和tokenizer :param model_path: 模型路径 :param label_map: 标签映射字典 """ self.model = ErnieForTokenClassification.from_pretrained(model_path) self.tokenizer = ErnieTokenizer.from_pretrained(model_path) self.label_map = label_map def recognize(self, text: str): """ 对输入的文本进行槽位识别,返回识别出的实体。 :param text: 输入的文本 :return: entities 字典,包含识别出的槽位实体 """ # 处理输入文本 inputs = self.tokenizer(text, max_length=512, return_tensors="pd") # 使用无梯度计算 with paddle.no_grad(): logits = self.model(**inputs) predictions = paddle.argmax(logits, axis=-1) # 解析预测结果 predicted_labels = predictions.numpy()[0] tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].numpy()) entities = {} current_entity = None current_label = None for token, label_id in zip(tokens, predicted_labels): label = self.label_map.get(label_id, "O") if label.startswith("B-"): # 开始新实体 if current_entity: entities[current_label] = "".join(current_entity) current_entity = [token] current_label = label[2:] # 去掉 B- elif label.startswith("I-") and current_entity and label[2:] == current_label: current_entity.append(token) # 继续合并同一实体 else: # 非实体 if current_entity: entities[current_label] = "".join(current_entity) current_entity = None current_label = None # 处理最后一个实体 if current_entity: entities[current_label] = "".join(current_entity) # 对所有实体进行替换:替换每个实体中的 '##' 为 ' ' # for key, value in entities.items(): # entities[key] = value.replace('#', '') updates = {} for key, value in entities.items(): #对所有实体进行替换:替换每个实体中的 '##' 为 ' ' entities[key] = value.replace('#', '') #暂时不支持分包商和监管单位的查询 if (key == SUBCONTRACTOR or key == CONSTRUCTION_UNIT) and ("宏源" in value or "宏远" in value): updates[IMPLEMENTATION_ORG] = value # 统一映射到 IMPLEMENTATION_ORG else: updates[key] = value # 保留原 key # 更新 entities entities.clear() entities.update(updates) return entities # def recognize_probability(self, text: str): def recognize_probability(self, text: str): """ 对输入的文本进行槽位识别,返回识别出的实体及其概率。 :param text: 输入的文本 :return: (entities, slot_probabilities) - entities: dict, 槽位名 -> 实体内容 - slot_probabilities: dict, 槽位名 -> 该实体平均概率(0-1之间) """ # 处理输入文本 inputs = self.tokenizer(text, max_length=512, return_tensors="pd") with paddle.no_grad(): logits = self.model(**inputs) # shape: (1, seq_len, num_labels) probs = F.softmax(logits, axis=-1) # 获取概率分布 predictions = paddle.argmax(probs, axis=-1) # shape: (1, seq_len) predicted_labels = predictions.numpy()[0] # 标签ID probs_np = probs.numpy()[0] # shape: (seq_len, num_labels) tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].numpy()) entities = {} slot_probabilities = {} current_entity = [] current_label = None current_probs = [] for idx, (token, label_id) in enumerate(zip(tokens, predicted_labels)): label = self.label_map.get(label_id, "O") if label.startswith("B-"): # 开始一个新实体 if current_entity: ent_text = "".join(current_entity).replace('#', '') avg_prob = float(sum(current_probs) / len(current_probs)) entities[current_label] = ent_text slot_probabilities[current_label] = round(avg_prob, 4) current_entity = [token] current_label = label[2:] current_probs = [probs_np[idx][label_id]] elif label.startswith("I-") and current_entity and label[2:] == current_label: current_entity.append(token) current_probs.append(probs_np[idx][label_id]) else: # 非实体 if current_entity: ent_text = "".join(current_entity).replace('#', '') avg_prob = float(sum(current_probs) / len(current_probs)) entities[current_label] = ent_text slot_probabilities[current_label] = round(avg_prob, 4) current_entity = [] current_label = None current_probs = [] # 处理最后一个实体 if current_entity: ent_text = "".join(current_entity).replace('#', '') avg_prob = float(sum(current_probs) / len(current_probs)) entities[current_label] = ent_text slot_probabilities[current_label] = round(avg_prob, 4) # 后处理与更新逻辑 updates = {} prob_updates = {} for key, value in entities.items(): value = value.replace('#', '') # 暂时不支持分包商和监管单位的查询 if (key == SUBCONTRACTOR or key == CONSTRUCTION_UNIT or key == IMPLEMENTATION_ORG) : # print(f"recognize_probability- key:{key},value:{value}") match_results = standardize_name_only_high_score(value,clean_useless_company_name, GlobalData.simply_to_standard_company_name_map, GlobalData.pinyin_simply_to_standard_company_name_map, 90) if match_results: updates[IMPLEMENTATION_ORG] = value # print(f"recognize_probability-key:{IMPLEMENTATION_ORG},value:{value}") prob_updates[IMPLEMENTATION_ORG] = slot_probabilities[key] else: match_results = standardize_name_only_high_score(value,clean_useless_company_name, GlobalData.simply_to_standard_construct_name_map, GlobalData.pinyin_simply_to_standard_construct_name_map, 90) if match_results: updates[CONSTRUCTION_UNIT] = value # print(f"recognize_probability-key:{CONSTRUCTION_UNIT},value:{value}") prob_updates[CONSTRUCTION_UNIT] = slot_probabilities[key] else: match_results = standardize_name_only_high_score(value,clean_useless_company_name, GlobalData.simply_to_standard_constractor_name_map, GlobalData.pinyin_simply_to_standard_constractor_name_map, 90) if match_results: updates[SUBCONTRACTOR] = value # print(f"recognize_probability-key:{SUBCONTRACTOR},value:{value}") prob_updates[SUBCONTRACTOR] = slot_probabilities[key] else: updates[key] = value prob_updates[key] = slot_probabilities[key] else: updates[key] = value prob_updates[key] = slot_probabilities[key] entities.clear() slot_probabilities.clear() entities.update(updates) slot_probabilities.update(prob_updates) return entities, slot_probabilities