import paddle from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer from constants import SUBCONTRACTOR, CONSTRUCTION_UNIT, IMPLEMENTATION_ORG class SlotRecognition: def __init__(self, model_path: str, label_map: dict): """ 初始化槽位识别模型和tokenizer :param model_path: 模型路径 :param label_map: 标签映射字典 """ self.model = ErnieForTokenClassification.from_pretrained(model_path) self.tokenizer = ErnieTokenizer.from_pretrained(model_path) self.label_map = label_map def recognize(self, text: str): """ 对输入的文本进行槽位识别,返回识别出的实体。 :param text: 输入的文本 :return: entities 字典,包含识别出的槽位实体 """ # 处理输入文本 inputs = self.tokenizer(text, max_length=512, return_tensors="pd") # 使用无梯度计算 with paddle.no_grad(): logits = self.model(**inputs) predictions = paddle.argmax(logits, axis=-1) # 解析预测结果 predicted_labels = predictions.numpy()[0] tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].numpy()) entities = {} current_entity = None current_label = None for token, label_id in zip(tokens, predicted_labels): label = self.label_map.get(label_id, "O") if label.startswith("B-"): # 开始新实体 if current_entity: entities[current_label] = "".join(current_entity) current_entity = [token] current_label = label[2:] # 去掉 B- elif label.startswith("I-") and current_entity and label[2:] == current_label: current_entity.append(token) # 继续合并同一实体 else: # 非实体 if current_entity: entities[current_label] = "".join(current_entity) current_entity = None current_label = None # 处理最后一个实体 if current_entity: entities[current_label] = "".join(current_entity) # 对所有实体进行替换:替换每个实体中的 '##' 为 ' ' # for key, value in entities.items(): # entities[key] = value.replace('#', '') updates = {} for key, value in entities.items(): #对所有实体进行替换:替换每个实体中的 '##' 为 ' ' entities[key] = value.replace('#', '') #暂时不支持分包商和监管单位的查询 if key == SUBCONTRACTOR or key == CONSTRUCTION_UNIT: updates[IMPLEMENTATION_ORG] = value # 统一映射到 IMPLEMENTATION_ORG else: updates[key] = value # 保留原 key # 更新 entities entities.clear() entities.update(updates) return entities