必须槽位检查

This commit is contained in:
weiweiw 2025-02-27 16:33:26 +08:00
parent a20f513d38
commit 9109890fca
2 changed files with 48 additions and 36 deletions

View File

@ -9,12 +9,12 @@ import paddle.nn.functional as F # 用于 Softmax
from typing import List, Dict
from pydantic import ValidationError
from api.intentRecognition import IntentRecognition
from api.slotRecognition import SlotRecognition
from intentRecognition import IntentRecognition
from slotRecognition import SlotRecognition
# 常量
MODEL_ERNIE_PATH = R"E:\workingSpace\PycharmProjects\Intention_dev\ernie\output\checkpoint-4160"
MODEL_UIE_PATH = R"E:\workingSpace\PycharmProjects\Intention_dev\uie\uie_ner\checkpoint-4320"
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-4160"
MODEL_UIE_PATH = R"../uie/output/checkpoint-4320"
# 类别名称列表
labels = [
@ -36,7 +36,6 @@ label_map = {
10: 'B-riskLevel', 20: 'I-riskLevel'
}
# 初始化工具类
intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
@ -45,6 +44,7 @@ slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map)
# 设置Flask应用
app = Flask(__name__)
# 统一的异常处理函数
@app.errorhandler(Exception)
def handle_exception(e):
@ -117,7 +117,7 @@ def intent_reco():
return user_validation_error
# 调用predict方法进行意图识别
predicted_label, predicted_probability,predicted_id = intent_recognizer.predict(text)
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(text)
return jsonify(
code=200,
@ -190,19 +190,30 @@ def agent():
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query)
# 再进行槽位抽取
entities = slot_recognizer.recognize(query)
status, sk = check_lost(predicted_label, entities)
# 返回意图和槽位识别的结果
return jsonify({
"code": 200,
"msg": "成功",
"answer": {
"int": predicted_id,
"label": predicted_label,
"probability": predicted_probability,
"slot": entities
},
})
print(f"意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}")
#必须槽位缺失检查
status, sk = check_lost(predicted_id, entities)
if status == 1:
return jsonify({
"code": 200,
"msg": "成功",
"answer": {
"miss": sk
},
})
else:
return jsonify({
"code": 200,
"msg": "成功",
"answer": {
"int": predicted_id,
"label": predicted_label,
"probability": predicted_probability,
"slot": entities
},
})
# 如果是后续轮次(多轮对话),这里只做示例,可能需要根据具体需求进行处理
else:
@ -219,27 +230,24 @@ def agent():
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
# def check_lost(int_res, slot):
# return 0, ""
def check_lost(int_res, slot):
# mapping = {
# "页面切换":[['页面','应用']],
# "作业计划数量查询":[['时间']],
# "周计划查询":[['时间']],
# "作业内容":[['时间']],
# "施工人数":[['时间']],
# "作业考勤人数":[['时间']],
# }
#labels: ["天气查询","互联网查询","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"]
mapping = {
1: [['date', 'area']],
3: [['page'], ['app'], ['module']],
2: [['page'], ['app'], ['module']],
3: [['date']],
4: [['date']],
5: [['date']],
6: [['date']],
7: [['date']],
8: [[]],
9: [[]],
8: [['date']],
}
#3:"页面切换",
intention_mapping = {2: "页面切换", 3: "日计划数量查询", 4: "周计划数量查询", 5: "日计划作业内容",
6: "周计划作业内容",7: "施工人数",8: "作业考勤人数"}
if not mapping.__contains__(int_res):
return 0, []
return 0, ""
cur_k = list(slot.keys())
idx = -1
idx_len = 99
@ -258,8 +266,12 @@ def check_lost(int_res, slot):
if idx_len == 0: # 匹配通过
return 0, cur_k
left = [x for x in mapping[int_res][idx] if x not in cur_k]
return 1, left # mapping[int_res][idx]
apologize_str = "非常抱歉,"
if int_res == 2:
return 1, f"{apologize_str}请问你想查询哪个页面?"
elif int_res in [3, 4, 5, 6, 7, 8]:
return 1, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}"
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True)
app.run(host='0.0.0.0', port=18074, debug=True)

View File

@ -63,8 +63,8 @@ def preprocess_function(example, tokenizer):
# === 3. 加载 UIE 预训练模型 ===
model = ErnieForTokenClassification.from_pretrained("uie-base", num_classes=21) # 3 类 (O, B, I)
tokenizer = ErnieTokenizer.from_pretrained("uie-base")
model = ErnieForTokenClassification.from_pretrained(r"/mnt/d/weiweiwang/intention/models/uie-base", num_classes=21) # 3 类 (O, B, I)
tokenizer = ErnieTokenizer.from_pretrained(r"/mnt/d/weiweiwang/intention/models/uie-base")
# === 4. 加载数据集 ===
train_dataset = load_dataset("data/data_part1.json") # 训练数据集
@ -81,7 +81,7 @@ data_collator = DataCollatorForTokenClassification(tokenizer, padding=True)
# === 7. 训练参数 ===
training_args = TrainingArguments(
output_dir="./uie_ner",
output_dir="./output",
evaluation_strategy="epoch",
save_strategy="epoch",
per_device_train_batch_size=16, # 你的显存较大,可调整 batch_size