标准化算法提优
This commit is contained in:
parent
510e829382
commit
f38aa8ed01
64
api/main.py
64
api/main.py
|
|
@ -43,28 +43,48 @@ label_map = {
|
|||
13: 'B-teamName', 26: 'I-teamName',
|
||||
}
|
||||
|
||||
#标准公司名和项目名
|
||||
standard_company_program = load_standard_data("./standard_data/standard_company_program.json")
|
||||
# 全局变量
|
||||
#标准公司名和项目名中文mapping
|
||||
standard_company_program = {}
|
||||
#标准分公司名
|
||||
standard_company_name_list = []
|
||||
#去不重要词条后中文分公司名和标准化分公司名mapping
|
||||
simply_to_standard_company_name_map = {}
|
||||
#去不重要词条后拼音分公司名和标准化分公司名mapping
|
||||
pinyin_simply_to_standard_company_name_map = {}
|
||||
|
||||
# 标准工程名,标准工程名拼音和工程名映射,标准工程名拼音
|
||||
standard_project_name_list = load_standard_name('./standard_data/standard_project.txt')
|
||||
pinyin_to_standard_project_name_map = {text_to_pinyin(kw): kw for kw in standard_project_name_list}
|
||||
standard_project_name_pinyin_list = list(pinyin_to_standard_project_name_map.keys())
|
||||
# 标准工程名
|
||||
standard_project_name_list = []
|
||||
#去不重要词条后中文分公司名和标准化分公司名mapping
|
||||
simply_to_standard_project_name_map = {}
|
||||
#去不重要词条后工程名拼音和标准化工程名mapping
|
||||
pinyin_simply_to_standard_project_name_map = {}
|
||||
|
||||
#标准分公司名,标准分公司名拼音和分公司名映射,标公司名拼音
|
||||
standard_company_name_list = list(standard_company_program.keys())
|
||||
pinyin_to_standard_company_name_map = {text_to_pinyin(kw): kw for kw in standard_company_name_list}
|
||||
standard_company_name_pinyin_list = list(pinyin_to_standard_company_name_map.keys())
|
||||
def update_data_from_local():
|
||||
global standard_company_program, standard_company_name_list, simply_to_standard_company_name_map, \
|
||||
pinyin_simply_to_standard_company_name_map, standard_project_name_list, simply_to_standard_project_name_map, \
|
||||
pinyin_simply_to_standard_project_name_map
|
||||
|
||||
simply_to_standard_project_name_map = {clean_useless_project_name(kw): kw for kw in standard_project_name_list}
|
||||
#标准公司名和项目名中文mapping
|
||||
temp_standard_company_program = load_standard_data("./standard_data/standard_company_program.json")
|
||||
if temp_standard_company_program != standard_company_program:
|
||||
standard_company_program = temp_standard_company_program
|
||||
standard_company_name_list = list(standard_company_program.keys())
|
||||
simply_to_standard_company_name_map = {clean_useless_company_name(kw): kw for kw in standard_company_name_list}
|
||||
pinyin_simply_to_standard_company_name_map = {text_to_pinyin(clean_useless_company_name(kw)): kw for kw in
|
||||
standard_company_name_list}
|
||||
|
||||
pinyin_simply_to_standard_project_name_map = {text_to_pinyin(clean_useless_project_name(kw)): kw for kw in
|
||||
# 标准工程名
|
||||
temp_standard_project_name_list = load_standard_name('./standard_data/standard_project.txt')
|
||||
if temp_standard_project_name_list != standard_project_name_list:
|
||||
standard_project_name_list = temp_standard_project_name_list
|
||||
simply_to_standard_project_name_map = {clean_useless_project_name(kw): kw for kw in standard_project_name_list}
|
||||
pinyin_simply_to_standard_project_name_map = {text_to_pinyin(clean_useless_project_name(kw)): kw for kw in
|
||||
standard_project_name_list}
|
||||
|
||||
simply_to_standard_company_name_map = {clean_useless_company_name(kw): kw for kw in standard_company_name_list}
|
||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
print(f"Updated data from local at {current_time}")
|
||||
|
||||
pinyin_simply_to_standard_company_name_map = {text_to_pinyin(clean_useless_company_name(kw)): kw for kw in
|
||||
standard_company_name_list}
|
||||
|
||||
# 初始化工具类
|
||||
intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
|
||||
|
|
@ -73,9 +93,7 @@ intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
|
|||
slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map)
|
||||
# 设置Flask应用
|
||||
|
||||
print(f"标准化的工程名是:{standard_project_name_list}", flush=True)
|
||||
print(f"pinyin标准化的工程名是 list:{standard_project_name_pinyin_list}", flush=True)
|
||||
print(f"pinyin-工程民对应关系 map:{pinyin_to_standard_company_name_map}", flush=True)
|
||||
update_data_from_local()
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
|
@ -384,6 +402,7 @@ def extract_multi_chat(messages):
|
|||
return res
|
||||
|
||||
|
||||
#槽位缺失检查
|
||||
def check_lost(int_res, slot):
|
||||
#labels: ["天气查询","通用对话","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"]
|
||||
mapping = {
|
||||
|
|
@ -436,7 +455,7 @@ def check_lost(int_res, slot):
|
|||
return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}?"
|
||||
|
||||
|
||||
#标准化工程名
|
||||
#标准化分公司名,工程名,项目名等
|
||||
def check_standard_name_slot(int_res, slot) -> tuple:
|
||||
intention_list = {3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15}
|
||||
if int_res not in intention_list:
|
||||
|
|
@ -489,8 +508,8 @@ def check_standard_name_slot(int_res, slot) -> tuple:
|
|||
|
||||
return CheckResult.NO_MATCH, ""
|
||||
|
||||
|
||||
#
|
||||
# #
|
||||
# test_cases = [
|
||||
# ("送一分公司"),
|
||||
# ("送二分公司"),
|
||||
|
|
@ -522,7 +541,7 @@ def check_standard_name_slot(int_res, slot) -> tuple:
|
|||
# print(f"加权混合策略 耗时: {end - start:.4f} 秒")
|
||||
#
|
||||
#
|
||||
#
|
||||
# #
|
||||
# test_cases = [
|
||||
# ("卢集"),
|
||||
# ("芦集"),
|
||||
|
|
@ -564,7 +583,8 @@ def check_standard_name_slot(int_res, slot) -> tuple:
|
|||
# print(f"去不重要词汇 工程名匹配******************************************")
|
||||
# start = time.perf_counter()
|
||||
# for item in test_cases:
|
||||
# match_results = standardize_project_name(item, simply_to_standard_project_name_map, pinyin_simply_to_standard_project_name_map,70,90)
|
||||
# match_results = standardize_project_name(item, simply_to_standard_project_name_map,
|
||||
# pinyin_simply_to_standard_project_name_map, 70, 90)
|
||||
# print(f"工程名匹配 输入: {item}-> 输出: {match_results}")
|
||||
# end = time.perf_counter()
|
||||
# print(f"词集匹配 耗时: {end - start:.4f} 秒")
|
||||
|
|
|
|||
103
api/main_temp.py
103
api/main_temp.py
|
|
@ -15,14 +15,14 @@ from utils import CheckResult, load_standard_name, generate_project_prompt, \
|
|||
from constants import PROJECT_NAME, PROJECT_DEPARTMENT, SIMILARITY_VALUE, IMPLEMENTATION_ORG, RISK_LEVEL
|
||||
from config import *
|
||||
|
||||
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-17890"
|
||||
MODEL_UIE_PATH = R"../uie/output/checkpoint-17290"
|
||||
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-31940"
|
||||
MODEL_UIE_PATH = R"../uie/output/checkpoint-31350"
|
||||
|
||||
# 类别名称列表
|
||||
labels = [
|
||||
"天气查询", "互联网查询", "页面切换", "日计划数量查询", "周计划数量查询",
|
||||
"日计划作业内容", "周计划作业内容", "施工人数", "作业考勤人数", "知识问答",
|
||||
"通用对话", "作业面查询","班组人数查询","班组数查询","作业面内容","班组详情"
|
||||
"通用对话", "作业面查询", "班组人数查询", "班组数查询", "作业面内容", "班组详情"
|
||||
]
|
||||
|
||||
# 标签映射
|
||||
|
|
@ -43,27 +43,47 @@ label_map = {
|
|||
13: 'B-teamName', 26: 'I-teamName',
|
||||
}
|
||||
|
||||
# 全局变量
|
||||
#标准公司名和项目名中文mapping
|
||||
standard_company_program = {}
|
||||
#标准分公司名
|
||||
standard_company_name_list = []
|
||||
#去不重要词条后中文分公司名和标准化分公司名mapping
|
||||
simply_to_standard_company_name_map = {}
|
||||
#去不重要词条后拼音分公司名和标准化分公司名mapping
|
||||
pinyin_simply_to_standard_company_name_map = {}
|
||||
|
||||
#标准公司名和项目名
|
||||
standard_company_program = load_standard_data("./standard_data/standard_company_program.json")
|
||||
# 标准工程名
|
||||
standard_project_name_list = []
|
||||
#去不重要词条后中文分公司名和标准化分公司名mapping
|
||||
simply_to_standard_project_name_map = {}
|
||||
#去不重要词条后工程名拼音和标准化工程名mapping
|
||||
pinyin_simply_to_standard_project_name_map = {}
|
||||
|
||||
# 标准工程名,标准工程名拼音和工程名映射,标准工程名拼音
|
||||
standard_project_name_list = load_standard_name('./standard_data/standard_project.txt')
|
||||
pinyin_to_standard_project_name_map = {text_to_pinyin(kw): kw for kw in standard_project_name_list}
|
||||
standard_project_name_pinyin_list = list(pinyin_to_standard_project_name_map.keys())
|
||||
def update_data_from_local():
|
||||
global standard_company_program, standard_company_name_list, simply_to_standard_company_name_map, \
|
||||
pinyin_simply_to_standard_company_name_map, standard_project_name_list, simply_to_standard_project_name_map, \
|
||||
pinyin_simply_to_standard_project_name_map
|
||||
|
||||
#标准分公司名,标准分公司名拼音和分公司名映射,标公司名拼音
|
||||
standard_company_name_list = list(standard_company_program.keys())
|
||||
pinyin_to_standard_company_name_map = {text_to_pinyin(kw): kw for kw in standard_company_name_list}
|
||||
standard_company_name_pinyin_list = list(pinyin_to_standard_company_name_map.keys())
|
||||
#标准公司名和项目名中文mapping
|
||||
temp_standard_company_program = load_standard_data("./standard_data/standard_company_program.json")
|
||||
if temp_standard_company_program != standard_company_program:
|
||||
standard_company_program = temp_standard_company_program
|
||||
standard_company_name_list = list(standard_company_program.keys())
|
||||
simply_to_standard_company_name_map = {clean_useless_company_name(kw): kw for kw in standard_company_name_list}
|
||||
pinyin_simply_to_standard_company_name_map = {text_to_pinyin(clean_useless_company_name(kw)): kw for kw in
|
||||
standard_company_name_list}
|
||||
|
||||
simply_to_standard_project_name_map = {clean_useless_project_name(kw): kw for kw in standard_project_name_list}
|
||||
# 标准工程名
|
||||
temp_standard_project_name_list = load_standard_name('./standard_data/standard_project.txt')
|
||||
if temp_standard_project_name_list != standard_project_name_list:
|
||||
standard_project_name_list = temp_standard_project_name_list
|
||||
simply_to_standard_project_name_map = {clean_useless_project_name(kw): kw for kw in standard_project_name_list}
|
||||
pinyin_simply_to_standard_project_name_map = {text_to_pinyin(clean_useless_project_name(kw)): kw for kw in
|
||||
standard_project_name_list}
|
||||
|
||||
pinyin_simply_to_standard_project_name_map = {text_to_pinyin(clean_useless_project_name(kw)): kw for kw in standard_project_name_list}
|
||||
|
||||
simply_to_standard_company_name_map = {clean_useless_company_name(kw): kw for kw in standard_company_name_list}
|
||||
|
||||
pinyin_simply_to_standard_company_name_map = {text_to_pinyin(clean_useless_company_name(kw)): kw for kw in standard_company_name_list}
|
||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
print(f"Updated data from local at {current_time}")
|
||||
|
||||
|
||||
# 初始化工具类
|
||||
|
|
@ -73,12 +93,11 @@ intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
|
|||
slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map)
|
||||
# 设置Flask应用
|
||||
|
||||
print(f"标准化的工程名是:{standard_project_name_list}", flush=True)
|
||||
print(f"pinyin标准化的工程名是 list:{standard_project_name_pinyin_list}", flush=True)
|
||||
print(f"pinyin-工程民对应关系 map:{pinyin_to_standard_company_name_map}", flush=True)
|
||||
update_data_from_local()
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
# 统一的异常处理函数
|
||||
@app.errorhandler(Exception)
|
||||
def handle_exception(e):
|
||||
|
|
@ -229,7 +248,8 @@ def agent():
|
|||
entities = slot_recognizer.recognize(query)
|
||||
|
||||
print(
|
||||
f"第一轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}",flush=True)
|
||||
f"第一轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}",
|
||||
flush=True)
|
||||
# 多轮
|
||||
else:
|
||||
res = extract_multi_chat(messages)
|
||||
|
|
@ -245,7 +265,8 @@ def agent():
|
|||
})
|
||||
entities = slot_recognizer.recognize(res)
|
||||
print(
|
||||
f"多轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}",flush=True)
|
||||
f"多轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}",
|
||||
flush=True)
|
||||
|
||||
#必须槽位缺失检查
|
||||
status, sk = check_lost(predicted_id, entities)
|
||||
|
|
@ -273,6 +294,7 @@ def agent():
|
|||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
|
||||
|
||||
|
||||
def extract_multi_chat(messages):
|
||||
from openai import OpenAI
|
||||
client = OpenAI(base_url=api_base_url, api_key=api_key)
|
||||
|
|
@ -379,6 +401,8 @@ def extract_multi_chat(messages):
|
|||
print(f"多轮意图后用户想要的问题是:{res}", flush=True)
|
||||
return res
|
||||
|
||||
|
||||
#槽位缺失检查
|
||||
def check_lost(int_res, slot):
|
||||
#labels: ["天气查询","通用对话","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"]
|
||||
mapping = {
|
||||
|
|
@ -398,7 +422,7 @@ def check_lost(int_res, slot):
|
|||
|
||||
intention_mapping = {2: "页面切换", 3: "日计划数量查询", 4: "周计划数量查询", 5: "日计划作业内容",
|
||||
6: "周计划作业内容", 7: "施工人数", 8: "作业考勤人数", 11: "作业面查询",
|
||||
12:"班组人数查询", 13:"班组数查询", 14:"作业面内容", 15:"班组详情"}
|
||||
12: "班组人数查询", 13: "班组数查询", 14: "作业面内容", 15: "班组详情"}
|
||||
if not mapping.__contains__(int_res):
|
||||
return 0, ""
|
||||
#提取的槽位信息
|
||||
|
|
@ -423,7 +447,7 @@ def check_lost(int_res, slot):
|
|||
return CheckResult.NO_MATCH, cur_k
|
||||
#符合当前意图的的必须槽位,但是不在提取的槽位信息里
|
||||
left = [x for x in mapping[int_res][idx] if x not in cur_k]
|
||||
print(f"符合当前意图的的必须槽位,但是不在提取的槽位信息里, {left}",flush=True)
|
||||
print(f"符合当前意图的的必须槽位,但是不在提取的槽位信息里, {left}", flush=True)
|
||||
apologize_str = "非常抱歉,"
|
||||
if int_res == 2:
|
||||
return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询哪个页面?"
|
||||
|
|
@ -431,7 +455,7 @@ def check_lost(int_res, slot):
|
|||
return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}?"
|
||||
|
||||
|
||||
#标准化工程名
|
||||
#标准化分公司名,工程名,项目名等
|
||||
def check_standard_name_slot(int_res, slot) -> tuple:
|
||||
intention_list = {3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15}
|
||||
if int_res not in intention_list:
|
||||
|
|
@ -446,8 +470,9 @@ def check_standard_name_slot(int_res, slot) -> tuple:
|
|||
for key, value in slot.items():
|
||||
if key == PROJECT_NAME:
|
||||
print(f"check_standard_name_slot 原始工程名 : {slot[PROJECT_NAME]}")
|
||||
match_results = standardize_project_name(value, simply_to_standard_project_name_map, pinyin_simply_to_standard_project_name_map,70,90)
|
||||
print(f"check_standard_name_slot 匹配后工程名 :result:{match_results}",flush=True)
|
||||
match_results = standardize_project_name(value, simply_to_standard_project_name_map,
|
||||
pinyin_simply_to_standard_project_name_map, 70, 90)
|
||||
print(f"check_standard_name_slot 匹配后工程名 :result:{match_results}", flush=True)
|
||||
if match_results and len(match_results) == 1:
|
||||
slot[key] = match_results[0]
|
||||
else:
|
||||
|
|
@ -456,8 +481,9 @@ def check_standard_name_slot(int_res, slot) -> tuple:
|
|||
|
||||
if key == IMPLEMENTATION_ORG and slot[key] != "公司":
|
||||
print(f"check_standard_name_slot 原始分公司名 : {slot[IMPLEMENTATION_ORG]}")
|
||||
match_results = standardize_sub_company(value,simply_to_standard_company_name_map, pinyin_simply_to_standard_company_name_map,55,80)
|
||||
print(f"check_standard_name_slot 匹配后分公司名: result:{match_results}",flush=True)
|
||||
match_results = standardize_sub_company(value, simply_to_standard_company_name_map,
|
||||
pinyin_simply_to_standard_company_name_map, 55, 80)
|
||||
print(f"check_standard_name_slot 匹配后分公司名: result:{match_results}", flush=True)
|
||||
if match_results and len(match_results) == 1:
|
||||
slot[key] = match_results[0]
|
||||
else:
|
||||
|
|
@ -466,8 +492,9 @@ def check_standard_name_slot(int_res, slot) -> tuple:
|
|||
|
||||
if key == PROJECT_DEPARTMENT:
|
||||
print(f"check_standard_name_slot 原始项目部名 : {slot[PROJECT_DEPARTMENT]}")
|
||||
match_results = standardize_projectDepartment(slot[IMPLEMENTATION_ORG], value, standard_company_program, high_score=90)
|
||||
print(f"check_standard_name_slot 匹配后项目部名: result:{match_results}",flush=True)
|
||||
match_results = standardize_projectDepartment(slot[IMPLEMENTATION_ORG], value, standard_company_program,
|
||||
high_score=90)
|
||||
print(f"check_standard_name_slot 匹配后项目部名: result:{match_results}", flush=True)
|
||||
if match_results and len(match_results) == 1:
|
||||
slot[key] = match_results[0]
|
||||
else:
|
||||
|
|
@ -475,13 +502,14 @@ def check_standard_name_slot(int_res, slot) -> tuple:
|
|||
return CheckResult.NEEDS_MORE_ROUNDS, prompt
|
||||
|
||||
if key == RISK_LEVEL:
|
||||
if slot[RISK_LEVEL] not in["2级","3级","4级","5级"] and slot[RISK_LEVEL] not in["二级","三级","四级","五级"]:
|
||||
if slot[RISK_LEVEL] not in ["2级", "3级", "4级", "5级"] and slot[RISK_LEVEL] not in ["二级", "三级", "四级",
|
||||
"五级"]:
|
||||
return CheckResult.NEEDS_MORE_ROUNDS, "您查询的风险等级在系统中未找到,请确认风险等级后再次提问"
|
||||
|
||||
return CheckResult.NO_MATCH, ""
|
||||
|
||||
|
||||
#
|
||||
# #
|
||||
# test_cases = [
|
||||
# ("送一分公司"),
|
||||
# ("送二分公司"),
|
||||
|
|
@ -513,7 +541,7 @@ def check_standard_name_slot(int_res, slot) -> tuple:
|
|||
# print(f"加权混合策略 耗时: {end - start:.4f} 秒")
|
||||
#
|
||||
#
|
||||
#
|
||||
# #
|
||||
# test_cases = [
|
||||
# ("卢集"),
|
||||
# ("芦集"),
|
||||
|
|
@ -555,7 +583,8 @@ def check_standard_name_slot(int_res, slot) -> tuple:
|
|||
# print(f"去不重要词汇 工程名匹配******************************************")
|
||||
# start = time.perf_counter()
|
||||
# for item in test_cases:
|
||||
# match_results = standardize_project_name(item, simply_to_standard_project_name_map, pinyin_simply_to_standard_project_name_map,70,90)
|
||||
# match_results = standardize_project_name(item, simply_to_standard_project_name_map,
|
||||
# pinyin_simply_to_standard_project_name_map, 70, 90)
|
||||
# print(f"工程名匹配 输入: {item}-> 输出: {match_results}")
|
||||
# end = time.perf_counter()
|
||||
# print(f"词集匹配 耗时: {end - start:.4f} 秒")
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import cast, Callable
|
|||
from rapidfuzz import process, fuzz
|
||||
from rapidfuzz.fuzz import WRatio
|
||||
import json
|
||||
from pypinyin import lazy_pinyin
|
||||
from pypinyin import lazy_pinyin, Style
|
||||
|
||||
from constants import USELESS_COMPANY_WORDS, USELESS_PROJECT_WORDS
|
||||
|
||||
|
|
@ -36,7 +36,7 @@ def arabic_to_chinese_number(text):
|
|||
|
||||
def text_to_pinyin(text):
|
||||
"""将文本转换为拼音字符串"""
|
||||
return ''.join(lazy_pinyin(text))
|
||||
return ''.join(lazy_pinyin(text, Style.TONE2))
|
||||
|
||||
|
||||
def load_standard_data(path):
|
||||
|
|
|
|||
Loading…
Reference in New Issue