必须槽位检查
This commit is contained in:
parent
a20f513d38
commit
9109890fca
56
api/mian.py
56
api/mian.py
|
|
@ -9,12 +9,12 @@ import paddle.nn.functional as F # 用于 Softmax
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from api.intentRecognition import IntentRecognition
|
from intentRecognition import IntentRecognition
|
||||||
from api.slotRecognition import SlotRecognition
|
from slotRecognition import SlotRecognition
|
||||||
|
|
||||||
# 常量
|
# 常量
|
||||||
MODEL_ERNIE_PATH = R"E:\workingSpace\PycharmProjects\Intention_dev\ernie\output\checkpoint-4160"
|
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-4160"
|
||||||
MODEL_UIE_PATH = R"E:\workingSpace\PycharmProjects\Intention_dev\uie\uie_ner\checkpoint-4320"
|
MODEL_UIE_PATH = R"../uie/output/checkpoint-4320"
|
||||||
|
|
||||||
# 类别名称列表
|
# 类别名称列表
|
||||||
labels = [
|
labels = [
|
||||||
|
|
@ -36,7 +36,6 @@ label_map = {
|
||||||
10: 'B-riskLevel', 20: 'I-riskLevel'
|
10: 'B-riskLevel', 20: 'I-riskLevel'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# 初始化工具类
|
# 初始化工具类
|
||||||
intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
|
intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
|
||||||
|
|
||||||
|
|
@ -45,6 +44,7 @@ slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map)
|
||||||
# 设置Flask应用
|
# 设置Flask应用
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
|
||||||
# 统一的异常处理函数
|
# 统一的异常处理函数
|
||||||
@app.errorhandler(Exception)
|
@app.errorhandler(Exception)
|
||||||
def handle_exception(e):
|
def handle_exception(e):
|
||||||
|
|
@ -190,9 +190,20 @@ def agent():
|
||||||
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query)
|
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query)
|
||||||
# 再进行槽位抽取
|
# 再进行槽位抽取
|
||||||
entities = slot_recognizer.recognize(query)
|
entities = slot_recognizer.recognize(query)
|
||||||
status, sk = check_lost(predicted_label, entities)
|
|
||||||
|
|
||||||
# 返回意图和槽位识别的结果
|
print(f"意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}")
|
||||||
|
#必须槽位缺失检查
|
||||||
|
status, sk = check_lost(predicted_id, entities)
|
||||||
|
if status == 1:
|
||||||
|
return jsonify({
|
||||||
|
"code": 200,
|
||||||
|
"msg": "成功",
|
||||||
|
"answer": {
|
||||||
|
"miss": sk
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
else:
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"code": 200,
|
"code": 200,
|
||||||
"msg": "成功",
|
"msg": "成功",
|
||||||
|
|
@ -219,27 +230,24 @@ def agent():
|
||||||
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
|
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
|
||||||
|
|
||||||
|
|
||||||
|
# def check_lost(int_res, slot):
|
||||||
|
# return 0, ""
|
||||||
def check_lost(int_res, slot):
|
def check_lost(int_res, slot):
|
||||||
# mapping = {
|
#labels: ["天气查询","互联网查询","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"]
|
||||||
# "页面切换":[['页面','应用']],
|
|
||||||
# "作业计划数量查询":[['时间']],
|
|
||||||
# "周计划查询":[['时间']],
|
|
||||||
# "作业内容":[['时间']],
|
|
||||||
# "施工人数":[['时间']],
|
|
||||||
# "作业考勤人数":[['时间']],
|
|
||||||
# }
|
|
||||||
mapping = {
|
mapping = {
|
||||||
1: [['date', 'area']],
|
2: [['page'], ['app'], ['module']],
|
||||||
3: [['page'], ['app'], ['module']],
|
3: [['date']],
|
||||||
4: [['date']],
|
4: [['date']],
|
||||||
5: [['date']],
|
5: [['date']],
|
||||||
6: [['date']],
|
6: [['date']],
|
||||||
7: [['date']],
|
7: [['date']],
|
||||||
8: [[]],
|
8: [['date']],
|
||||||
9: [[]],
|
|
||||||
}
|
}
|
||||||
|
#3:"页面切换",
|
||||||
|
intention_mapping = {2: "页面切换", 3: "日计划数量查询", 4: "周计划数量查询", 5: "日计划作业内容",
|
||||||
|
6: "周计划作业内容",7: "施工人数",8: "作业考勤人数"}
|
||||||
if not mapping.__contains__(int_res):
|
if not mapping.__contains__(int_res):
|
||||||
return 0, []
|
return 0, ""
|
||||||
cur_k = list(slot.keys())
|
cur_k = list(slot.keys())
|
||||||
idx = -1
|
idx = -1
|
||||||
idx_len = 99
|
idx_len = 99
|
||||||
|
|
@ -258,8 +266,12 @@ def check_lost(int_res, slot):
|
||||||
if idx_len == 0: # 匹配通过
|
if idx_len == 0: # 匹配通过
|
||||||
return 0, cur_k
|
return 0, cur_k
|
||||||
left = [x for x in mapping[int_res][idx] if x not in cur_k]
|
left = [x for x in mapping[int_res][idx] if x not in cur_k]
|
||||||
return 1, left # mapping[int_res][idx]
|
apologize_str = "非常抱歉,"
|
||||||
|
if int_res == 2:
|
||||||
|
return 1, f"{apologize_str}请问你想查询哪个页面?"
|
||||||
|
elif int_res in [3, 4, 5, 6, 7, 8]:
|
||||||
|
return 1, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}?"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
app.run(host='0.0.0.0', port=5000, debug=True)
|
app.run(host='0.0.0.0', port=18074, debug=True)
|
||||||
|
|
|
||||||
|
|
@ -63,8 +63,8 @@ def preprocess_function(example, tokenizer):
|
||||||
|
|
||||||
|
|
||||||
# === 3. 加载 UIE 预训练模型 ===
|
# === 3. 加载 UIE 预训练模型 ===
|
||||||
model = ErnieForTokenClassification.from_pretrained("uie-base", num_classes=21) # 3 类 (O, B, I)
|
model = ErnieForTokenClassification.from_pretrained(r"/mnt/d/weiweiwang/intention/models/uie-base", num_classes=21) # 3 类 (O, B, I)
|
||||||
tokenizer = ErnieTokenizer.from_pretrained("uie-base")
|
tokenizer = ErnieTokenizer.from_pretrained(r"/mnt/d/weiweiwang/intention/models/uie-base")
|
||||||
|
|
||||||
# === 4. 加载数据集 ===
|
# === 4. 加载数据集 ===
|
||||||
train_dataset = load_dataset("data/data_part1.json") # 训练数据集
|
train_dataset = load_dataset("data/data_part1.json") # 训练数据集
|
||||||
|
|
@ -81,7 +81,7 @@ data_collator = DataCollatorForTokenClassification(tokenizer, padding=True)
|
||||||
|
|
||||||
# === 7. 训练参数 ===
|
# === 7. 训练参数 ===
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
output_dir="./uie_ner",
|
output_dir="./output",
|
||||||
evaluation_strategy="epoch",
|
evaluation_strategy="epoch",
|
||||||
save_strategy="epoch",
|
save_strategy="epoch",
|
||||||
per_device_train_batch_size=16, # 你的显存较大,可调整 batch_size
|
per_device_train_batch_size=16, # 你的显存较大,可调整 batch_size
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue