必须槽位检查
This commit is contained in:
parent
a20f513d38
commit
9109890fca
78
api/mian.py
78
api/mian.py
|
|
@ -9,12 +9,12 @@ import paddle.nn.functional as F # 用于 Softmax
|
|||
from typing import List, Dict
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.intentRecognition import IntentRecognition
|
||||
from api.slotRecognition import SlotRecognition
|
||||
from intentRecognition import IntentRecognition
|
||||
from slotRecognition import SlotRecognition
|
||||
|
||||
# 常量
|
||||
MODEL_ERNIE_PATH = R"E:\workingSpace\PycharmProjects\Intention_dev\ernie\output\checkpoint-4160"
|
||||
MODEL_UIE_PATH = R"E:\workingSpace\PycharmProjects\Intention_dev\uie\uie_ner\checkpoint-4320"
|
||||
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-4160"
|
||||
MODEL_UIE_PATH = R"../uie/output/checkpoint-4320"
|
||||
|
||||
# 类别名称列表
|
||||
labels = [
|
||||
|
|
@ -36,7 +36,6 @@ label_map = {
|
|||
10: 'B-riskLevel', 20: 'I-riskLevel'
|
||||
}
|
||||
|
||||
|
||||
# 初始化工具类
|
||||
intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
|
||||
|
||||
|
|
@ -45,6 +44,7 @@ slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map)
|
|||
# 设置Flask应用
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
# 统一的异常处理函数
|
||||
@app.errorhandler(Exception)
|
||||
def handle_exception(e):
|
||||
|
|
@ -117,7 +117,7 @@ def intent_reco():
|
|||
return user_validation_error
|
||||
|
||||
# 调用predict方法进行意图识别
|
||||
predicted_label, predicted_probability,predicted_id = intent_recognizer.predict(text)
|
||||
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(text)
|
||||
|
||||
return jsonify(
|
||||
code=200,
|
||||
|
|
@ -190,19 +190,30 @@ def agent():
|
|||
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query)
|
||||
# 再进行槽位抽取
|
||||
entities = slot_recognizer.recognize(query)
|
||||
status, sk = check_lost(predicted_label, entities)
|
||||
|
||||
# 返回意图和槽位识别的结果
|
||||
return jsonify({
|
||||
"code": 200,
|
||||
"msg": "成功",
|
||||
"answer": {
|
||||
"int": predicted_id,
|
||||
"label": predicted_label,
|
||||
"probability": predicted_probability,
|
||||
"slot": entities
|
||||
},
|
||||
})
|
||||
print(f"意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}")
|
||||
#必须槽位缺失检查
|
||||
status, sk = check_lost(predicted_id, entities)
|
||||
if status == 1:
|
||||
return jsonify({
|
||||
"code": 200,
|
||||
"msg": "成功",
|
||||
"answer": {
|
||||
"miss": sk
|
||||
},
|
||||
})
|
||||
|
||||
else:
|
||||
return jsonify({
|
||||
"code": 200,
|
||||
"msg": "成功",
|
||||
"answer": {
|
||||
"int": predicted_id,
|
||||
"label": predicted_label,
|
||||
"probability": predicted_probability,
|
||||
"slot": entities
|
||||
},
|
||||
})
|
||||
|
||||
# 如果是后续轮次(多轮对话),这里只做示例,可能需要根据具体需求进行处理
|
||||
else:
|
||||
|
|
@ -219,27 +230,24 @@ def agent():
|
|||
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
|
||||
|
||||
|
||||
# def check_lost(int_res, slot):
|
||||
# return 0, ""
|
||||
def check_lost(int_res, slot):
|
||||
# mapping = {
|
||||
# "页面切换":[['页面','应用']],
|
||||
# "作业计划数量查询":[['时间']],
|
||||
# "周计划查询":[['时间']],
|
||||
# "作业内容":[['时间']],
|
||||
# "施工人数":[['时间']],
|
||||
# "作业考勤人数":[['时间']],
|
||||
# }
|
||||
#labels: ["天气查询","互联网查询","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"]
|
||||
mapping = {
|
||||
1: [['date', 'area']],
|
||||
3: [['page'], ['app'], ['module']],
|
||||
2: [['page'], ['app'], ['module']],
|
||||
3: [['date']],
|
||||
4: [['date']],
|
||||
5: [['date']],
|
||||
6: [['date']],
|
||||
7: [['date']],
|
||||
8: [[]],
|
||||
9: [[]],
|
||||
8: [['date']],
|
||||
}
|
||||
#3:"页面切换",
|
||||
intention_mapping = {2: "页面切换", 3: "日计划数量查询", 4: "周计划数量查询", 5: "日计划作业内容",
|
||||
6: "周计划作业内容",7: "施工人数",8: "作业考勤人数"}
|
||||
if not mapping.__contains__(int_res):
|
||||
return 0, []
|
||||
return 0, ""
|
||||
cur_k = list(slot.keys())
|
||||
idx = -1
|
||||
idx_len = 99
|
||||
|
|
@ -258,8 +266,12 @@ def check_lost(int_res, slot):
|
|||
if idx_len == 0: # 匹配通过
|
||||
return 0, cur_k
|
||||
left = [x for x in mapping[int_res][idx] if x not in cur_k]
|
||||
return 1, left # mapping[int_res][idx]
|
||||
apologize_str = "非常抱歉,"
|
||||
if int_res == 2:
|
||||
return 1, f"{apologize_str}请问你想查询哪个页面?"
|
||||
elif int_res in [3, 4, 5, 6, 7, 8]:
|
||||
return 1, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}?"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(host='0.0.0.0', port=5000, debug=True)
|
||||
app.run(host='0.0.0.0', port=18074, debug=True)
|
||||
|
|
|
|||
|
|
@ -63,8 +63,8 @@ def preprocess_function(example, tokenizer):
|
|||
|
||||
|
||||
# === 3. 加载 UIE 预训练模型 ===
|
||||
model = ErnieForTokenClassification.from_pretrained("uie-base", num_classes=21) # 3 类 (O, B, I)
|
||||
tokenizer = ErnieTokenizer.from_pretrained("uie-base")
|
||||
model = ErnieForTokenClassification.from_pretrained(r"/mnt/d/weiweiwang/intention/models/uie-base", num_classes=21) # 3 类 (O, B, I)
|
||||
tokenizer = ErnieTokenizer.from_pretrained(r"/mnt/d/weiweiwang/intention/models/uie-base")
|
||||
|
||||
# === 4. 加载数据集 ===
|
||||
train_dataset = load_dataset("data/data_part1.json") # 训练数据集
|
||||
|
|
@ -81,7 +81,7 @@ data_collator = DataCollatorForTokenClassification(tokenizer, padding=True)
|
|||
|
||||
# === 7. 训练参数 ===
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./uie_ner",
|
||||
output_dir="./output",
|
||||
evaluation_strategy="epoch",
|
||||
save_strategy="epoch",
|
||||
per_device_train_batch_size=16, # 你的显存较大,可调整 batch_size
|
||||
|
|
|
|||
Loading…
Reference in New Issue