62 lines
2.3 KiB
Python
62 lines
2.3 KiB
Python
|
|
import paddle
|
|||
|
|
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer
|
|||
|
|
|
|||
|
|
class SlotRecognition:
|
|||
|
|
def __init__(self, model_path: str, label_map: dict):
|
|||
|
|
"""
|
|||
|
|
初始化槽位识别模型和tokenizer
|
|||
|
|
:param model_path: 模型路径
|
|||
|
|
:param label_map: 标签映射字典
|
|||
|
|
"""
|
|||
|
|
self.model = ErnieForTokenClassification.from_pretrained(model_path)
|
|||
|
|
self.tokenizer = ErnieTokenizer.from_pretrained(model_path)
|
|||
|
|
self.label_map = label_map
|
|||
|
|
|
|||
|
|
def recognize(self, text: str):
|
|||
|
|
"""
|
|||
|
|
对输入的文本进行槽位识别,返回识别出的实体。
|
|||
|
|
:param text: 输入的文本
|
|||
|
|
:return: entities 字典,包含识别出的槽位实体
|
|||
|
|
"""
|
|||
|
|
# 处理输入文本
|
|||
|
|
inputs = self.tokenizer(text, max_length=512, return_tensors="pd")
|
|||
|
|
|
|||
|
|
# 使用无梯度计算
|
|||
|
|
with paddle.no_grad():
|
|||
|
|
logits = self.model(**inputs)
|
|||
|
|
predictions = paddle.argmax(logits, axis=-1)
|
|||
|
|
|
|||
|
|
# 解析预测结果
|
|||
|
|
predicted_labels = predictions.numpy()[0]
|
|||
|
|
tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].numpy())
|
|||
|
|
|
|||
|
|
entities = {}
|
|||
|
|
current_entity = None
|
|||
|
|
current_label = None
|
|||
|
|
|
|||
|
|
for token, label_id in zip(tokens, predicted_labels):
|
|||
|
|
label = self.label_map.get(label_id, "O")
|
|||
|
|
|
|||
|
|
if label.startswith("B-"): # 开始新实体
|
|||
|
|
if current_entity:
|
|||
|
|
entities[current_label] = "".join(current_entity)
|
|||
|
|
current_entity = [token]
|
|||
|
|
current_label = label[2:] # 去掉 B-
|
|||
|
|
|
|||
|
|
elif label.startswith("I-") and current_entity and label[2:] == current_label:
|
|||
|
|
current_entity.append(token) # 继续合并同一实体
|
|||
|
|
|
|||
|
|
else: # 非实体
|
|||
|
|
if current_entity:
|
|||
|
|
entities[current_label] = "".join(current_entity)
|
|||
|
|
current_entity = None
|
|||
|
|
current_label = None
|
|||
|
|
|
|||
|
|
# 处理最后一个实体
|
|||
|
|
if current_entity:
|
|||
|
|
entities[current_label] = "".join(current_entity)
|
|||
|
|
# 对所有实体进行替换:替换每个实体中的 '##' 为 ' '
|
|||
|
|
for key, value in entities.items():
|
|||
|
|
entities[key] = value.replace('#', '')
|
|||
|
|
return entities
|