必须槽位检查

This commit is contained in:
weiweiw 2025-02-27 16:33:26 +08:00
parent a20f513d38
commit 9109890fca
2 changed files with 48 additions and 36 deletions

View File

@ -9,12 +9,12 @@ import paddle.nn.functional as F # 用于 Softmax
from typing import List, Dict from typing import List, Dict
from pydantic import ValidationError from pydantic import ValidationError
from api.intentRecognition import IntentRecognition from intentRecognition import IntentRecognition
from api.slotRecognition import SlotRecognition from slotRecognition import SlotRecognition
# 常量 # 常量
MODEL_ERNIE_PATH = R"E:\workingSpace\PycharmProjects\Intention_dev\ernie\output\checkpoint-4160" MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-4160"
MODEL_UIE_PATH = R"E:\workingSpace\PycharmProjects\Intention_dev\uie\uie_ner\checkpoint-4320" MODEL_UIE_PATH = R"../uie/output/checkpoint-4320"
# 类别名称列表 # 类别名称列表
labels = [ labels = [
@ -36,7 +36,6 @@ label_map = {
10: 'B-riskLevel', 20: 'I-riskLevel' 10: 'B-riskLevel', 20: 'I-riskLevel'
} }
# 初始化工具类 # 初始化工具类
intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels) intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
@ -45,6 +44,7 @@ slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map)
# 设置Flask应用 # 设置Flask应用
app = Flask(__name__) app = Flask(__name__)
# 统一的异常处理函数 # 统一的异常处理函数
@app.errorhandler(Exception) @app.errorhandler(Exception)
def handle_exception(e): def handle_exception(e):
@ -117,7 +117,7 @@ def intent_reco():
return user_validation_error return user_validation_error
# 调用predict方法进行意图识别 # 调用predict方法进行意图识别
predicted_label, predicted_probability,predicted_id = intent_recognizer.predict(text) predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(text)
return jsonify( return jsonify(
code=200, code=200,
@ -190,19 +190,30 @@ def agent():
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query) predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query)
# 再进行槽位抽取 # 再进行槽位抽取
entities = slot_recognizer.recognize(query) entities = slot_recognizer.recognize(query)
status, sk = check_lost(predicted_label, entities)
# 返回意图和槽位识别的结果 print(f"意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}")
return jsonify({ #必须槽位缺失检查
"code": 200, status, sk = check_lost(predicted_id, entities)
"msg": "成功", if status == 1:
"answer": { return jsonify({
"int": predicted_id, "code": 200,
"label": predicted_label, "msg": "成功",
"probability": predicted_probability, "answer": {
"slot": entities "miss": sk
}, },
}) })
else:
return jsonify({
"code": 200,
"msg": "成功",
"answer": {
"int": predicted_id,
"label": predicted_label,
"probability": predicted_probability,
"slot": entities
},
})
# 如果是后续轮次(多轮对话),这里只做示例,可能需要根据具体需求进行处理 # 如果是后续轮次(多轮对话),这里只做示例,可能需要根据具体需求进行处理
else: else:
@ -219,27 +230,24 @@ def agent():
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回 return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
# def check_lost(int_res, slot):
# return 0, ""
def check_lost(int_res, slot): def check_lost(int_res, slot):
# mapping = { #labels: ["天气查询","互联网查询","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"]
# "页面切换":[['页面','应用']],
# "作业计划数量查询":[['时间']],
# "周计划查询":[['时间']],
# "作业内容":[['时间']],
# "施工人数":[['时间']],
# "作业考勤人数":[['时间']],
# }
mapping = { mapping = {
1: [['date', 'area']], 2: [['page'], ['app'], ['module']],
3: [['page'], ['app'], ['module']], 3: [['date']],
4: [['date']], 4: [['date']],
5: [['date']], 5: [['date']],
6: [['date']], 6: [['date']],
7: [['date']], 7: [['date']],
8: [[]], 8: [['date']],
9: [[]],
} }
#3:"页面切换",
intention_mapping = {2: "页面切换", 3: "日计划数量查询", 4: "周计划数量查询", 5: "日计划作业内容",
6: "周计划作业内容",7: "施工人数",8: "作业考勤人数"}
if not mapping.__contains__(int_res): if not mapping.__contains__(int_res):
return 0, [] return 0, ""
cur_k = list(slot.keys()) cur_k = list(slot.keys())
idx = -1 idx = -1
idx_len = 99 idx_len = 99
@ -258,8 +266,12 @@ def check_lost(int_res, slot):
if idx_len == 0: # 匹配通过 if idx_len == 0: # 匹配通过
return 0, cur_k return 0, cur_k
left = [x for x in mapping[int_res][idx] if x not in cur_k] left = [x for x in mapping[int_res][idx] if x not in cur_k]
return 1, left # mapping[int_res][idx] apologize_str = "非常抱歉,"
if int_res == 2:
return 1, f"{apologize_str}请问你想查询哪个页面?"
elif int_res in [3, 4, 5, 6, 7, 8]:
return 1, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}"
if __name__ == '__main__': if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True) app.run(host='0.0.0.0', port=18074, debug=True)

View File

@ -63,8 +63,8 @@ def preprocess_function(example, tokenizer):
# === 3. 加载 UIE 预训练模型 === # === 3. 加载 UIE 预训练模型 ===
model = ErnieForTokenClassification.from_pretrained("uie-base", num_classes=21) # 3 类 (O, B, I) model = ErnieForTokenClassification.from_pretrained(r"/mnt/d/weiweiwang/intention/models/uie-base", num_classes=21) # 3 类 (O, B, I)
tokenizer = ErnieTokenizer.from_pretrained("uie-base") tokenizer = ErnieTokenizer.from_pretrained(r"/mnt/d/weiweiwang/intention/models/uie-base")
# === 4. 加载数据集 === # === 4. 加载数据集 ===
train_dataset = load_dataset("data/data_part1.json") # 训练数据集 train_dataset = load_dataset("data/data_part1.json") # 训练数据集
@ -81,7 +81,7 @@ data_collator = DataCollatorForTokenClassification(tokenizer, padding=True)
# === 7. 训练参数 === # === 7. 训练参数 ===
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir="./uie_ner", output_dir="./output",
evaluation_strategy="epoch", evaluation_strategy="epoch",
save_strategy="epoch", save_strategy="epoch",
per_device_train_batch_size=16, # 你的显存较大,可调整 batch_size per_device_train_batch_size=16, # 你的显存较大,可调整 batch_size