diff --git a/api/config.py b/api/config.py index 7f21782..8f60fae 100644 --- a/api/config.py +++ b/api/config.py @@ -2,4 +2,6 @@ api_base_url = "http://36.33.26.201:27861/v1" api_key = 'EMPTY' model_name = 'qwen2.5-instruct' +redis_url = "redis://:Bonus@Redis123!@192.168.0.37:16379" + diff --git a/api/globalData.py b/api/globalData.py index 1c76f2e..36a103f 100644 --- a/api/globalData.py +++ b/api/globalData.py @@ -1,5 +1,9 @@ # globalData.py +import json import time +import redis +from config import redis_url + class GlobalData: @@ -99,3 +103,104 @@ class GlobalData: }) print(f"✅ Data updated from local at {time.strftime('%Y-%m-%d %H:%M:%S')}") + + + @classmethod + def update_from_redis(cls): + from utils import ( + load_standard_data, + load_standard_name, + clean_useless_company_name, + clean_useless_project_name, + text_to_pinyin + ) + + # 公司数据 + # r = redis.Redis(host='192.168.0.37', port=16379, password = 'Bonus@Redis123!', decode_responses=True) + r = redis.from_url(redis_url, decode_responses=True) + + json_str = r.get('SBD_QUERY_DATA:STANDARD_COMPANY_PROGRAM') + if json_str: + temp_standard_company_program = json.loads(json_str) + print(f"update_from_redis:temp_standard_project_name_list from redis") + else: + temp_standard_company_program = load_standard_data("./standard_data/standard_company_program.json") + + if temp_standard_company_program != cls.standard_company_program: + cls.standard_company_program.clear() + cls.standard_company_program.update(temp_standard_company_program) + + cls.standard_company_name_list.clear() + cls.standard_company_name_list.extend(list(cls.standard_company_program.keys())) + + cls.simply_to_standard_company_name_map.clear() + cls.simply_to_standard_company_name_map.update({ + clean_useless_company_name(kw): kw for kw in cls.standard_company_name_list + }) + + cls.pinyin_simply_to_standard_company_name_map.clear() + cls.pinyin_simply_to_standard_company_name_map.update({ + text_to_pinyin(clean_useless_company_name(kw)): kw for kw in cls.standard_company_name_list + }) + + # 工程名数据 + json_str = r.get('SBD_QUERY_DATA:PROJECT_NAME') + if json_str: + temp_standard_project_name_list = json.loads(json_str) + print(f"update_from_redis:temp_standard_project_name_list from redis") + else: + temp_standard_project_name_list = load_standard_name('./standard_data/standard_project.txt') + + if temp_standard_project_name_list != cls.standard_project_name_list: + cls.standard_project_name_list.clear() + cls.standard_project_name_list.extend(temp_standard_project_name_list) + + cls.simply_to_standard_project_name_map.clear() + cls.simply_to_standard_project_name_map.update({ + clean_useless_project_name(kw): kw for kw in cls.standard_project_name_list + }) + + cls.pinyin_simply_to_standard_project_name_map.clear() + cls.pinyin_simply_to_standard_project_name_map.update({ + text_to_pinyin(clean_useless_project_name(kw)): kw for kw in cls.standard_project_name_list + }) + + # 建管单位数据 + temp_standard_construct_name_list = load_standard_name('./standard_data/construct_unit.txt') + if temp_standard_construct_name_list != cls.standard_construct_name_list: + cls.standard_construct_name_list.clear() + cls.standard_construct_name_list.extend(temp_standard_construct_name_list) + + cls.simply_to_standard_construct_name_map.clear() + cls.simply_to_standard_construct_name_map.update({ + clean_useless_company_name(kw): kw for kw in cls.standard_construct_name_list + }) + + cls.pinyin_simply_to_standard_construct_name_map.clear() + cls.pinyin_simply_to_standard_construct_name_map.update({ + text_to_pinyin(clean_useless_company_name(kw)): kw for kw in cls.standard_construct_name_list + }) + + # 分包单位数据 + json_str = r.get('SBD_QUERY_DATA:SUBCONTRACTOR') + if json_str: + temp_standard_constractor_name_list = json.loads(json_str) + print(f"update_from_redis:temp_standard_constractor_name_list from redis") + else: + temp_standard_constractor_name_list = load_standard_name('./standard_data/sub_contract.txt') + + if temp_standard_constractor_name_list != cls.standard_constractor_name_list: + cls.standard_constractor_name_list.clear() + cls.standard_constractor_name_list.extend(temp_standard_constractor_name_list) + + cls.simply_to_standard_constractor_name_map.clear() + cls.simply_to_standard_constractor_name_map.update({ + clean_useless_company_name(kw): kw for kw in cls.standard_constractor_name_list + }) + + cls.pinyin_simply_to_standard_constractor_name_map.clear() + cls.pinyin_simply_to_standard_constractor_name_map.update({ + text_to_pinyin(clean_useless_company_name(kw)): kw for kw in cls.standard_constractor_name_list + }) + + print(f"✅ Data updated from local at {time.strftime('%Y-%m-%d %H:%M:%S')}") diff --git a/api/main.py b/api/main.py index 71ceaa4..c649b44 100644 --- a/api/main.py +++ b/api/main.py @@ -7,12 +7,8 @@ import time from intentRecognition import IntentRecognition from slotRecognition import SlotRecognition -from utils import CheckResult, load_standard_name, generate_project_prompt, \ - load_standard_data, text_to_pinyin, \ - standardize_projectDepartment, standardize_project_name, clean_useless_project_name, \ - clean_useless_company_name, standardize_sub_company +from utils import CheckResult, check_standard_name_slot_probability, check_lost -from constants import PROJECT_NAME, PROJECT_DEPARTMENT, SIMILARITY_VALUE, IMPLEMENTATION_ORG, RISK_LEVEL from config import * MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-22620" @@ -44,49 +40,6 @@ label_map = { 14: 'B-constructionArea', 28: 'I-constructionArea', } -# 全局变量 -#标准公司名和项目名中文mapping -standard_company_program = {} -#标准分公司名 -standard_company_name_list = [] -#去不重要词条后中文分公司名和标准化分公司名mapping -simply_to_standard_company_name_map = {} -#去不重要词条后拼音分公司名和标准化分公司名mapping -pinyin_simply_to_standard_company_name_map = {} - -# 标准工程名 -standard_project_name_list = [] -#去不重要词条后中文分公司名和标准化分公司名mapping -simply_to_standard_project_name_map = {} -#去不重要词条后工程名拼音和标准化工程名mapping -pinyin_simply_to_standard_project_name_map = {} - -def update_data_from_local(): - global standard_company_program, standard_company_name_list, simply_to_standard_company_name_map, \ - pinyin_simply_to_standard_company_name_map, standard_project_name_list, simply_to_standard_project_name_map, \ - pinyin_simply_to_standard_project_name_map - - #标准公司名和项目名中文mapping - temp_standard_company_program = load_standard_data("./standard_data/standard_company_program.json") - if temp_standard_company_program != standard_company_program: - standard_company_program = temp_standard_company_program - standard_company_name_list = list(standard_company_program.keys()) - simply_to_standard_company_name_map = {clean_useless_company_name(kw): kw for kw in standard_company_name_list} - pinyin_simply_to_standard_company_name_map = {text_to_pinyin(clean_useless_company_name(kw)): kw for kw in - standard_company_name_list} - - # 标准工程名 - temp_standard_project_name_list = load_standard_name('./standard_data/standard_project.txt') - if temp_standard_project_name_list != standard_project_name_list: - standard_project_name_list = temp_standard_project_name_list - simply_to_standard_project_name_map = {clean_useless_project_name(kw): kw for kw in standard_project_name_list} - pinyin_simply_to_standard_project_name_map = {text_to_pinyin(clean_useless_project_name(kw)): kw for kw in - standard_project_name_list} - - current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) - print(f"Updated data from local at {current_time}") - - # 初始化工具类 intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels) @@ -94,7 +47,9 @@ intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels) slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map) # 设置Flask应用 -update_data_from_local() +# update_data_from_local() +from globalData import GlobalData +GlobalData.update_from_local() app = Flask(__name__) @@ -207,8 +162,10 @@ def slot_reco(): return user_validation_error # 调用 recognize 方法进行槽位识别 - entities = slot_recognizer.recognize(text) - + entities, slot_probability = slot_recognizer.recognize_probability(text) + print( + f"槽位抽取后的实体:{entities},实体后的可能值:{slot_probability}", + flush=True) return jsonify( code=200, msg="成功", @@ -246,10 +203,9 @@ def agent(): # 先进行意图识别 predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query) # 再进行槽位抽取 - entities = slot_recognizer.recognize(query) - + entities,slot_probability = slot_recognizer.recognize_probability(query) print( - f"第一轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}", + f"第一轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},slot_probability:{slot_probability},message:{messages}", flush=True) # 多轮 else: @@ -264,9 +220,9 @@ def agent(): "answer": {"int": predicted_id, "label": predicted_label, "probability": predicted_probability}, "finalQuery": res }) - entities = slot_recognizer.recognize(res) + entities, slot_probability = slot_recognizer.recognize_probability(res) print( - f"多轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}", + f"多轮意图识别后的槽位:槽位抽取后的实体:{entities},slot_probability:{slot_probability}", flush=True) #必须槽位缺失检查 @@ -277,7 +233,12 @@ def agent(): }) #工程名、分公司名和项目名标准化 - result, information = check_standard_name_slot(predicted_id, entities) + result, information = check_standard_name_slot_probability(predicted_id, entities) + if result == CheckResult.NEEDS_MORE_ROUNDS: + return jsonify({ + "code": 10001, "msg": "成功", + "answer": {"miss": information}, + }) if result == CheckResult.NEEDS_MORE_ROUNDS: return jsonify({ "code": 10001, "msg": "成功", @@ -403,112 +364,6 @@ def extract_multi_chat(messages): return res -#槽位缺失检查 -def check_lost(int_res, slot): - #labels: ["天气查询","通用对话","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"] - mapping = { - 2: [['page'], ['app'], ['module']], - 3: [['date']], - 4: [['date']], - 5: [['date']], - 6: [['date']], - 7: [['date']], - 8: [['date']], - 11: [['date']], - 12: [['date']], - 13: [['date']], - 14: [['date']], - 15: [['date']], - } - - intention_mapping = {2: "页面切换", 3: "日计划数量查询", 4: "周计划数量查询", 5: "日计划作业内容", - 6: "周计划作业内容", 7: "施工人数", 8: "作业考勤人数", 11: "作业面查询", - 12: "班组人数查询", 13: "班组数查询", 14: "作业面内容", 15: "班组详情"} - if not mapping.__contains__(int_res): - return 0, "" - #提取的槽位信息 - cur_k = list(slot.keys()) - idx = -1 - idx_len = 99 - for i in range(len(mapping[int_res])): - sk = mapping[int_res][i] - #不在提取的槽位信息里,但是在必须槽位表里 - miss_params = [x for x in sk if x not in cur_k] - #不在必须槽位表里,但是在提取的槽位信息里 - extra_params = [x for x in cur_k if x not in sk] - if len(extra_params) >= 0 and len(miss_params) == 0: - idx = i - idx_len = 0 - break - if len(miss_params) < idx_len: - idx = i - idx_len = len(miss_params) - - if idx_len == 0: # 匹配通过 - return CheckResult.NO_MATCH, cur_k - #符合当前意图的的必须槽位,但是不在提取的槽位信息里 - left = [x for x in mapping[int_res][idx] if x not in cur_k] - print(f"符合当前意图的的必须槽位,但是不在提取的槽位信息里, {left}", flush=True) - apologize_str = "非常抱歉," - if int_res == 2: - return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询哪个页面?" - elif int_res in [3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15]: - return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}?" - - -#标准化分公司名,工程名,项目名等 -def check_standard_name_slot(int_res, slot) -> tuple: - intention_list = {3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15} - if int_res not in intention_list: - return CheckResult.NO_MATCH, "" - - #项目名 当项目名存在时需要一定存在分公司(实施组织)名 - if PROJECT_DEPARTMENT in slot: - if IMPLEMENTATION_ORG not in slot: - return CheckResult.NEEDS_MORE_ROUNDS, "请补充该项目部所属的分公司名称" - - #工程名和分公司名和项目名标准化 - for key, value in slot.items(): - if key == PROJECT_NAME: - print(f"check_standard_name_slot 原始工程名 : {slot[PROJECT_NAME]}") - match_results = standardize_project_name(value, simply_to_standard_project_name_map, - pinyin_simply_to_standard_project_name_map, 70, 90) - print(f"check_standard_name_slot 匹配后工程名 :result:{match_results}", flush=True) - if match_results and len(match_results) == 1: - slot[key] = match_results[0] - else: - prompt = generate_project_prompt(match_results, original_name=slot[PROJECT_NAME], type="工程名") - return CheckResult.NEEDS_MORE_ROUNDS, prompt - - if key == IMPLEMENTATION_ORG and slot[key] != "公司": - print(f"check_standard_name_slot 原始分公司名 : {slot[IMPLEMENTATION_ORG]}") - match_results = standardize_sub_company(value, simply_to_standard_company_name_map, - pinyin_simply_to_standard_company_name_map, 55, 80) - print(f"check_standard_name_slot 匹配后分公司名: result:{match_results}", flush=True) - if match_results and len(match_results) == 1: - slot[key] = match_results[0] - else: - prompt = generate_project_prompt(match_results, original_name=slot[IMPLEMENTATION_ORG], type="分公司名") - return CheckResult.NEEDS_MORE_ROUNDS, prompt - - if key == PROJECT_DEPARTMENT: - print(f"check_standard_name_slot 原始项目部名 : {slot[PROJECT_DEPARTMENT]}") - match_results = standardize_projectDepartment(slot[IMPLEMENTATION_ORG], value, standard_company_program, - high_score=90) - print(f"check_standard_name_slot 匹配后项目部名: result:{match_results}", flush=True) - if match_results and len(match_results) == 1: - slot[key] = match_results[0] - else: - prompt = generate_project_prompt(match_results, original_name=slot[PROJECT_DEPARTMENT], type="项目名") - return CheckResult.NEEDS_MORE_ROUNDS, prompt - - if key == RISK_LEVEL: - if slot[RISK_LEVEL] not in ["2级", "3级", "4级", "5级"] and slot[RISK_LEVEL] not in ["二级", "三级", "四级", - "五级"]: - return CheckResult.NEEDS_MORE_ROUNDS, "您查询的风险等级在系统中未找到,请确认风险等级后再次提问" - - return CheckResult.NO_MATCH, "" - # # # # test_cases = [ diff --git a/api/main_temp.py b/api/main_temp.py index e045830..d7804e2 100644 --- a/api/main_temp.py +++ b/api/main_temp.py @@ -7,25 +7,19 @@ import time from intentRecognition import IntentRecognition from slotRecognition import SlotRecognition -from utils import CheckResult, load_standard_name, generate_project_prompt, \ - load_standard_data, text_to_pinyin, \ - standardize_projectDepartment, standardize_project_name, clean_useless_project_name, \ - clean_useless_company_name, standardize_sub_company, standardize_name_only_high_score, \ - generate_project_prompt_with_key - -from constants import PROJECT_NAME, PROJECT_DEPARTMENT, SIMILARITY_VALUE, IMPLEMENTATION_ORG, RISK_LEVEL, \ - CONSTRUCTION_UNIT, SUBCONTRACTOR +from utils import CheckResult, check_standard_name_slot_probability, check_lost, standardize_sub_company, \ + standardize_project_name, standardize_projectDepartment from config import * -MODEL_ERNIE_PATH = R"../ernie/output_temp/checkpoint-22620" -MODEL_UIE_PATH = R"../uie/output_temp/checkpoint-22320" - +MODEL_ERNIE_PATH = R"../ernie/output_temp/checkpoint-22960" +MODEL_UIE_PATH = R"../uie/output_temp/checkpoint-22670" # 类别名称列表 labels = [ "天气查询", "互联网查询", "页面切换", "日计划数量查询", "周计划数量查询", "日计划作业内容", "周计划作业内容", "施工人数", "作业考勤人数", "知识问答", - "通用对话", "作业面查询", "班组人数查询", "班组数查询", "作业面内容", "班组详情" + "通用对话", "作业面查询", "班组人数查询", "班组数查询", "作业面内容", "班组详情", + "工程进度查询" ] # 标签映射 @@ -48,9 +42,6 @@ label_map = { } - - - # 初始化工具类 intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels) @@ -58,13 +49,12 @@ intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels) slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map) # 设置Flask应用 -# update_data_from_local() from globalData import GlobalData GlobalData.update_from_local() +# GlobalData.update_from_redis() app = Flask(__name__) - # 统一的异常处理函数 @app.errorhandler(Exception) def handle_exception(e): @@ -326,7 +316,7 @@ def extract_multi_chat(messages): 第四步:用户最新问题是否为序号指代(第一个/第2个)?→ 用完整工程/项目/公司名替换补全 - 精确提取用户所指的序号(如“第3个”指第3个工程名、公司名或项目部名); - 将该工程、公司或项目部的完整名称(包括括号中的编号)提取出来; - - **用完整名称替换掉用户上一个问题中出现的简称或模糊表达,并保留用户问题中的其它部分原样不变(如时间、计划数、内容)不变**; + - **用完整名称替换掉用户上一个问题中出现的简称或模糊表达,并保留用户问题中的其它部分原样不变(如时间、计划数、内容如“进度情况”“作业计划”“作业内容”)不变**; - 示例1: - 用户最新问题:"第一个" 或"第1个" - 对话记录的最后一个用户问题:"2025年南苑调相机检修(PROJ-2023-0179)今天有多少作业计划"" @@ -334,11 +324,11 @@ def extract_multi_chat(messages): - 则最终提问应为: `检修公司调相机一二次设备检修维护和改造服务框架-2025年南苑调相机检修(PROJ-2023-0179)今天有多少作业计划` - 示例2: - - 用户的最新问题:"第二个" 或"第2个" - - 对话记录的最后一个用户问题:"宏源电力建设公司第三项目部今天有多少项作业计划"" - - 对话记录的最后一个AI回答:列出多个分公司名,第2个:"安徽宏源电力建设有限公司(线路)" + - 用户的最新问题:"第一个" 或"第1个" + - 对话记录的最后一个用户问题:"请帮我查一下今天芦集变电站的进度情况" + - 对话记录的最后一个AI回答:列出多个工程名,第1个:"芦集-古沟π入潘集变电站220kV线路工程(PROJ-2024-0189)" - 则最终提问应为: - "安徽宏源电力建设有限公司(线路)第三项目部今天有多少项作业计划" + "请帮我查一下今天芦集-古沟π入潘集变电站220kV线路工程(PROJ-2024-0189)的进度情况" 第五步:输出最终问题 - 直接输出最终问题(无解释、无多余前缀或后缀) @@ -373,187 +363,59 @@ def extract_multi_chat(messages): return res -#槽位缺失检查 -def check_lost(int_res, slot): - #labels: ["天气查询","通用对话","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"] - mapping = { - 2: [['page'], ['app'], ['module']], - 3: [['date']], - 4: [['date']], - 5: [['date']], - 6: [['date']], - 7: [['date']], - 8: [['date']], - 11: [['date']], - 12: [['date']], - 13: [['date']], - 14: [['date']], - 15: [['date']], - } - - intention_mapping = {2: "页面切换", 3: "日计划数量查询", 4: "周计划数量查询", 5: "日计划作业内容", - 6: "周计划作业内容", 7: "施工人数", 8: "作业考勤人数", 11: "作业面查询", - 12: "班组人数查询", 13: "班组数查询", 14: "作业面内容", 15: "班组详情"} - if not mapping.__contains__(int_res): - return 0, "" - #提取的槽位信息 - cur_k = list(slot.keys()) - idx = -1 - idx_len = 99 - for i in range(len(mapping[int_res])): - sk = mapping[int_res][i] - #不在提取的槽位信息里,但是在必须槽位表里 - miss_params = [x for x in sk if x not in cur_k] - #不在必须槽位表里,但是在提取的槽位信息里 - extra_params = [x for x in cur_k if x not in sk] - if len(extra_params) >= 0 and len(miss_params) == 0: - idx = i - idx_len = 0 - break - if len(miss_params) < idx_len: - idx = i - idx_len = len(miss_params) - - if idx_len == 0: # 匹配通过 - return CheckResult.NO_MATCH, cur_k - #符合当前意图的的必须槽位,但是不在提取的槽位信息里 - left = [x for x in mapping[int_res][idx] if x not in cur_k] - print(f"符合当前意图的的必须槽位,但是不在提取的槽位信息里, {left}", flush=True) - apologize_str = "非常抱歉," - if int_res == 2: - return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询哪个页面?" - elif int_res in [3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15]: - return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}?" - - #标准化分公司名,工程名,项目名等 -def check_standard_name_slot(int_res, slot) -> tuple: - intention_list = {3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15} - if int_res not in intention_list: - return CheckResult.NO_MATCH, "" - - #项目名 当项目名存在时需要一定存在分公司(实施组织)名 - if PROJECT_DEPARTMENT in slot: - if IMPLEMENTATION_ORG not in slot: - return CheckResult.NEEDS_MORE_ROUNDS, "请补充该项目部所属的分公司名称" - - #工程名和分公司名和项目名标准化 - for key, value in slot.items(): - if key == PROJECT_NAME: - print(f"check_standard_name_slot 原始工程名 : {slot[PROJECT_NAME]}") - match_results = standardize_project_name(value, simply_to_standard_project_name_map, - pinyin_simply_to_standard_project_name_map, 70, 90) - print(f"check_standard_name_slot 匹配后工程名 :result:{match_results}", flush=True) - if match_results and len(match_results) == 1: - slot[key] = match_results[0] - else: - prompt = generate_project_prompt(match_results, original_name=slot[PROJECT_NAME], type="工程名") - return CheckResult.NEEDS_MORE_ROUNDS, prompt - - if key == IMPLEMENTATION_ORG and slot[key] != "公司": - print(f"check_standard_name_slot 原始分公司名 : {slot[IMPLEMENTATION_ORG]}") - match_results = standardize_sub_company(value, simply_to_standard_company_name_map, - pinyin_simply_to_standard_company_name_map, 55, 80) - print(f"check_standard_name_slot 匹配后分公司名: result:{match_results}", flush=True) - if match_results and len(match_results) == 1: - slot[key] = match_results[0] - else: - prompt = generate_project_prompt(match_results, original_name=slot[IMPLEMENTATION_ORG], type="分公司名") - return CheckResult.NEEDS_MORE_ROUNDS, prompt - - if key == PROJECT_DEPARTMENT: - print(f"check_standard_name_slot 原始项目部名 : {slot[PROJECT_DEPARTMENT]}") - match_results = standardize_projectDepartment(slot[IMPLEMENTATION_ORG], value, standard_company_program, - high_score=90) - print(f"check_standard_name_slot 匹配后项目部名: result:{match_results}", flush=True) - if match_results and len(match_results) == 1: - slot[key] = match_results[0] - else: - prompt = generate_project_prompt(match_results, original_name=slot[PROJECT_DEPARTMENT], type="项目名") - return CheckResult.NEEDS_MORE_ROUNDS, prompt - if key == RISK_LEVEL: - if slot[RISK_LEVEL] not in ["2级", "3级", "4级", "5级"] and slot[RISK_LEVEL] not in ["二级", "三级", "四级", - "五级"]: - return CheckResult.NEEDS_MORE_ROUNDS, "您查询的风险等级在系统中未找到,请确认风险等级后再次提问" - - return CheckResult.NO_MATCH, "" - - -def check_standard_name_slot_probability(int_res, slot) -> tuple: - intention_list = {3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15} - if int_res not in intention_list: - return CheckResult.NO_MATCH, "" - - #项目名 当项目名存在时需要一定存在分公司(实施组织)名 - if PROJECT_DEPARTMENT in slot: - if IMPLEMENTATION_ORG not in slot: - return CheckResult.NEEDS_MORE_ROUNDS, "请补充该项目部所属的分公司名称" - - #工程名和分公司名和项目名标准化 - for key, value in slot.items(): - if key == PROJECT_NAME: - print(f"check_standard_name_slot_probability 原始工程名 : {slot[PROJECT_NAME]}") - match_results = standardize_project_name(value, GlobalData.simply_to_standard_project_name_map, - GlobalData.pinyin_simply_to_standard_project_name_map, 70, 90) - print(f"check_standard_name_slot_probability 匹配后工程名 :result:{match_results}", flush=True) - if match_results and len(match_results) == 1: - slot[key] = match_results[0] - else: - prompt = generate_project_prompt(match_results, original_name=slot[PROJECT_NAME], type="工程名") - return CheckResult.NEEDS_MORE_ROUNDS, prompt - - if key == IMPLEMENTATION_ORG and slot[key] != "公司": - print(f"check_standard_name_slot_probability 原始分公司名 : {slot[IMPLEMENTATION_ORG]}") - match_results = standardize_sub_company(value, GlobalData.simply_to_standard_company_name_map, - GlobalData.pinyin_simply_to_standard_company_name_map, 60, 80) - print(f"check_standard_name_slot_probability 匹配后分公司名: result:{match_results}", flush=True) - if match_results and len(match_results) == 1: - slot[key] = match_results[0] - else: - prompt = generate_project_prompt_with_key(match_results, original_name=slot[IMPLEMENTATION_ORG], slot_key= IMPLEMENTATION_ORG) - return CheckResult.NEEDS_MORE_ROUNDS, prompt - - if key == CONSTRUCTION_UNIT: - print(f"check_standard_name_slot_probability 原始建管单位名 : {slot[CONSTRUCTION_UNIT]}") - match_results = standardize_sub_company(value, GlobalData.simply_to_standard_construct_name_map, - GlobalData.pinyin_simply_to_standard_construct_name_map, 55, 80) - print(f"check_standard_name_slot_probability 匹配后建管单位名: result:{match_results}", flush=True) - if match_results and len(match_results) == 1: - slot[key] = match_results[0] - else: - prompt = generate_project_prompt_with_key(match_results, original_name=slot[CONSTRUCTION_UNIT], slot_key= CONSTRUCTION_UNIT) - return CheckResult.NEEDS_MORE_ROUNDS, prompt - - if key == SUBCONTRACTOR: - print(f"check_standard_name_slot_probability 原始分包单位名 : {slot[SUBCONTRACTOR]}") - match_results = standardize_sub_company(value, GlobalData.simply_to_standard_constractor_name_map, - GlobalData.pinyin_simply_to_standard_constractor_name_map, 55, 80) - print(f"check_standard_name_slot_probability 匹配后分包单位名: result:{match_results}", flush=True) - if match_results and len(match_results) == 1: - slot[key] = match_results[0] - else: - prompt = generate_project_prompt_with_key(match_results, original_name=slot[SUBCONTRACTOR], slot_key= SUBCONTRACTOR) - return CheckResult.NEEDS_MORE_ROUNDS, prompt - - if key == PROJECT_DEPARTMENT: - print(f"check_standard_name_slot 原始项目部名 : {slot[PROJECT_DEPARTMENT]}") - match_results = standardize_projectDepartment(slot[IMPLEMENTATION_ORG], value, GlobalData.standard_company_program, - high_score=90) - print(f"check_standard_name_slot 匹配后项目部名: result:{match_results}", flush=True) - if match_results and len(match_results) == 1: - slot[key] = match_results[0] - else: - prompt = generate_project_prompt(match_results, original_name=slot[PROJECT_DEPARTMENT], type="项目名") - return CheckResult.NEEDS_MORE_ROUNDS, prompt - - if key == RISK_LEVEL: - if slot[RISK_LEVEL] not in ["2级", "3级", "4级", "5级"] and slot[RISK_LEVEL] not in ["二级", "三级", "四级", - "五级"]: - return CheckResult.NEEDS_MORE_ROUNDS, "您查询的风险等级在系统中未找到,请确认风险等级后再次提问" - - return CheckResult.NO_MATCH, "" +# def check_standard_name_slot(int_res, slot) -> tuple: +# intention_list = {3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15} +# if int_res not in intention_list: +# return CheckResult.NO_MATCH, "" +# +# #项目名 当项目名存在时需要一定存在分公司(实施组织)名 +# if PROJECT_DEPARTMENT in slot: +# if IMPLEMENTATION_ORG not in slot: +# return CheckResult.NEEDS_MORE_ROUNDS, "请补充该项目部所属的分公司名称" +# +# #工程名和分公司名和项目名标准化 +# for key, value in slot.items(): +# if key == PROJECT_NAME: +# print(f"check_standard_name_slot 原始工程名 : {slot[PROJECT_NAME]}") +# match_results = standardize_project_name(value, simply_to_standard_project_name_map, +# pinyin_simply_to_standard_project_name_map, 70, 90) +# print(f"check_standard_name_slot 匹配后工程名 :result:{match_results}", flush=True) +# if match_results and len(match_results) == 1: +# slot[key] = match_results[0] +# else: +# prompt = generate_project_prompt(match_results, original_name=slot[PROJECT_NAME], type="工程名") +# return CheckResult.NEEDS_MORE_ROUNDS, prompt +# +# if key == IMPLEMENTATION_ORG and slot[key] != "公司": +# print(f"check_standard_name_slot 原始分公司名 : {slot[IMPLEMENTATION_ORG]}") +# match_results = standardize_sub_company(value, simply_to_standard_company_name_map, +# pinyin_simply_to_standard_company_name_map, 55, 80) +# print(f"check_standard_name_slot 匹配后分公司名: result:{match_results}", flush=True) +# if match_results and len(match_results) == 1: +# slot[key] = match_results[0] +# else: +# prompt = generate_project_prompt(match_results, original_name=slot[IMPLEMENTATION_ORG], type="分公司名") +# return CheckResult.NEEDS_MORE_ROUNDS, prompt +# +# if key == PROJECT_DEPARTMENT: +# print(f"check_standard_name_slot 原始项目部名 : {slot[PROJECT_DEPARTMENT]}") +# match_results = standardize_projectDepartment(slot[IMPLEMENTATION_ORG], value, standard_company_program, +# high_score=90) +# print(f"check_standard_name_slot 匹配后项目部名: result:{match_results}", flush=True) +# if match_results and len(match_results) == 1: +# slot[key] = match_results[0] +# else: +# prompt = generate_project_prompt(match_results, original_name=slot[PROJECT_DEPARTMENT], type="项目名") +# return CheckResult.NEEDS_MORE_ROUNDS, prompt +# if key == RISK_LEVEL: +# if slot[RISK_LEVEL] not in ["2级", "3级", "4级", "5级"] and slot[RISK_LEVEL] not in ["二级", "三级", "四级", +# "五级"]: +# return CheckResult.NEEDS_MORE_ROUNDS, "您查询的风险等级在系统中未找到,请确认风险等级后再次提问" +# +# return CheckResult.NO_MATCH, "" + # -# # # test_cases = [ # ("送一分公司"), # ("送二分公司"), @@ -579,12 +441,12 @@ def check_standard_name_slot_probability(int_res, slot) -> tuple: # print(f"加权混合策略 分公司名匹配**********************") # start = time.perf_counter() # for item in test_cases: -# match_results = standardize_sub_company(item,simply_to_standard_company_name_map, pinyin_simply_to_standard_company_name_map,55,80) +# match_results = standardize_sub_company(item,GlobalData.simply_to_standard_company_name_map, GlobalData.pinyin_simply_to_standard_company_name_map,70,90) # print(f"加权混合策略 分公司名匹配 输入: {item}-> 输出: {match_results}") # end = time.perf_counter() # print(f"加权混合策略 耗时: {end - start:.4f} 秒") # - +# # # test_cases = [ # ("合肥供电公司"), @@ -595,17 +457,17 @@ def check_standard_name_slot_probability(int_res, slot) -> tuple: # print(f"加权混合策略 建管单位名匹配**********************") # start = time.perf_counter() # for item in test_cases: -# match_results = standardize_sub_company(item,simply_to_standard_construct_name_map, pinyin_simply_to_standard_construct_name_map,55,80) +# match_results = standardize_sub_company(item,GlobalData.simply_to_standard_construct_name_map, GlobalData.pinyin_simply_to_standard_construct_name_map,70,90) # print(f"加权混合策略 建管单位名匹配 输入: {item}-> 输出: {match_results}") # # print(f"加权混合策略,分公司名匹配**********************") # for item in test_cases: -# match_results = standardize_sub_company(item,simply_to_standard_company_name_map, pinyin_simply_to_standard_company_name_map,55,80) +# match_results = standardize_sub_company(item,GlobalData.simply_to_standard_company_name_map, GlobalData.pinyin_simply_to_standard_company_name_map,70,90) # print(f"加权混合策略 分公司名匹配 输入: {item}-> 输出: {match_results}") # end = time.perf_counter() # print(f"加权混合策略 耗时: {end - start:.4f} 秒") - - +# +# # # # test_cases = [ # ("卢集"), @@ -648,8 +510,8 @@ def check_standard_name_slot_probability(int_res, slot) -> tuple: # print(f"去不重要词汇 工程名匹配******************************************") # start = time.perf_counter() # for item in test_cases: -# match_results = standardize_project_name(item, simply_to_standard_project_name_map, -# pinyin_simply_to_standard_project_name_map, 70, 90) +# match_results = standardize_project_name(item, GlobalData.simply_to_standard_project_name_map, +# GlobalData.pinyin_simply_to_standard_project_name_map, 70, 90) # print(f"工程名匹配 输入: {item}-> 输出: {match_results}") # end = time.perf_counter() # print(f"词集匹配 耗时: {end - start:.4f} 秒") @@ -690,9 +552,9 @@ def check_standard_name_slot_probability(int_res, slot) -> tuple: # ("电缆班"), # ] # -# for company in standard_company_name_list: +# for company in GlobalData.standard_company_name_list: # for program in oral_program_name_list: -# match_results = standardize_projectDepartment(company, program, standard_company_program, high_score=90) +# match_results = standardize_projectDepartment(company, program, GlobalData.standard_company_program, high_score=90) # print(f"加权混合策略 项目部名称 输入: 公司:{company},项目部:{program}-> 输出: {match_results}") diff --git a/api/utils.py b/api/utils.py index 3c1a9cf..146c3dc 100644 --- a/api/utils.py +++ b/api/utils.py @@ -8,8 +8,9 @@ import json from pypinyin import lazy_pinyin, Style import re +from globalData import GlobalData from constants import USELESS_COMPANY_WORDS, USELESS_PROJECT_WORDS, CONSTRUCTION_UNIT, IMPLEMENTATION_ORG, \ - SUBCONTRACTOR, PROJECT_NAME, PROJECT_DEPARTMENT + SUBCONTRACTOR, PROJECT_NAME, PROJECT_DEPARTMENT, RISK_LEVEL # 数字转换表(1-20,常见数字) digit_to_chinese = { @@ -418,3 +419,134 @@ def clean_useless_company_name(name: str) -> str: name = useless_company_words_pattern.sub("", name) name = company_symbols_pattern.sub("", name) return name.strip() + + +#槽位缺失检查 +def check_lost(int_res, slot): + #labels: ["天气查询","通用对话","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"] + mapping = { + 2: [['page'], ['app'], ['module']], + 3: [['date']], + 4: [['date']], + 5: [['date']], + 6: [['date']], + 7: [['date']], + 8: [['date']], + 11: [['date']], + 12: [['date']], + 13: [['date']], + 14: [['date']], + 15: [['date']], + 16: [['date']], + } + + intention_mapping = {2: "页面切换", 3: "日计划数量查询", 4: "周计划数量查询", 5: "日计划作业内容", + 6: "周计划作业内容", 7: "施工人数", 8: "作业考勤人数", 11: "作业面查询", + 12: "班组人数查询", 13: "班组数查询", 14: "作业面内容", 15: "班组详情", + 16: "工程进度查询"} + if not mapping.__contains__(int_res): + return 0, "" + #提取的槽位信息 + cur_k = list(slot.keys()) + idx = -1 + idx_len = 99 + for i in range(len(mapping[int_res])): + sk = mapping[int_res][i] + #不在提取的槽位信息里,但是在必须槽位表里 + miss_params = [x for x in sk if x not in cur_k] + #不在必须槽位表里,但是在提取的槽位信息里 + extra_params = [x for x in cur_k if x not in sk] + if len(extra_params) >= 0 and len(miss_params) == 0: + idx = i + idx_len = 0 + break + if len(miss_params) < idx_len: + idx = i + idx_len = len(miss_params) + + if idx_len == 0: # 匹配通过 + return CheckResult.NO_MATCH, cur_k + #符合当前意图的的必须槽位,但是不在提取的槽位信息里 + left = [x for x in mapping[int_res][idx] if x not in cur_k] + print(f"符合当前意图的的必须槽位,但是不在提取的槽位信息里, {left}", flush=True) + apologize_str = "非常抱歉," + if int_res == 2: + return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询哪个页面?" + elif int_res in [3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15,16]: + return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}?" + + +def check_standard_name_slot_probability(int_res, slot) -> tuple: + intention_list = {3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15,16} + if int_res not in intention_list: + return CheckResult.NO_MATCH, "" + + #项目名 当项目名存在时需要一定存在分公司(实施组织)名 + if PROJECT_DEPARTMENT in slot: + if IMPLEMENTATION_ORG not in slot: + return CheckResult.NEEDS_MORE_ROUNDS, "请补充该项目部所属的分公司名称" + + #工程名和分公司名和项目名标准化 + for key, value in slot.items(): + if key == PROJECT_NAME: + print(f"check_standard_name_slot_probability 原始工程名 : {slot[PROJECT_NAME]}") + match_results = standardize_project_name(value, GlobalData.simply_to_standard_project_name_map, + GlobalData.pinyin_simply_to_standard_project_name_map, 70, 90) + print(f"check_standard_name_slot_probability 匹配后工程名 :result:{match_results}", flush=True) + if match_results and len(match_results) == 1: + slot[key] = match_results[0] + else: + prompt = generate_project_prompt(match_results, original_name=slot[PROJECT_NAME], type="工程名") + return CheckResult.NEEDS_MORE_ROUNDS, prompt + + if key == IMPLEMENTATION_ORG and slot[key] != "公司": + print(f"check_standard_name_slot_probability 原始分公司名 : {slot[IMPLEMENTATION_ORG]}") + match_results = standardize_sub_company(value, GlobalData.simply_to_standard_company_name_map, + GlobalData.pinyin_simply_to_standard_company_name_map, 70, 90) + print(f"check_standard_name_slot_probability 匹配后分公司名: result:{match_results}", flush=True) + if match_results and len(match_results) == 1: + slot[key] = match_results[0] + else: + prompt = generate_project_prompt_with_key(match_results, original_name=slot[IMPLEMENTATION_ORG], slot_key= IMPLEMENTATION_ORG) + return CheckResult.NEEDS_MORE_ROUNDS, prompt + + if key == CONSTRUCTION_UNIT: + print(f"check_standard_name_slot_probability 原始建管单位名 : {slot[CONSTRUCTION_UNIT]}") + match_results = standardize_sub_company(value, GlobalData.simply_to_standard_construct_name_map, + GlobalData.pinyin_simply_to_standard_construct_name_map, 70, 90) + print(f"check_standard_name_slot_probability 匹配后建管单位名: result:{match_results}", flush=True) + if match_results and len(match_results) == 1: + slot[key] = match_results[0] + else: + prompt = generate_project_prompt_with_key(match_results, original_name=slot[CONSTRUCTION_UNIT], slot_key= CONSTRUCTION_UNIT) + return CheckResult.NEEDS_MORE_ROUNDS, prompt + + if key == SUBCONTRACTOR: + print(f"check_standard_name_slot_probability 原始分包单位名 : {slot[SUBCONTRACTOR]}") + match_results = standardize_sub_company(value, GlobalData.simply_to_standard_constractor_name_map, + GlobalData.pinyin_simply_to_standard_constractor_name_map, 70, 90) + print(f"check_standard_name_slot_probability 匹配后分包单位名: result:{match_results}", flush=True) + if match_results and len(match_results) == 1: + slot[key] = match_results[0] + else: + prompt = generate_project_prompt_with_key(match_results, original_name=slot[SUBCONTRACTOR], slot_key= SUBCONTRACTOR) + return CheckResult.NEEDS_MORE_ROUNDS, prompt + + if key == PROJECT_DEPARTMENT: + print(f"check_standard_name_slot 原始项目部名 : {slot[PROJECT_DEPARTMENT]}") + match_results = standardize_projectDepartment(slot[IMPLEMENTATION_ORG], value, GlobalData.standard_company_program, + high_score=90) + print(f"check_standard_name_slot 匹配后项目部名: result:{match_results}", flush=True) + if match_results and len(match_results) == 1: + slot[key] = match_results[0] + else: + prompt = generate_project_prompt(match_results, original_name=slot[PROJECT_DEPARTMENT], type="项目名") + return CheckResult.NEEDS_MORE_ROUNDS, prompt + + if key == RISK_LEVEL: + if slot[RISK_LEVEL] not in ["2级", "3级", "4级", "5级"] and slot[RISK_LEVEL] not in ["二级", "三级", "四级", + "五级"]: + return CheckResult.NEEDS_MORE_ROUNDS, "您查询的风险等级在系统中未找到,请确认风险等级后再次提问" + + return CheckResult.NO_MATCH, "" +# \ No newline at end of file diff --git a/generated_data/generated.py b/generated_data/generated.py index d6cc004..0d8b9f3 100644 --- a/generated_data/generated.py +++ b/generated_data/generated.py @@ -51,7 +51,8 @@ BASE_DATA = { "construction_areas": ["合肥","马鞍山","滁州"], "construction_units": ["芜湖供电公司","阜阳供电公司","安徽送变电工程有限公司","安徽明生电力投资集团有限公司","明生电力投资公司","国网安徽省电力有限公司合肥供电公司", - "淮南交通控股(集团)有限公司","国网安徽省电力有限公司淮南供电公司","合肥供电公司","西信新能源科技公司","葛洲坝集团公司","中国葛洲坝集团公司","国网安徽省电力有限公司建设分公司"], + "淮南交通控股(集团)有限公司","国网安徽省电力有限公司淮南供电公司","合肥供电公司","西信新能源科技公司","葛洲坝集团公司","中国葛洲坝集团公司", + "国网安徽省电力有限公司建设分公司","中铁四局建设公司","中铁四局建设公司","银联黄山园区开发有限公司"], # 分包单位 "subcontractors": ["劦力建筑责任公司","安徽劦力建筑装饰有限责任公司", "安徽苏亚建设集团有限公司","大信电力建设有限公司","优越电力公司", "安徽国腾电力工程有限公司","安徽京硚建设有限公司","中国能源建设集团安徽省电力设计院有限公司"],