From 2411eba2f0c892c121044e712e7f8cfea08bc92e Mon Sep 17 00:00:00 2001 From: weiweiw <14335254+weiweiw22@user.noreply.gitee.com> Date: Thu, 27 Feb 2025 19:51:17 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E5=B7=A5=E7=A8=8B=E5=90=8D?= =?UTF-8?q?=E5=92=8C=E9=A1=B9=E7=9B=AE=E5=90=8D=E8=A7=84=E8=8C=83=E5=8C=96?= =?UTF-8?q?=E5=92=8C=E5=A4=9A=E8=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/constants.py | 21 +++++++++++++ api/mian.py | 82 +++++++++++++++++++++++++++++++----------------- api/utils.py | 14 +++++++++ 3 files changed, 88 insertions(+), 29 deletions(-) create mode 100644 api/constants.py diff --git a/api/constants.py b/api/constants.py new file mode 100644 index 0000000..0e7fd56 --- /dev/null +++ b/api/constants.py @@ -0,0 +1,21 @@ +# constants.py +#日期 +DATE = "date" +#工程名称 +PROJECT_NAME = "projectName" +#工程性质 +PROJECT_TYPE = "projectType" +#建管单位 +CONSTRUCTION_UNIT = "constructionUnit" +#分公司 +IMPLEMENTATION_ORG = "implementationOrganization" +#项目部 +PROJECT_DEPARTMENT = "projectDepartment" +#项目经理 +PROJECT_MANAGER = "projectManager" +#分包商 +SUBCONTRACTOR = "subcontractor" +#班组 +TEAM_LEADER = "teamLeader" +#风险 +RISK_LEVEL = "riskLevel" diff --git a/api/mian.py b/api/mian.py index 7d01b0d..b17726f 100644 --- a/api/mian.py +++ b/api/mian.py @@ -12,7 +12,8 @@ from pydantic import ValidationError from intentRecognition import IntentRecognition from slotRecognition import SlotRecognition from fuzzywuzzy import process -import utils +from utils import CheckResult, StandardType +from constants import PROJECT_NAME, PROJECT_DEPARTMENT # 常量 MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-4160" @@ -199,28 +200,26 @@ def agent(): entities = slot_recognizer.recognize(query) print(f"意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}") + #必须槽位缺失检查 status, sk = check_lost(predicted_id, entities) - if status == 1: + if status == CheckResult.NEEDS_MORE_ROUNDS: + return jsonify({"code": 10001, "msg": "成功", + "answer": { "miss": sk}, + }) + + #工程名和项目名标准化 + result, information = check_project_standard_slot(predicted_id, entities) + if result == CheckResult.NEEDS_MORE_ROUNDS: return jsonify({ - "code": 200, - "msg": "成功", - "answer": { - "miss": sk - }, + "code": 10001, "msg": "成功", + "answer": {"miss": information}, }) - else: - return jsonify({ - "code": 200, - "msg": "成功", - "answer": { - "int": predicted_id, - "label": predicted_label, - "probability": predicted_probability, - "slot": entities - }, - }) + return jsonify({ + "code": 200,"msg": "成功", + "answer": {"int": predicted_id, "label": predicted_label, "probability": predicted_probability, "slot": entities }, + }) # 如果是后续轮次(多轮对话),这里只做示例,可能需要根据具体需求进行处理 else: @@ -248,39 +247,64 @@ def check_lost(int_res, slot): 7: [['date']], 8: [['date']], } - #3:"页面切换", + intention_mapping = {2: "页面切换", 3: "日计划数量查询", 4: "周计划数量查询", 5: "日计划作业内容", 6: "周计划作业内容",7: "施工人数",8: "作业考勤人数"} if not mapping.__contains__(int_res): return 0, "" + #提取的槽位信息 cur_k = list(slot.keys()) idx = -1 idx_len = 99 for i in range(len(mapping[int_res])): sk = mapping[int_res][i] - left = [x for x in sk if x not in cur_k] - more = [x for x in cur_k if x not in sk] - if len(more) >= 0 and len(left) == 0: + #不在提取的槽位信息里,但是在必须槽位表里 + miss_params = [x for x in sk if x not in cur_k] + #不在必须槽位表里,但是在提取的槽位信息里 + extra_params = [x for x in cur_k if x not in sk] + if len(extra_params) >= 0 and len(miss_params) == 0: idx = i idx_len = 0 break - if len(left) < idx_len: + if len(miss_params) < idx_len: idx = i - idx_len = len(left) + idx_len = len(miss_params) if idx_len == 0: # 匹配通过 - return 0, cur_k + return CheckResult.NO_MATCH, cur_k + #符合当前意图的的必须槽位,但是不在提取的槽位信息里 left = [x for x in mapping[int_res][idx] if x not in cur_k] + print(f"符合当前意图的的必须槽位,但是不在提取的槽位信息里, {left}") apologize_str = "非常抱歉," if int_res == 2: - return 1, f"{apologize_str}请问你想查询哪个页面?" + return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询哪个页面?" elif int_res in [3, 4, 5, 6, 7, 8]: - return 1, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}?" + return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}?" #标准化工程名 -def check_project_standard_slot(int_res, slot)->tuple: - return 0, '' +def check_project_standard_slot(int_res, slot) -> tuple: + intention_list = {3, 4, 5, 6, 7, 8} + if int_res not in intention_list: + return CheckResult.NO_MATCH,"" + + for key, value in slot.items(): + if key == PROJECT_NAME: + match_project, match_possibility = fuzzy_match(value, standard_project_name_list) + if match_possibility >= 0.9: + slot[key] = match_project + else: + return CheckResult.NEEDS_MORE_ROUNDS, f"抱歉,您说的工程名是{match_project}吗" + + if key == PROJECT_DEPARTMENT: + match_program, match_possibility = fuzzy_match(value, standard_program_name_list) + if match_possibility >= 0.9: + slot[key] = match_program + else: + return CheckResult.NEEDS_MORE_ROUNDS, f"抱歉,您说的项目名是{match_program}吗" + + return CheckResult.NO_MATCH,"" + def fuzzy_match(user_input, standard_name): result = process.extract(user_input, standard_name) diff --git a/api/utils.py b/api/utils.py index 083a32d..711d7be 100644 --- a/api/utils.py +++ b/api/utils.py @@ -1,3 +1,4 @@ +from enum import Enum def load_standard_name(file_path:str): try: # f = open(file_path, 'r', encoding='utf-8') @@ -16,3 +17,16 @@ def load_standard_name(file_path:str): except Exception as e: print(f"读取文件时发生错误:{e}") raise Exception(f"错误:文件 {file_path} 不存在") + + +class CheckResult(Enum): + NO_MATCH = 0 # 不符合检查条件 + MATCH_FOUND = 1 # 匹配到了值 + NEEDS_MORE_ROUNDS = 2 # 需要多轮 + + +class StandardType(Enum): + #工程名检查 + PROJECT_CHECK = 0 + #项目名检查 + PROGRAM_CHECK = 1