标准化算法提优

This commit is contained in:
weiweiw 2025-04-18 16:39:06 +08:00
parent 510e829382
commit f38aa8ed01
4 changed files with 164 additions and 115 deletions

View File

@ -43,28 +43,48 @@ label_map = {
13: 'B-teamName', 26: 'I-teamName', 13: 'B-teamName', 26: 'I-teamName',
} }
#标准公司名和项目名 # 全局变量
standard_company_program = load_standard_data("./standard_data/standard_company_program.json") #标准公司名和项目名中文mapping
standard_company_program = {}
#标准分公司名
standard_company_name_list = []
#去不重要词条后中文分公司名和标准化分公司名mapping
simply_to_standard_company_name_map = {}
#去不重要词条后拼音分公司名和标准化分公司名mapping
pinyin_simply_to_standard_company_name_map = {}
# 标准工程名,标准工程名拼音和工程名映射,标准工程名拼音 # 标准工程名
standard_project_name_list = load_standard_name('./standard_data/standard_project.txt') standard_project_name_list = []
pinyin_to_standard_project_name_map = {text_to_pinyin(kw): kw for kw in standard_project_name_list} #去不重要词条后中文分公司名和标准化分公司名mapping
standard_project_name_pinyin_list = list(pinyin_to_standard_project_name_map.keys()) simply_to_standard_project_name_map = {}
#去不重要词条后工程名拼音和标准化工程名mapping
pinyin_simply_to_standard_project_name_map = {}
#标准分公司名,标准分公司名拼音和分公司名映射,标公司名拼音 def update_data_from_local():
standard_company_name_list = list(standard_company_program.keys()) global standard_company_program, standard_company_name_list, simply_to_standard_company_name_map, \
pinyin_to_standard_company_name_map = {text_to_pinyin(kw): kw for kw in standard_company_name_list} pinyin_simply_to_standard_company_name_map, standard_project_name_list, simply_to_standard_project_name_map, \
standard_company_name_pinyin_list = list(pinyin_to_standard_company_name_map.keys()) pinyin_simply_to_standard_project_name_map
simply_to_standard_project_name_map = {clean_useless_project_name(kw): kw for kw in standard_project_name_list} #标准公司名和项目名中文mapping
temp_standard_company_program = load_standard_data("./standard_data/standard_company_program.json")
if temp_standard_company_program != standard_company_program:
standard_company_program = temp_standard_company_program
standard_company_name_list = list(standard_company_program.keys())
simply_to_standard_company_name_map = {clean_useless_company_name(kw): kw for kw in standard_company_name_list}
pinyin_simply_to_standard_company_name_map = {text_to_pinyin(clean_useless_company_name(kw)): kw for kw in
standard_company_name_list}
pinyin_simply_to_standard_project_name_map = {text_to_pinyin(clean_useless_project_name(kw)): kw for kw in # 标准工程名
temp_standard_project_name_list = load_standard_name('./standard_data/standard_project.txt')
if temp_standard_project_name_list != standard_project_name_list:
standard_project_name_list = temp_standard_project_name_list
simply_to_standard_project_name_map = {clean_useless_project_name(kw): kw for kw in standard_project_name_list}
pinyin_simply_to_standard_project_name_map = {text_to_pinyin(clean_useless_project_name(kw)): kw for kw in
standard_project_name_list} standard_project_name_list}
simply_to_standard_company_name_map = {clean_useless_company_name(kw): kw for kw in standard_company_name_list} current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
print(f"Updated data from local at {current_time}")
pinyin_simply_to_standard_company_name_map = {text_to_pinyin(clean_useless_company_name(kw)): kw for kw in
standard_company_name_list}
# 初始化工具类 # 初始化工具类
intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels) intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
@ -73,9 +93,7 @@ intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map) slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map)
# 设置Flask应用 # 设置Flask应用
print(f"标准化的工程名是:{standard_project_name_list}", flush=True) update_data_from_local()
print(f"pinyin标准化的工程名是 list{standard_project_name_pinyin_list}", flush=True)
print(f"pinyin-工程民对应关系 map{pinyin_to_standard_company_name_map}", flush=True)
app = Flask(__name__) app = Flask(__name__)
@ -384,6 +402,7 @@ def extract_multi_chat(messages):
return res return res
#槽位缺失检查
def check_lost(int_res, slot): def check_lost(int_res, slot):
#labels: ["天气查询","通用对话","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"] #labels: ["天气查询","通用对话","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"]
mapping = { mapping = {
@ -436,7 +455,7 @@ def check_lost(int_res, slot):
return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}" return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}"
#标准化工程名 #标准化分公司名,工程名,项目名等
def check_standard_name_slot(int_res, slot) -> tuple: def check_standard_name_slot(int_res, slot) -> tuple:
intention_list = {3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15} intention_list = {3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15}
if int_res not in intention_list: if int_res not in intention_list:
@ -489,8 +508,8 @@ def check_standard_name_slot(int_res, slot) -> tuple:
return CheckResult.NO_MATCH, "" return CheckResult.NO_MATCH, ""
# #
# #
# test_cases = [ # test_cases = [
# ("送一分公司"), # ("送一分公司"),
# ("送二分公司"), # ("送二分公司"),
@ -522,7 +541,7 @@ def check_standard_name_slot(int_res, slot) -> tuple:
# print(f"加权混合策略 耗时: {end - start:.4f} 秒") # print(f"加权混合策略 耗时: {end - start:.4f} 秒")
# #
# #
# # #
# test_cases = [ # test_cases = [
# ("卢集"), # ("卢集"),
# ("芦集"), # ("芦集"),
@ -564,7 +583,8 @@ def check_standard_name_slot(int_res, slot) -> tuple:
# print(f"去不重要词汇 工程名匹配******************************************") # print(f"去不重要词汇 工程名匹配******************************************")
# start = time.perf_counter() # start = time.perf_counter()
# for item in test_cases: # for item in test_cases:
# match_results = standardize_project_name(item, simply_to_standard_project_name_map, pinyin_simply_to_standard_project_name_map,70,90) # match_results = standardize_project_name(item, simply_to_standard_project_name_map,
# pinyin_simply_to_standard_project_name_map, 70, 90)
# print(f"工程名匹配 输入: {item}-> 输出: {match_results}") # print(f"工程名匹配 输入: {item}-> 输出: {match_results}")
# end = time.perf_counter() # end = time.perf_counter()
# print(f"词集匹配 耗时: {end - start:.4f} 秒") # print(f"词集匹配 耗时: {end - start:.4f} 秒")

View File

@ -15,14 +15,14 @@ from utils import CheckResult, load_standard_name, generate_project_prompt, \
from constants import PROJECT_NAME, PROJECT_DEPARTMENT, SIMILARITY_VALUE, IMPLEMENTATION_ORG, RISK_LEVEL from constants import PROJECT_NAME, PROJECT_DEPARTMENT, SIMILARITY_VALUE, IMPLEMENTATION_ORG, RISK_LEVEL
from config import * from config import *
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-17890" MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-31940"
MODEL_UIE_PATH = R"../uie/output/checkpoint-17290" MODEL_UIE_PATH = R"../uie/output/checkpoint-31350"
# 类别名称列表 # 类别名称列表
labels = [ labels = [
"天气查询", "互联网查询", "页面切换", "日计划数量查询", "周计划数量查询", "天气查询", "互联网查询", "页面切换", "日计划数量查询", "周计划数量查询",
"日计划作业内容", "周计划作业内容", "施工人数", "作业考勤人数", "知识问答", "日计划作业内容", "周计划作业内容", "施工人数", "作业考勤人数", "知识问答",
"通用对话", "作业面查询","班组人数查询","班组数查询","作业面内容","班组详情" "通用对话", "作业面查询", "班组人数查询", "班组数查询", "作业面内容", "班组详情"
] ]
# 标签映射 # 标签映射
@ -43,27 +43,47 @@ label_map = {
13: 'B-teamName', 26: 'I-teamName', 13: 'B-teamName', 26: 'I-teamName',
} }
# 全局变量
#标准公司名和项目名中文mapping
standard_company_program = {}
#标准分公司名
standard_company_name_list = []
#去不重要词条后中文分公司名和标准化分公司名mapping
simply_to_standard_company_name_map = {}
#去不重要词条后拼音分公司名和标准化分公司名mapping
pinyin_simply_to_standard_company_name_map = {}
#标准公司名和项目名 # 标准工程名
standard_company_program = load_standard_data("./standard_data/standard_company_program.json") standard_project_name_list = []
#去不重要词条后中文分公司名和标准化分公司名mapping
simply_to_standard_project_name_map = {}
#去不重要词条后工程名拼音和标准化工程名mapping
pinyin_simply_to_standard_project_name_map = {}
# 标准工程名,标准工程名拼音和工程名映射,标准工程名拼音 def update_data_from_local():
standard_project_name_list = load_standard_name('./standard_data/standard_project.txt') global standard_company_program, standard_company_name_list, simply_to_standard_company_name_map, \
pinyin_to_standard_project_name_map = {text_to_pinyin(kw): kw for kw in standard_project_name_list} pinyin_simply_to_standard_company_name_map, standard_project_name_list, simply_to_standard_project_name_map, \
standard_project_name_pinyin_list = list(pinyin_to_standard_project_name_map.keys()) pinyin_simply_to_standard_project_name_map
#标准分公司名,标准分公司名拼音和分公司名映射,标公司名拼音 #标准公司名和项目名中文mapping
standard_company_name_list = list(standard_company_program.keys()) temp_standard_company_program = load_standard_data("./standard_data/standard_company_program.json")
pinyin_to_standard_company_name_map = {text_to_pinyin(kw): kw for kw in standard_company_name_list} if temp_standard_company_program != standard_company_program:
standard_company_name_pinyin_list = list(pinyin_to_standard_company_name_map.keys()) standard_company_program = temp_standard_company_program
standard_company_name_list = list(standard_company_program.keys())
simply_to_standard_company_name_map = {clean_useless_company_name(kw): kw for kw in standard_company_name_list}
pinyin_simply_to_standard_company_name_map = {text_to_pinyin(clean_useless_company_name(kw)): kw for kw in
standard_company_name_list}
simply_to_standard_project_name_map = {clean_useless_project_name(kw): kw for kw in standard_project_name_list} # 标准工程名
temp_standard_project_name_list = load_standard_name('./standard_data/standard_project.txt')
if temp_standard_project_name_list != standard_project_name_list:
standard_project_name_list = temp_standard_project_name_list
simply_to_standard_project_name_map = {clean_useless_project_name(kw): kw for kw in standard_project_name_list}
pinyin_simply_to_standard_project_name_map = {text_to_pinyin(clean_useless_project_name(kw)): kw for kw in
standard_project_name_list}
pinyin_simply_to_standard_project_name_map = {text_to_pinyin(clean_useless_project_name(kw)): kw for kw in standard_project_name_list} current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
print(f"Updated data from local at {current_time}")
simply_to_standard_company_name_map = {clean_useless_company_name(kw): kw for kw in standard_company_name_list}
pinyin_simply_to_standard_company_name_map = {text_to_pinyin(clean_useless_company_name(kw)): kw for kw in standard_company_name_list}
# 初始化工具类 # 初始化工具类
@ -73,12 +93,11 @@ intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map) slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map)
# 设置Flask应用 # 设置Flask应用
print(f"标准化的工程名是:{standard_project_name_list}", flush=True) update_data_from_local()
print(f"pinyin标准化的工程名是 list{standard_project_name_pinyin_list}", flush=True)
print(f"pinyin-工程民对应关系 map{pinyin_to_standard_company_name_map}", flush=True)
app = Flask(__name__) app = Flask(__name__)
# 统一的异常处理函数 # 统一的异常处理函数
@app.errorhandler(Exception) @app.errorhandler(Exception)
def handle_exception(e): def handle_exception(e):
@ -229,7 +248,8 @@ def agent():
entities = slot_recognizer.recognize(query) entities = slot_recognizer.recognize(query)
print( print(
f"第一轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}",flush=True) f"第一轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}",
flush=True)
# 多轮 # 多轮
else: else:
res = extract_multi_chat(messages) res = extract_multi_chat(messages)
@ -245,7 +265,8 @@ def agent():
}) })
entities = slot_recognizer.recognize(res) entities = slot_recognizer.recognize(res)
print( print(
f"多轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}",flush=True) f"多轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}",
flush=True)
#必须槽位缺失检查 #必须槽位缺失检查
status, sk = check_lost(predicted_id, entities) status, sk = check_lost(predicted_id, entities)
@ -273,6 +294,7 @@ def agent():
except Exception as e: except Exception as e:
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回 return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
def extract_multi_chat(messages): def extract_multi_chat(messages):
from openai import OpenAI from openai import OpenAI
client = OpenAI(base_url=api_base_url, api_key=api_key) client = OpenAI(base_url=api_base_url, api_key=api_key)
@ -379,6 +401,8 @@ def extract_multi_chat(messages):
print(f"多轮意图后用户想要的问题是:{res}", flush=True) print(f"多轮意图后用户想要的问题是:{res}", flush=True)
return res return res
#槽位缺失检查
def check_lost(int_res, slot): def check_lost(int_res, slot):
#labels: ["天气查询","通用对话","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"] #labels: ["天气查询","通用对话","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"]
mapping = { mapping = {
@ -398,7 +422,7 @@ def check_lost(int_res, slot):
intention_mapping = {2: "页面切换", 3: "日计划数量查询", 4: "周计划数量查询", 5: "日计划作业内容", intention_mapping = {2: "页面切换", 3: "日计划数量查询", 4: "周计划数量查询", 5: "日计划作业内容",
6: "周计划作业内容", 7: "施工人数", 8: "作业考勤人数", 11: "作业面查询", 6: "周计划作业内容", 7: "施工人数", 8: "作业考勤人数", 11: "作业面查询",
12:"班组人数查询", 13:"班组数查询", 14:"作业面内容", 15:"班组详情"} 12: "班组人数查询", 13: "班组数查询", 14: "作业面内容", 15: "班组详情"}
if not mapping.__contains__(int_res): if not mapping.__contains__(int_res):
return 0, "" return 0, ""
#提取的槽位信息 #提取的槽位信息
@ -423,7 +447,7 @@ def check_lost(int_res, slot):
return CheckResult.NO_MATCH, cur_k return CheckResult.NO_MATCH, cur_k
#符合当前意图的的必须槽位,但是不在提取的槽位信息里 #符合当前意图的的必须槽位,但是不在提取的槽位信息里
left = [x for x in mapping[int_res][idx] if x not in cur_k] left = [x for x in mapping[int_res][idx] if x not in cur_k]
print(f"符合当前意图的的必须槽位,但是不在提取的槽位信息里, {left}",flush=True) print(f"符合当前意图的的必须槽位,但是不在提取的槽位信息里, {left}", flush=True)
apologize_str = "非常抱歉," apologize_str = "非常抱歉,"
if int_res == 2: if int_res == 2:
return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询哪个页面?" return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询哪个页面?"
@ -431,7 +455,7 @@ def check_lost(int_res, slot):
return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}" return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}"
#标准化工程名 #标准化分公司名,工程名,项目名等
def check_standard_name_slot(int_res, slot) -> tuple: def check_standard_name_slot(int_res, slot) -> tuple:
intention_list = {3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15} intention_list = {3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15}
if int_res not in intention_list: if int_res not in intention_list:
@ -446,8 +470,9 @@ def check_standard_name_slot(int_res, slot) -> tuple:
for key, value in slot.items(): for key, value in slot.items():
if key == PROJECT_NAME: if key == PROJECT_NAME:
print(f"check_standard_name_slot 原始工程名 : {slot[PROJECT_NAME]}") print(f"check_standard_name_slot 原始工程名 : {slot[PROJECT_NAME]}")
match_results = standardize_project_name(value, simply_to_standard_project_name_map, pinyin_simply_to_standard_project_name_map,70,90) match_results = standardize_project_name(value, simply_to_standard_project_name_map,
print(f"check_standard_name_slot 匹配后工程名 result:{match_results}",flush=True) pinyin_simply_to_standard_project_name_map, 70, 90)
print(f"check_standard_name_slot 匹配后工程名 result:{match_results}", flush=True)
if match_results and len(match_results) == 1: if match_results and len(match_results) == 1:
slot[key] = match_results[0] slot[key] = match_results[0]
else: else:
@ -456,8 +481,9 @@ def check_standard_name_slot(int_res, slot) -> tuple:
if key == IMPLEMENTATION_ORG and slot[key] != "公司": if key == IMPLEMENTATION_ORG and slot[key] != "公司":
print(f"check_standard_name_slot 原始分公司名 : {slot[IMPLEMENTATION_ORG]}") print(f"check_standard_name_slot 原始分公司名 : {slot[IMPLEMENTATION_ORG]}")
match_results = standardize_sub_company(value,simply_to_standard_company_name_map, pinyin_simply_to_standard_company_name_map,55,80) match_results = standardize_sub_company(value, simply_to_standard_company_name_map,
print(f"check_standard_name_slot 匹配后分公司名: result:{match_results}",flush=True) pinyin_simply_to_standard_company_name_map, 55, 80)
print(f"check_standard_name_slot 匹配后分公司名: result:{match_results}", flush=True)
if match_results and len(match_results) == 1: if match_results and len(match_results) == 1:
slot[key] = match_results[0] slot[key] = match_results[0]
else: else:
@ -466,8 +492,9 @@ def check_standard_name_slot(int_res, slot) -> tuple:
if key == PROJECT_DEPARTMENT: if key == PROJECT_DEPARTMENT:
print(f"check_standard_name_slot 原始项目部名 : {slot[PROJECT_DEPARTMENT]}") print(f"check_standard_name_slot 原始项目部名 : {slot[PROJECT_DEPARTMENT]}")
match_results = standardize_projectDepartment(slot[IMPLEMENTATION_ORG], value, standard_company_program, high_score=90) match_results = standardize_projectDepartment(slot[IMPLEMENTATION_ORG], value, standard_company_program,
print(f"check_standard_name_slot 匹配后项目部名: result:{match_results}",flush=True) high_score=90)
print(f"check_standard_name_slot 匹配后项目部名: result:{match_results}", flush=True)
if match_results and len(match_results) == 1: if match_results and len(match_results) == 1:
slot[key] = match_results[0] slot[key] = match_results[0]
else: else:
@ -475,13 +502,14 @@ def check_standard_name_slot(int_res, slot) -> tuple:
return CheckResult.NEEDS_MORE_ROUNDS, prompt return CheckResult.NEEDS_MORE_ROUNDS, prompt
if key == RISK_LEVEL: if key == RISK_LEVEL:
if slot[RISK_LEVEL] not in["2级","3级","4级","5级"] and slot[RISK_LEVEL] not in["二级","三级","四级","五级"]: if slot[RISK_LEVEL] not in ["2级", "3级", "4级", "5级"] and slot[RISK_LEVEL] not in ["二级", "三级", "四级",
"五级"]:
return CheckResult.NEEDS_MORE_ROUNDS, "您查询的风险等级在系统中未找到,请确认风险等级后再次提问" return CheckResult.NEEDS_MORE_ROUNDS, "您查询的风险等级在系统中未找到,请确认风险等级后再次提问"
return CheckResult.NO_MATCH, "" return CheckResult.NO_MATCH, ""
# #
# #
# test_cases = [ # test_cases = [
# ("送一分公司"), # ("送一分公司"),
# ("送二分公司"), # ("送二分公司"),
@ -513,7 +541,7 @@ def check_standard_name_slot(int_res, slot) -> tuple:
# print(f"加权混合策略 耗时: {end - start:.4f} 秒") # print(f"加权混合策略 耗时: {end - start:.4f} 秒")
# #
# #
# # #
# test_cases = [ # test_cases = [
# ("卢集"), # ("卢集"),
# ("芦集"), # ("芦集"),
@ -555,7 +583,8 @@ def check_standard_name_slot(int_res, slot) -> tuple:
# print(f"去不重要词汇 工程名匹配******************************************") # print(f"去不重要词汇 工程名匹配******************************************")
# start = time.perf_counter() # start = time.perf_counter()
# for item in test_cases: # for item in test_cases:
# match_results = standardize_project_name(item, simply_to_standard_project_name_map, pinyin_simply_to_standard_project_name_map,70,90) # match_results = standardize_project_name(item, simply_to_standard_project_name_map,
# pinyin_simply_to_standard_project_name_map, 70, 90)
# print(f"工程名匹配 输入: {item}-> 输出: {match_results}") # print(f"工程名匹配 输入: {item}-> 输出: {match_results}")
# end = time.perf_counter() # end = time.perf_counter()
# print(f"词集匹配 耗时: {end - start:.4f} 秒") # print(f"词集匹配 耗时: {end - start:.4f} 秒")

View File

@ -4,7 +4,7 @@ from typing import cast, Callable
from rapidfuzz import process, fuzz from rapidfuzz import process, fuzz
from rapidfuzz.fuzz import WRatio from rapidfuzz.fuzz import WRatio
import json import json
from pypinyin import lazy_pinyin from pypinyin import lazy_pinyin, Style
from constants import USELESS_COMPANY_WORDS, USELESS_PROJECT_WORDS from constants import USELESS_COMPANY_WORDS, USELESS_PROJECT_WORDS
@ -36,7 +36,7 @@ def arabic_to_chinese_number(text):
def text_to_pinyin(text): def text_to_pinyin(text):
"""将文本转换为拼音字符串""" """将文本转换为拼音字符串"""
return ''.join(lazy_pinyin(text)) return ''.join(lazy_pinyin(text, Style.TONE2))
def load_standard_data(path): def load_standard_data(path):