标准化算法提优

This commit is contained in:
weiweiw 2025-04-18 16:39:06 +08:00
parent 510e829382
commit f38aa8ed01
4 changed files with 164 additions and 115 deletions

View File

@ -43,28 +43,48 @@ label_map = {
13: 'B-teamName', 26: 'I-teamName',
}
#标准公司名和项目名
standard_company_program = load_standard_data("./standard_data/standard_company_program.json")
# 全局变量
#标准公司名和项目名中文mapping
standard_company_program = {}
#标准分公司名
standard_company_name_list = []
#去不重要词条后中文分公司名和标准化分公司名mapping
simply_to_standard_company_name_map = {}
#去不重要词条后拼音分公司名和标准化分公司名mapping
pinyin_simply_to_standard_company_name_map = {}
# 标准工程名,标准工程名拼音和工程名映射,标准工程名拼音
standard_project_name_list = load_standard_name('./standard_data/standard_project.txt')
pinyin_to_standard_project_name_map = {text_to_pinyin(kw): kw for kw in standard_project_name_list}
standard_project_name_pinyin_list = list(pinyin_to_standard_project_name_map.keys())
# 标准工程名
standard_project_name_list = []
#去不重要词条后中文分公司名和标准化分公司名mapping
simply_to_standard_project_name_map = {}
#去不重要词条后工程名拼音和标准化工程名mapping
pinyin_simply_to_standard_project_name_map = {}
#标准分公司名,标准分公司名拼音和分公司名映射,标公司名拼音
standard_company_name_list = list(standard_company_program.keys())
pinyin_to_standard_company_name_map = {text_to_pinyin(kw): kw for kw in standard_company_name_list}
standard_company_name_pinyin_list = list(pinyin_to_standard_company_name_map.keys())
def update_data_from_local():
global standard_company_program, standard_company_name_list, simply_to_standard_company_name_map, \
pinyin_simply_to_standard_company_name_map, standard_project_name_list, simply_to_standard_project_name_map, \
pinyin_simply_to_standard_project_name_map
simply_to_standard_project_name_map = {clean_useless_project_name(kw): kw for kw in standard_project_name_list}
#标准公司名和项目名中文mapping
temp_standard_company_program = load_standard_data("./standard_data/standard_company_program.json")
if temp_standard_company_program != standard_company_program:
standard_company_program = temp_standard_company_program
standard_company_name_list = list(standard_company_program.keys())
simply_to_standard_company_name_map = {clean_useless_company_name(kw): kw for kw in standard_company_name_list}
pinyin_simply_to_standard_company_name_map = {text_to_pinyin(clean_useless_company_name(kw)): kw for kw in
standard_company_name_list}
pinyin_simply_to_standard_project_name_map = {text_to_pinyin(clean_useless_project_name(kw)): kw for kw in
# 标准工程名
temp_standard_project_name_list = load_standard_name('./standard_data/standard_project.txt')
if temp_standard_project_name_list != standard_project_name_list:
standard_project_name_list = temp_standard_project_name_list
simply_to_standard_project_name_map = {clean_useless_project_name(kw): kw for kw in standard_project_name_list}
pinyin_simply_to_standard_project_name_map = {text_to_pinyin(clean_useless_project_name(kw)): kw for kw in
standard_project_name_list}
simply_to_standard_company_name_map = {clean_useless_company_name(kw): kw for kw in standard_company_name_list}
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
print(f"Updated data from local at {current_time}")
pinyin_simply_to_standard_company_name_map = {text_to_pinyin(clean_useless_company_name(kw)): kw for kw in
standard_company_name_list}
# 初始化工具类
intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
@ -73,9 +93,7 @@ intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map)
# 设置Flask应用
print(f"标准化的工程名是:{standard_project_name_list}", flush=True)
print(f"pinyin标准化的工程名是 list{standard_project_name_pinyin_list}", flush=True)
print(f"pinyin-工程民对应关系 map{pinyin_to_standard_company_name_map}", flush=True)
update_data_from_local()
app = Flask(__name__)
@ -384,6 +402,7 @@ def extract_multi_chat(messages):
return res
#槽位缺失检查
def check_lost(int_res, slot):
#labels: ["天气查询","通用对话","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"]
mapping = {
@ -436,7 +455,7 @@ def check_lost(int_res, slot):
return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}"
#标准化工程名
#标准化分公司名,工程名,项目名等
def check_standard_name_slot(int_res, slot) -> tuple:
intention_list = {3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15}
if int_res not in intention_list:
@ -489,8 +508,8 @@ def check_standard_name_slot(int_res, slot) -> tuple:
return CheckResult.NO_MATCH, ""
#
# #
# test_cases = [
# ("送一分公司"),
# ("送二分公司"),
@ -522,7 +541,7 @@ def check_standard_name_slot(int_res, slot) -> tuple:
# print(f"加权混合策略 耗时: {end - start:.4f} 秒")
#
#
#
# #
# test_cases = [
# ("卢集"),
# ("芦集"),
@ -564,7 +583,8 @@ def check_standard_name_slot(int_res, slot) -> tuple:
# print(f"去不重要词汇 工程名匹配******************************************")
# start = time.perf_counter()
# for item in test_cases:
# match_results = standardize_project_name(item, simply_to_standard_project_name_map, pinyin_simply_to_standard_project_name_map,70,90)
# match_results = standardize_project_name(item, simply_to_standard_project_name_map,
# pinyin_simply_to_standard_project_name_map, 70, 90)
# print(f"工程名匹配 输入: {item}-> 输出: {match_results}")
# end = time.perf_counter()
# print(f"词集匹配 耗时: {end - start:.4f} 秒")

View File

@ -15,14 +15,14 @@ from utils import CheckResult, load_standard_name, generate_project_prompt, \
from constants import PROJECT_NAME, PROJECT_DEPARTMENT, SIMILARITY_VALUE, IMPLEMENTATION_ORG, RISK_LEVEL
from config import *
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-17890"
MODEL_UIE_PATH = R"../uie/output/checkpoint-17290"
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-31940"
MODEL_UIE_PATH = R"../uie/output/checkpoint-31350"
# 类别名称列表
labels = [
"天气查询", "互联网查询", "页面切换", "日计划数量查询", "周计划数量查询",
"日计划作业内容", "周计划作业内容", "施工人数", "作业考勤人数", "知识问答",
"通用对话", "作业面查询","班组人数查询","班组数查询","作业面内容","班组详情"
"通用对话", "作业面查询", "班组人数查询", "班组数查询", "作业面内容", "班组详情"
]
# 标签映射
@ -43,27 +43,47 @@ label_map = {
13: 'B-teamName', 26: 'I-teamName',
}
# 全局变量
#标准公司名和项目名中文mapping
standard_company_program = {}
#标准分公司名
standard_company_name_list = []
#去不重要词条后中文分公司名和标准化分公司名mapping
simply_to_standard_company_name_map = {}
#去不重要词条后拼音分公司名和标准化分公司名mapping
pinyin_simply_to_standard_company_name_map = {}
#标准公司名和项目名
standard_company_program = load_standard_data("./standard_data/standard_company_program.json")
# 标准工程名
standard_project_name_list = []
#去不重要词条后中文分公司名和标准化分公司名mapping
simply_to_standard_project_name_map = {}
#去不重要词条后工程名拼音和标准化工程名mapping
pinyin_simply_to_standard_project_name_map = {}
# 标准工程名,标准工程名拼音和工程名映射,标准工程名拼音
standard_project_name_list = load_standard_name('./standard_data/standard_project.txt')
pinyin_to_standard_project_name_map = {text_to_pinyin(kw): kw for kw in standard_project_name_list}
standard_project_name_pinyin_list = list(pinyin_to_standard_project_name_map.keys())
def update_data_from_local():
global standard_company_program, standard_company_name_list, simply_to_standard_company_name_map, \
pinyin_simply_to_standard_company_name_map, standard_project_name_list, simply_to_standard_project_name_map, \
pinyin_simply_to_standard_project_name_map
#标准分公司名,标准分公司名拼音和分公司名映射,标公司名拼音
standard_company_name_list = list(standard_company_program.keys())
pinyin_to_standard_company_name_map = {text_to_pinyin(kw): kw for kw in standard_company_name_list}
standard_company_name_pinyin_list = list(pinyin_to_standard_company_name_map.keys())
#标准公司名和项目名中文mapping
temp_standard_company_program = load_standard_data("./standard_data/standard_company_program.json")
if temp_standard_company_program != standard_company_program:
standard_company_program = temp_standard_company_program
standard_company_name_list = list(standard_company_program.keys())
simply_to_standard_company_name_map = {clean_useless_company_name(kw): kw for kw in standard_company_name_list}
pinyin_simply_to_standard_company_name_map = {text_to_pinyin(clean_useless_company_name(kw)): kw for kw in
standard_company_name_list}
simply_to_standard_project_name_map = {clean_useless_project_name(kw): kw for kw in standard_project_name_list}
# 标准工程名
temp_standard_project_name_list = load_standard_name('./standard_data/standard_project.txt')
if temp_standard_project_name_list != standard_project_name_list:
standard_project_name_list = temp_standard_project_name_list
simply_to_standard_project_name_map = {clean_useless_project_name(kw): kw for kw in standard_project_name_list}
pinyin_simply_to_standard_project_name_map = {text_to_pinyin(clean_useless_project_name(kw)): kw for kw in
standard_project_name_list}
pinyin_simply_to_standard_project_name_map = {text_to_pinyin(clean_useless_project_name(kw)): kw for kw in standard_project_name_list}
simply_to_standard_company_name_map = {clean_useless_company_name(kw): kw for kw in standard_company_name_list}
pinyin_simply_to_standard_company_name_map = {text_to_pinyin(clean_useless_company_name(kw)): kw for kw in standard_company_name_list}
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
print(f"Updated data from local at {current_time}")
# 初始化工具类
@ -73,12 +93,11 @@ intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map)
# 设置Flask应用
print(f"标准化的工程名是:{standard_project_name_list}", flush=True)
print(f"pinyin标准化的工程名是 list{standard_project_name_pinyin_list}", flush=True)
print(f"pinyin-工程民对应关系 map{pinyin_to_standard_company_name_map}", flush=True)
update_data_from_local()
app = Flask(__name__)
# 统一的异常处理函数
@app.errorhandler(Exception)
def handle_exception(e):
@ -229,7 +248,8 @@ def agent():
entities = slot_recognizer.recognize(query)
print(
f"第一轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}",flush=True)
f"第一轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}",
flush=True)
# 多轮
else:
res = extract_multi_chat(messages)
@ -245,7 +265,8 @@ def agent():
})
entities = slot_recognizer.recognize(res)
print(
f"多轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}",flush=True)
f"多轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}",
flush=True)
#必须槽位缺失检查
status, sk = check_lost(predicted_id, entities)
@ -273,6 +294,7 @@ def agent():
except Exception as e:
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
def extract_multi_chat(messages):
from openai import OpenAI
client = OpenAI(base_url=api_base_url, api_key=api_key)
@ -379,6 +401,8 @@ def extract_multi_chat(messages):
print(f"多轮意图后用户想要的问题是:{res}", flush=True)
return res
#槽位缺失检查
def check_lost(int_res, slot):
#labels: ["天气查询","通用对话","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"]
mapping = {
@ -398,7 +422,7 @@ def check_lost(int_res, slot):
intention_mapping = {2: "页面切换", 3: "日计划数量查询", 4: "周计划数量查询", 5: "日计划作业内容",
6: "周计划作业内容", 7: "施工人数", 8: "作业考勤人数", 11: "作业面查询",
12:"班组人数查询", 13:"班组数查询", 14:"作业面内容", 15:"班组详情"}
12: "班组人数查询", 13: "班组数查询", 14: "作业面内容", 15: "班组详情"}
if not mapping.__contains__(int_res):
return 0, ""
#提取的槽位信息
@ -423,7 +447,7 @@ def check_lost(int_res, slot):
return CheckResult.NO_MATCH, cur_k
#符合当前意图的的必须槽位,但是不在提取的槽位信息里
left = [x for x in mapping[int_res][idx] if x not in cur_k]
print(f"符合当前意图的的必须槽位,但是不在提取的槽位信息里, {left}",flush=True)
print(f"符合当前意图的的必须槽位,但是不在提取的槽位信息里, {left}", flush=True)
apologize_str = "非常抱歉,"
if int_res == 2:
return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询哪个页面?"
@ -431,7 +455,7 @@ def check_lost(int_res, slot):
return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}"
#标准化工程名
#标准化分公司名,工程名,项目名等
def check_standard_name_slot(int_res, slot) -> tuple:
intention_list = {3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15}
if int_res not in intention_list:
@ -446,8 +470,9 @@ def check_standard_name_slot(int_res, slot) -> tuple:
for key, value in slot.items():
if key == PROJECT_NAME:
print(f"check_standard_name_slot 原始工程名 : {slot[PROJECT_NAME]}")
match_results = standardize_project_name(value, simply_to_standard_project_name_map, pinyin_simply_to_standard_project_name_map,70,90)
print(f"check_standard_name_slot 匹配后工程名 result:{match_results}",flush=True)
match_results = standardize_project_name(value, simply_to_standard_project_name_map,
pinyin_simply_to_standard_project_name_map, 70, 90)
print(f"check_standard_name_slot 匹配后工程名 result:{match_results}", flush=True)
if match_results and len(match_results) == 1:
slot[key] = match_results[0]
else:
@ -456,8 +481,9 @@ def check_standard_name_slot(int_res, slot) -> tuple:
if key == IMPLEMENTATION_ORG and slot[key] != "公司":
print(f"check_standard_name_slot 原始分公司名 : {slot[IMPLEMENTATION_ORG]}")
match_results = standardize_sub_company(value,simply_to_standard_company_name_map, pinyin_simply_to_standard_company_name_map,55,80)
print(f"check_standard_name_slot 匹配后分公司名: result:{match_results}",flush=True)
match_results = standardize_sub_company(value, simply_to_standard_company_name_map,
pinyin_simply_to_standard_company_name_map, 55, 80)
print(f"check_standard_name_slot 匹配后分公司名: result:{match_results}", flush=True)
if match_results and len(match_results) == 1:
slot[key] = match_results[0]
else:
@ -466,8 +492,9 @@ def check_standard_name_slot(int_res, slot) -> tuple:
if key == PROJECT_DEPARTMENT:
print(f"check_standard_name_slot 原始项目部名 : {slot[PROJECT_DEPARTMENT]}")
match_results = standardize_projectDepartment(slot[IMPLEMENTATION_ORG], value, standard_company_program, high_score=90)
print(f"check_standard_name_slot 匹配后项目部名: result:{match_results}",flush=True)
match_results = standardize_projectDepartment(slot[IMPLEMENTATION_ORG], value, standard_company_program,
high_score=90)
print(f"check_standard_name_slot 匹配后项目部名: result:{match_results}", flush=True)
if match_results and len(match_results) == 1:
slot[key] = match_results[0]
else:
@ -475,13 +502,14 @@ def check_standard_name_slot(int_res, slot) -> tuple:
return CheckResult.NEEDS_MORE_ROUNDS, prompt
if key == RISK_LEVEL:
if slot[RISK_LEVEL] not in["2级","3级","4级","5级"] and slot[RISK_LEVEL] not in["二级","三级","四级","五级"]:
if slot[RISK_LEVEL] not in ["2级", "3级", "4级", "5级"] and slot[RISK_LEVEL] not in ["二级", "三级", "四级",
"五级"]:
return CheckResult.NEEDS_MORE_ROUNDS, "您查询的风险等级在系统中未找到,请确认风险等级后再次提问"
return CheckResult.NO_MATCH, ""
#
# #
# test_cases = [
# ("送一分公司"),
# ("送二分公司"),
@ -513,7 +541,7 @@ def check_standard_name_slot(int_res, slot) -> tuple:
# print(f"加权混合策略 耗时: {end - start:.4f} 秒")
#
#
#
# #
# test_cases = [
# ("卢集"),
# ("芦集"),
@ -555,7 +583,8 @@ def check_standard_name_slot(int_res, slot) -> tuple:
# print(f"去不重要词汇 工程名匹配******************************************")
# start = time.perf_counter()
# for item in test_cases:
# match_results = standardize_project_name(item, simply_to_standard_project_name_map, pinyin_simply_to_standard_project_name_map,70,90)
# match_results = standardize_project_name(item, simply_to_standard_project_name_map,
# pinyin_simply_to_standard_project_name_map, 70, 90)
# print(f"工程名匹配 输入: {item}-> 输出: {match_results}")
# end = time.perf_counter()
# print(f"词集匹配 耗时: {end - start:.4f} 秒")

View File

@ -4,7 +4,7 @@ from typing import cast, Callable
from rapidfuzz import process, fuzz
from rapidfuzz.fuzz import WRatio
import json
from pypinyin import lazy_pinyin
from pypinyin import lazy_pinyin, Style
from constants import USELESS_COMPANY_WORDS, USELESS_PROJECT_WORDS
@ -36,7 +36,7 @@ def arabic_to_chinese_number(text):
def text_to_pinyin(text):
"""将文本转换为拼音字符串"""
return ''.join(lazy_pinyin(text))
return ''.join(lazy_pinyin(text, Style.TONE2))
def load_standard_data(path):