from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer import paddle # 1. 加载模型和 tokenizer model_path = R"E:\workingSpace\PycharmProjects\Intention_dev\uie\output\checkpoint-2440" # 你的模型路径 model = ErnieForTokenClassification.from_pretrained(model_path) tokenizer = ErnieTokenizer.from_pretrained(model_path) # 2. 处理输入文本 text = "李四班组今天有多少作业计划" inputs = tokenizer(text, max_len=512, return_tensors="pd") # 3. 进行预测 model.eval() with paddle.no_grad(): logits = model(**inputs) predictions = paddle.argmax(logits, axis=-1) # 4. 标签映射 label_map = { 0: 'O', # 非实体 1: 'B-date', 14: 'I-date', 2: 'B-projectName', 15: 'I-projectName', 3: 'B-projectType', 16: 'I-projectType', 4: 'B-constructionUnit', 17: 'I-constructionUnit', 5: 'B-implementationOrganization', 18: 'I-implementationOrganization', 6: 'B-projectDepartment', 19: 'I-projectDepartment', 7: 'B-projectManager', 20: 'I-projectManager', 8: 'B-subcontractor', 21: 'I-subcontractor', 9: 'B-teamLeader', 22: 'I-teamLeader', 10: 'B-riskLevel', 23: 'I-riskLevel', 11: 'B-page', 24: 'I-page', 12: 'B-operating', 25: 'I-operating', 13: 'B-teamName', 26: 'I-teamName', } # 5. 解析预测结果 predicted_labels = predictions.numpy()[0] tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].numpy()) entities = [] current_entity = None current_label = None for token, label_id in zip(tokens, predicted_labels): print(label_id) label = label_map.get(label_id, "O") if label.startswith("B-"): # 开始新实体 if current_entity: entities.append({"text": "".join(current_entity), "label": current_label}) current_entity = [token] current_label = label[2:] # 去掉 B- elif label.startswith("I-") and current_entity and label[2:] == current_label: current_entity.append(token) # 继续合并同一实体 else: # 非实体 if current_entity: entities.append({"text": "".join(current_entity), "label": current_label}) current_entity = None current_label = None # 处理最后一个实体 if current_entity: entities.append({"text": "".join(current_entity), "label": current_label}) # 输出最终的实体 for entity in entities: print(f"Entity: {entity['text']}, Label: {entity['label']}")