diff --git a/api/mian.py b/api/mian.py index b2fe016..ec9fdbd 100644 --- a/api/mian.py +++ b/api/mian.py @@ -9,12 +9,12 @@ import paddle.nn.functional as F # 用于 Softmax from typing import List, Dict from pydantic import ValidationError -from api.intentRecognition import IntentRecognition -from api.slotRecognition import SlotRecognition +from intentRecognition import IntentRecognition +from slotRecognition import SlotRecognition # 常量 -MODEL_ERNIE_PATH = R"E:\workingSpace\PycharmProjects\Intention_dev\ernie\output\checkpoint-4160" -MODEL_UIE_PATH = R"E:\workingSpace\PycharmProjects\Intention_dev\uie\uie_ner\checkpoint-4320" +MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-4160" +MODEL_UIE_PATH = R"../uie/output/checkpoint-4320" # 类别名称列表 labels = [ @@ -36,7 +36,6 @@ label_map = { 10: 'B-riskLevel', 20: 'I-riskLevel' } - # 初始化工具类 intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels) @@ -45,6 +44,7 @@ slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map) # 设置Flask应用 app = Flask(__name__) + # 统一的异常处理函数 @app.errorhandler(Exception) def handle_exception(e): @@ -117,7 +117,7 @@ def intent_reco(): return user_validation_error # 调用predict方法进行意图识别 - predicted_label, predicted_probability,predicted_id = intent_recognizer.predict(text) + predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(text) return jsonify( code=200, @@ -190,19 +190,30 @@ def agent(): predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query) # 再进行槽位抽取 entities = slot_recognizer.recognize(query) - status, sk = check_lost(predicted_label, entities) - # 返回意图和槽位识别的结果 - return jsonify({ - "code": 200, - "msg": "成功", - "answer": { - "int": predicted_id, - "label": predicted_label, - "probability": predicted_probability, - "slot": entities - }, - }) + print(f"意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}") + #必须槽位缺失检查 + status, sk = check_lost(predicted_id, entities) + if status == 1: + return jsonify({ + "code": 200, + "msg": "成功", + "answer": { + "miss": sk + }, + }) + + else: + return jsonify({ + "code": 200, + "msg": "成功", + "answer": { + "int": predicted_id, + "label": predicted_label, + "probability": predicted_probability, + "slot": entities + }, + }) # 如果是后续轮次(多轮对话),这里只做示例,可能需要根据具体需求进行处理 else: @@ -219,27 +230,24 @@ def agent(): return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回 +# def check_lost(int_res, slot): +# return 0, "" def check_lost(int_res, slot): - # mapping = { - # "页面切换":[['页面','应用']], - # "作业计划数量查询":[['时间']], - # "周计划查询":[['时间']], - # "作业内容":[['时间']], - # "施工人数":[['时间']], - # "作业考勤人数":[['时间']], - # } + #labels: ["天气查询","互联网查询","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"] mapping = { - 1: [['date', 'area']], - 3: [['page'], ['app'], ['module']], + 2: [['page'], ['app'], ['module']], + 3: [['date']], 4: [['date']], 5: [['date']], 6: [['date']], 7: [['date']], - 8: [[]], - 9: [[]], + 8: [['date']], } + #3:"页面切换", + intention_mapping = {2: "页面切换", 3: "日计划数量查询", 4: "周计划数量查询", 5: "日计划作业内容", + 6: "周计划作业内容",7: "施工人数",8: "作业考勤人数"} if not mapping.__contains__(int_res): - return 0, [] + return 0, "" cur_k = list(slot.keys()) idx = -1 idx_len = 99 @@ -258,8 +266,12 @@ def check_lost(int_res, slot): if idx_len == 0: # 匹配通过 return 0, cur_k left = [x for x in mapping[int_res][idx] if x not in cur_k] - return 1, left # mapping[int_res][idx] + apologize_str = "非常抱歉," + if int_res == 2: + return 1, f"{apologize_str}请问你想查询哪个页面?" + elif int_res in [3, 4, 5, 6, 7, 8]: + return 1, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}?" if __name__ == '__main__': - app.run(host='0.0.0.0', port=5000, debug=True) + app.run(host='0.0.0.0', port=18074, debug=True) diff --git a/uie/train.py b/uie/train.py index 269ed31..491b160 100644 --- a/uie/train.py +++ b/uie/train.py @@ -63,8 +63,8 @@ def preprocess_function(example, tokenizer): # === 3. 加载 UIE 预训练模型 === -model = ErnieForTokenClassification.from_pretrained("uie-base", num_classes=21) # 3 类 (O, B, I) -tokenizer = ErnieTokenizer.from_pretrained("uie-base") +model = ErnieForTokenClassification.from_pretrained(r"/mnt/d/weiweiwang/intention/models/uie-base", num_classes=21) # 3 类 (O, B, I) +tokenizer = ErnieTokenizer.from_pretrained(r"/mnt/d/weiweiwang/intention/models/uie-base") # === 4. 加载数据集 === train_dataset = load_dataset("data/data_part1.json") # 训练数据集 @@ -81,7 +81,7 @@ data_collator = DataCollatorForTokenClassification(tokenizer, padding=True) # === 7. 训练参数 === training_args = TrainingArguments( - output_dir="./uie_ner", + output_dir="./output", evaluation_strategy="epoch", save_strategy="epoch", per_device_train_batch_size=16, # 你的显存较大,可调整 batch_size