Intention/api/slotRecognition.py

62 lines
2.3 KiB
Python

import paddle
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer
class SlotRecognition:
def __init__(self, model_path: str, label_map: dict):
"""
初始化槽位识别模型和tokenizer
:param model_path: 模型路径
:param label_map: 标签映射字典
"""
self.model = ErnieForTokenClassification.from_pretrained(model_path)
self.tokenizer = ErnieTokenizer.from_pretrained(model_path)
self.label_map = label_map
def recognize(self, text: str):
"""
对输入的文本进行槽位识别,返回识别出的实体。
:param text: 输入的文本
:return: entities 字典,包含识别出的槽位实体
"""
# 处理输入文本
inputs = self.tokenizer(text, max_length=512, return_tensors="pd")
# 使用无梯度计算
with paddle.no_grad():
logits = self.model(**inputs)
predictions = paddle.argmax(logits, axis=-1)
# 解析预测结果
predicted_labels = predictions.numpy()[0]
tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].numpy())
entities = {}
current_entity = None
current_label = None
for token, label_id in zip(tokens, predicted_labels):
label = self.label_map.get(label_id, "O")
if label.startswith("B-"): # 开始新实体
if current_entity:
entities[current_label] = "".join(current_entity)
current_entity = [token]
current_label = label[2:] # 去掉 B-
elif label.startswith("I-") and current_entity and label[2:] == current_label:
current_entity.append(token) # 继续合并同一实体
else: # 非实体
if current_entity:
entities[current_label] = "".join(current_entity)
current_entity = None
current_label = None
# 处理最后一个实体
if current_entity:
entities[current_label] = "".join(current_entity)
# 对所有实体进行替换:替换每个实体中的 '##' 为 ' '
for key, value in entities.items():
entities[key] = value.replace('#', '')
return entities