标准化算法提优

This commit is contained in:
weiweiw 2025-04-18 16:39:06 +08:00
parent 510e829382
commit f38aa8ed01
4 changed files with 164 additions and 115 deletions

View File

@ -43,28 +43,48 @@ label_map = {
13: 'B-teamName', 26: 'I-teamName',
}
#标准公司名和项目名
standard_company_program = load_standard_data("./standard_data/standard_company_program.json")
# 全局变量
#标准公司名和项目名中文mapping
standard_company_program = {}
#标准分公司名
standard_company_name_list = []
#去不重要词条后中文分公司名和标准化分公司名mapping
simply_to_standard_company_name_map = {}
#去不重要词条后拼音分公司名和标准化分公司名mapping
pinyin_simply_to_standard_company_name_map = {}
# 标准工程名,标准工程名拼音和工程名映射,标准工程名拼音
standard_project_name_list = load_standard_name('./standard_data/standard_project.txt')
pinyin_to_standard_project_name_map = {text_to_pinyin(kw): kw for kw in standard_project_name_list}
standard_project_name_pinyin_list = list(pinyin_to_standard_project_name_map.keys())
# 标准工程名
standard_project_name_list = []
#去不重要词条后中文分公司名和标准化分公司名mapping
simply_to_standard_project_name_map = {}
#去不重要词条后工程名拼音和标准化工程名mapping
pinyin_simply_to_standard_project_name_map = {}
#标准分公司名,标准分公司名拼音和分公司名映射,标公司名拼音
def update_data_from_local():
global standard_company_program, standard_company_name_list, simply_to_standard_company_name_map, \
pinyin_simply_to_standard_company_name_map, standard_project_name_list, simply_to_standard_project_name_map, \
pinyin_simply_to_standard_project_name_map
#标准公司名和项目名中文mapping
temp_standard_company_program = load_standard_data("./standard_data/standard_company_program.json")
if temp_standard_company_program != standard_company_program:
standard_company_program = temp_standard_company_program
standard_company_name_list = list(standard_company_program.keys())
pinyin_to_standard_company_name_map = {text_to_pinyin(kw): kw for kw in standard_company_name_list}
standard_company_name_pinyin_list = list(pinyin_to_standard_company_name_map.keys())
simply_to_standard_company_name_map = {clean_useless_company_name(kw): kw for kw in standard_company_name_list}
pinyin_simply_to_standard_company_name_map = {text_to_pinyin(clean_useless_company_name(kw)): kw for kw in
standard_company_name_list}
# 标准工程名
temp_standard_project_name_list = load_standard_name('./standard_data/standard_project.txt')
if temp_standard_project_name_list != standard_project_name_list:
standard_project_name_list = temp_standard_project_name_list
simply_to_standard_project_name_map = {clean_useless_project_name(kw): kw for kw in standard_project_name_list}
pinyin_simply_to_standard_project_name_map = {text_to_pinyin(clean_useless_project_name(kw)): kw for kw in
standard_project_name_list}
simply_to_standard_company_name_map = {clean_useless_company_name(kw): kw for kw in standard_company_name_list}
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
print(f"Updated data from local at {current_time}")
pinyin_simply_to_standard_company_name_map = {text_to_pinyin(clean_useless_company_name(kw)): kw for kw in
standard_company_name_list}
# 初始化工具类
intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
@ -73,9 +93,7 @@ intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map)
# 设置Flask应用
print(f"标准化的工程名是:{standard_project_name_list}", flush=True)
print(f"pinyin标准化的工程名是 list{standard_project_name_pinyin_list}", flush=True)
print(f"pinyin-工程民对应关系 map{pinyin_to_standard_company_name_map}", flush=True)
update_data_from_local()
app = Flask(__name__)
@ -384,6 +402,7 @@ def extract_multi_chat(messages):
return res
#槽位缺失检查
def check_lost(int_res, slot):
#labels: ["天气查询","通用对话","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"]
mapping = {
@ -436,7 +455,7 @@ def check_lost(int_res, slot):
return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}"
#标准化工程名
#标准化分公司名,工程名,项目名等
def check_standard_name_slot(int_res, slot) -> tuple:
intention_list = {3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15}
if int_res not in intention_list:
@ -489,8 +508,8 @@ def check_standard_name_slot(int_res, slot) -> tuple:
return CheckResult.NO_MATCH, ""
#
# #
# test_cases = [
# ("送一分公司"),
# ("送二分公司"),
@ -522,7 +541,7 @@ def check_standard_name_slot(int_res, slot) -> tuple:
# print(f"加权混合策略 耗时: {end - start:.4f} 秒")
#
#
#
# #
# test_cases = [
# ("卢集"),
# ("芦集"),
@ -564,7 +583,8 @@ def check_standard_name_slot(int_res, slot) -> tuple:
# print(f"去不重要词汇 工程名匹配******************************************")
# start = time.perf_counter()
# for item in test_cases:
# match_results = standardize_project_name(item, simply_to_standard_project_name_map, pinyin_simply_to_standard_project_name_map,70,90)
# match_results = standardize_project_name(item, simply_to_standard_project_name_map,
# pinyin_simply_to_standard_project_name_map, 70, 90)
# print(f"工程名匹配 输入: {item}-> 输出: {match_results}")
# end = time.perf_counter()
# print(f"词集匹配 耗时: {end - start:.4f} 秒")

View File

@ -15,8 +15,8 @@ from utils import CheckResult, load_standard_name, generate_project_prompt, \
from constants import PROJECT_NAME, PROJECT_DEPARTMENT, SIMILARITY_VALUE, IMPLEMENTATION_ORG, RISK_LEVEL
from config import *
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-17890"
MODEL_UIE_PATH = R"../uie/output/checkpoint-17290"
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-31940"
MODEL_UIE_PATH = R"../uie/output/checkpoint-31350"
# 类别名称列表
labels = [
@ -43,27 +43,47 @@ label_map = {
13: 'B-teamName', 26: 'I-teamName',
}
# 全局变量
#标准公司名和项目名中文mapping
standard_company_program = {}
#标准分公司名
standard_company_name_list = []
#去不重要词条后中文分公司名和标准化分公司名mapping
simply_to_standard_company_name_map = {}
#去不重要词条后拼音分公司名和标准化分公司名mapping
pinyin_simply_to_standard_company_name_map = {}
#标准公司名和项目名
standard_company_program = load_standard_data("./standard_data/standard_company_program.json")
# 标准工程名
standard_project_name_list = []
#去不重要词条后中文分公司名和标准化分公司名mapping
simply_to_standard_project_name_map = {}
#去不重要词条后工程名拼音和标准化工程名mapping
pinyin_simply_to_standard_project_name_map = {}
# 标准工程名,标准工程名拼音和工程名映射,标准工程名拼音
standard_project_name_list = load_standard_name('./standard_data/standard_project.txt')
pinyin_to_standard_project_name_map = {text_to_pinyin(kw): kw for kw in standard_project_name_list}
standard_project_name_pinyin_list = list(pinyin_to_standard_project_name_map.keys())
def update_data_from_local():
global standard_company_program, standard_company_name_list, simply_to_standard_company_name_map, \
pinyin_simply_to_standard_company_name_map, standard_project_name_list, simply_to_standard_project_name_map, \
pinyin_simply_to_standard_project_name_map
#标准分公司名,标准分公司名拼音和分公司名映射,标公司名拼音
#标准公司名和项目名中文mapping
temp_standard_company_program = load_standard_data("./standard_data/standard_company_program.json")
if temp_standard_company_program != standard_company_program:
standard_company_program = temp_standard_company_program
standard_company_name_list = list(standard_company_program.keys())
pinyin_to_standard_company_name_map = {text_to_pinyin(kw): kw for kw in standard_company_name_list}
standard_company_name_pinyin_list = list(pinyin_to_standard_company_name_map.keys())
simply_to_standard_project_name_map = {clean_useless_project_name(kw): kw for kw in standard_project_name_list}
pinyin_simply_to_standard_project_name_map = {text_to_pinyin(clean_useless_project_name(kw)): kw for kw in standard_project_name_list}
simply_to_standard_company_name_map = {clean_useless_company_name(kw): kw for kw in standard_company_name_list}
pinyin_simply_to_standard_company_name_map = {text_to_pinyin(clean_useless_company_name(kw)): kw for kw in
standard_company_name_list}
pinyin_simply_to_standard_company_name_map = {text_to_pinyin(clean_useless_company_name(kw)): kw for kw in standard_company_name_list}
# 标准工程名
temp_standard_project_name_list = load_standard_name('./standard_data/standard_project.txt')
if temp_standard_project_name_list != standard_project_name_list:
standard_project_name_list = temp_standard_project_name_list
simply_to_standard_project_name_map = {clean_useless_project_name(kw): kw for kw in standard_project_name_list}
pinyin_simply_to_standard_project_name_map = {text_to_pinyin(clean_useless_project_name(kw)): kw for kw in
standard_project_name_list}
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
print(f"Updated data from local at {current_time}")
# 初始化工具类
@ -73,12 +93,11 @@ intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map)
# 设置Flask应用
print(f"标准化的工程名是:{standard_project_name_list}", flush=True)
print(f"pinyin标准化的工程名是 list{standard_project_name_pinyin_list}", flush=True)
print(f"pinyin-工程民对应关系 map{pinyin_to_standard_company_name_map}", flush=True)
update_data_from_local()
app = Flask(__name__)
# 统一的异常处理函数
@app.errorhandler(Exception)
def handle_exception(e):
@ -229,7 +248,8 @@ def agent():
entities = slot_recognizer.recognize(query)
print(
f"第一轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}",flush=True)
f"第一轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}",
flush=True)
# 多轮
else:
res = extract_multi_chat(messages)
@ -245,7 +265,8 @@ def agent():
})
entities = slot_recognizer.recognize(res)
print(
f"多轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}",flush=True)
f"多轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}",
flush=True)
#必须槽位缺失检查
status, sk = check_lost(predicted_id, entities)
@ -273,6 +294,7 @@ def agent():
except Exception as e:
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
def extract_multi_chat(messages):
from openai import OpenAI
client = OpenAI(base_url=api_base_url, api_key=api_key)
@ -379,6 +401,8 @@ def extract_multi_chat(messages):
print(f"多轮意图后用户想要的问题是:{res}", flush=True)
return res
#槽位缺失检查
def check_lost(int_res, slot):
#labels: ["天气查询","通用对话","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"]
mapping = {
@ -431,7 +455,7 @@ def check_lost(int_res, slot):
return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}"
#标准化工程名
#标准化分公司名,工程名,项目名等
def check_standard_name_slot(int_res, slot) -> tuple:
intention_list = {3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15}
if int_res not in intention_list:
@ -446,7 +470,8 @@ def check_standard_name_slot(int_res, slot) -> tuple:
for key, value in slot.items():
if key == PROJECT_NAME:
print(f"check_standard_name_slot 原始工程名 : {slot[PROJECT_NAME]}")
match_results = standardize_project_name(value, simply_to_standard_project_name_map, pinyin_simply_to_standard_project_name_map,70,90)
match_results = standardize_project_name(value, simply_to_standard_project_name_map,
pinyin_simply_to_standard_project_name_map, 70, 90)
print(f"check_standard_name_slot 匹配后工程名 result:{match_results}", flush=True)
if match_results and len(match_results) == 1:
slot[key] = match_results[0]
@ -456,7 +481,8 @@ def check_standard_name_slot(int_res, slot) -> tuple:
if key == IMPLEMENTATION_ORG and slot[key] != "公司":
print(f"check_standard_name_slot 原始分公司名 : {slot[IMPLEMENTATION_ORG]}")
match_results = standardize_sub_company(value,simply_to_standard_company_name_map, pinyin_simply_to_standard_company_name_map,55,80)
match_results = standardize_sub_company(value, simply_to_standard_company_name_map,
pinyin_simply_to_standard_company_name_map, 55, 80)
print(f"check_standard_name_slot 匹配后分公司名: result:{match_results}", flush=True)
if match_results and len(match_results) == 1:
slot[key] = match_results[0]
@ -466,7 +492,8 @@ def check_standard_name_slot(int_res, slot) -> tuple:
if key == PROJECT_DEPARTMENT:
print(f"check_standard_name_slot 原始项目部名 : {slot[PROJECT_DEPARTMENT]}")
match_results = standardize_projectDepartment(slot[IMPLEMENTATION_ORG], value, standard_company_program, high_score=90)
match_results = standardize_projectDepartment(slot[IMPLEMENTATION_ORG], value, standard_company_program,
high_score=90)
print(f"check_standard_name_slot 匹配后项目部名: result:{match_results}", flush=True)
if match_results and len(match_results) == 1:
slot[key] = match_results[0]
@ -475,13 +502,14 @@ def check_standard_name_slot(int_res, slot) -> tuple:
return CheckResult.NEEDS_MORE_ROUNDS, prompt
if key == RISK_LEVEL:
if slot[RISK_LEVEL] not in["2级","3级","4级","5级"] and slot[RISK_LEVEL] not in["二级","三级","四级","五级"]:
if slot[RISK_LEVEL] not in ["2级", "3级", "4级", "5级"] and slot[RISK_LEVEL] not in ["二级", "三级", "四级",
"五级"]:
return CheckResult.NEEDS_MORE_ROUNDS, "您查询的风险等级在系统中未找到,请确认风险等级后再次提问"
return CheckResult.NO_MATCH, ""
#
# #
# test_cases = [
# ("送一分公司"),
# ("送二分公司"),
@ -513,7 +541,7 @@ def check_standard_name_slot(int_res, slot) -> tuple:
# print(f"加权混合策略 耗时: {end - start:.4f} 秒")
#
#
#
# #
# test_cases = [
# ("卢集"),
# ("芦集"),
@ -555,7 +583,8 @@ def check_standard_name_slot(int_res, slot) -> tuple:
# print(f"去不重要词汇 工程名匹配******************************************")
# start = time.perf_counter()
# for item in test_cases:
# match_results = standardize_project_name(item, simply_to_standard_project_name_map, pinyin_simply_to_standard_project_name_map,70,90)
# match_results = standardize_project_name(item, simply_to_standard_project_name_map,
# pinyin_simply_to_standard_project_name_map, 70, 90)
# print(f"工程名匹配 输入: {item}-> 输出: {match_results}")
# end = time.perf_counter()
# print(f"词集匹配 耗时: {end - start:.4f} 秒")

View File

@ -4,7 +4,7 @@ from typing import cast, Callable
from rapidfuzz import process, fuzz
from rapidfuzz.fuzz import WRatio
import json
from pypinyin import lazy_pinyin
from pypinyin import lazy_pinyin, Style
from constants import USELESS_COMPANY_WORDS, USELESS_PROJECT_WORDS
@ -36,7 +36,7 @@ def arabic_to_chinese_number(text):
def text_to_pinyin(text):
"""将文本转换为拼音字符串"""
return ''.join(lazy_pinyin(text))
return ''.join(lazy_pinyin(text, Style.TONE2))
def load_standard_data(path):