实现工程名和项目名规范化和多轮
This commit is contained in:
parent
823c6ebafa
commit
2411eba2f0
|
|
@ -0,0 +1,21 @@
|
|||
# constants.py
|
||||
#日期
|
||||
DATE = "date"
|
||||
#工程名称
|
||||
PROJECT_NAME = "projectName"
|
||||
#工程性质
|
||||
PROJECT_TYPE = "projectType"
|
||||
#建管单位
|
||||
CONSTRUCTION_UNIT = "constructionUnit"
|
||||
#分公司
|
||||
IMPLEMENTATION_ORG = "implementationOrganization"
|
||||
#项目部
|
||||
PROJECT_DEPARTMENT = "projectDepartment"
|
||||
#项目经理
|
||||
PROJECT_MANAGER = "projectManager"
|
||||
#分包商
|
||||
SUBCONTRACTOR = "subcontractor"
|
||||
#班组
|
||||
TEAM_LEADER = "teamLeader"
|
||||
#风险
|
||||
RISK_LEVEL = "riskLevel"
|
||||
82
api/mian.py
82
api/mian.py
|
|
@ -12,7 +12,8 @@ from pydantic import ValidationError
|
|||
from intentRecognition import IntentRecognition
|
||||
from slotRecognition import SlotRecognition
|
||||
from fuzzywuzzy import process
|
||||
import utils
|
||||
from utils import CheckResult, StandardType
|
||||
from constants import PROJECT_NAME, PROJECT_DEPARTMENT
|
||||
|
||||
# 常量
|
||||
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-4160"
|
||||
|
|
@ -199,28 +200,26 @@ def agent():
|
|||
entities = slot_recognizer.recognize(query)
|
||||
|
||||
print(f"意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}")
|
||||
|
||||
#必须槽位缺失检查
|
||||
status, sk = check_lost(predicted_id, entities)
|
||||
if status == 1:
|
||||
if status == CheckResult.NEEDS_MORE_ROUNDS:
|
||||
return jsonify({"code": 10001, "msg": "成功",
|
||||
"answer": { "miss": sk},
|
||||
})
|
||||
|
||||
#工程名和项目名标准化
|
||||
result, information = check_project_standard_slot(predicted_id, entities)
|
||||
if result == CheckResult.NEEDS_MORE_ROUNDS:
|
||||
return jsonify({
|
||||
"code": 200,
|
||||
"msg": "成功",
|
||||
"answer": {
|
||||
"miss": sk
|
||||
},
|
||||
"code": 10001, "msg": "成功",
|
||||
"answer": {"miss": information},
|
||||
})
|
||||
|
||||
else:
|
||||
return jsonify({
|
||||
"code": 200,
|
||||
"msg": "成功",
|
||||
"answer": {
|
||||
"int": predicted_id,
|
||||
"label": predicted_label,
|
||||
"probability": predicted_probability,
|
||||
"slot": entities
|
||||
},
|
||||
})
|
||||
return jsonify({
|
||||
"code": 200,"msg": "成功",
|
||||
"answer": {"int": predicted_id, "label": predicted_label, "probability": predicted_probability, "slot": entities },
|
||||
})
|
||||
|
||||
# 如果是后续轮次(多轮对话),这里只做示例,可能需要根据具体需求进行处理
|
||||
else:
|
||||
|
|
@ -248,39 +247,64 @@ def check_lost(int_res, slot):
|
|||
7: [['date']],
|
||||
8: [['date']],
|
||||
}
|
||||
#3:"页面切换",
|
||||
|
||||
intention_mapping = {2: "页面切换", 3: "日计划数量查询", 4: "周计划数量查询", 5: "日计划作业内容",
|
||||
6: "周计划作业内容",7: "施工人数",8: "作业考勤人数"}
|
||||
if not mapping.__contains__(int_res):
|
||||
return 0, ""
|
||||
#提取的槽位信息
|
||||
cur_k = list(slot.keys())
|
||||
idx = -1
|
||||
idx_len = 99
|
||||
for i in range(len(mapping[int_res])):
|
||||
sk = mapping[int_res][i]
|
||||
left = [x for x in sk if x not in cur_k]
|
||||
more = [x for x in cur_k if x not in sk]
|
||||
if len(more) >= 0 and len(left) == 0:
|
||||
#不在提取的槽位信息里,但是在必须槽位表里
|
||||
miss_params = [x for x in sk if x not in cur_k]
|
||||
#不在必须槽位表里,但是在提取的槽位信息里
|
||||
extra_params = [x for x in cur_k if x not in sk]
|
||||
if len(extra_params) >= 0 and len(miss_params) == 0:
|
||||
idx = i
|
||||
idx_len = 0
|
||||
break
|
||||
if len(left) < idx_len:
|
||||
if len(miss_params) < idx_len:
|
||||
idx = i
|
||||
idx_len = len(left)
|
||||
idx_len = len(miss_params)
|
||||
|
||||
if idx_len == 0: # 匹配通过
|
||||
return 0, cur_k
|
||||
return CheckResult.NO_MATCH, cur_k
|
||||
#符合当前意图的的必须槽位,但是不在提取的槽位信息里
|
||||
left = [x for x in mapping[int_res][idx] if x not in cur_k]
|
||||
print(f"符合当前意图的的必须槽位,但是不在提取的槽位信息里, {left}")
|
||||
apologize_str = "非常抱歉,"
|
||||
if int_res == 2:
|
||||
return 1, f"{apologize_str}请问你想查询哪个页面?"
|
||||
return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询哪个页面?"
|
||||
elif int_res in [3, 4, 5, 6, 7, 8]:
|
||||
return 1, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}?"
|
||||
return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}?"
|
||||
|
||||
|
||||
#标准化工程名
|
||||
def check_project_standard_slot(int_res, slot)->tuple:
|
||||
return 0, ''
|
||||
def check_project_standard_slot(int_res, slot) -> tuple:
|
||||
intention_list = {3, 4, 5, 6, 7, 8}
|
||||
if int_res not in intention_list:
|
||||
return CheckResult.NO_MATCH,""
|
||||
|
||||
for key, value in slot.items():
|
||||
if key == PROJECT_NAME:
|
||||
match_project, match_possibility = fuzzy_match(value, standard_project_name_list)
|
||||
if match_possibility >= 0.9:
|
||||
slot[key] = match_project
|
||||
else:
|
||||
return CheckResult.NEEDS_MORE_ROUNDS, f"抱歉,您说的工程名是{match_project}吗"
|
||||
|
||||
if key == PROJECT_DEPARTMENT:
|
||||
match_program, match_possibility = fuzzy_match(value, standard_program_name_list)
|
||||
if match_possibility >= 0.9:
|
||||
slot[key] = match_program
|
||||
else:
|
||||
return CheckResult.NEEDS_MORE_ROUNDS, f"抱歉,您说的项目名是{match_program}吗"
|
||||
|
||||
return CheckResult.NO_MATCH,""
|
||||
|
||||
|
||||
def fuzzy_match(user_input, standard_name):
|
||||
result = process.extract(user_input, standard_name)
|
||||
|
|
|
|||
14
api/utils.py
14
api/utils.py
|
|
@ -1,3 +1,4 @@
|
|||
from enum import Enum
|
||||
def load_standard_name(file_path:str):
|
||||
try:
|
||||
# f = open(file_path, 'r', encoding='utf-8')
|
||||
|
|
@ -16,3 +17,16 @@ def load_standard_name(file_path:str):
|
|||
except Exception as e:
|
||||
print(f"读取文件时发生错误:{e}")
|
||||
raise Exception(f"错误:文件 {file_path} 不存在")
|
||||
|
||||
|
||||
class CheckResult(Enum):
|
||||
NO_MATCH = 0 # 不符合检查条件
|
||||
MATCH_FOUND = 1 # 匹配到了值
|
||||
NEEDS_MORE_ROUNDS = 2 # 需要多轮
|
||||
|
||||
|
||||
class StandardType(Enum):
|
||||
#工程名检查
|
||||
PROJECT_CHECK = 0
|
||||
#项目名检查
|
||||
PROGRAM_CHECK = 1
|
||||
|
|
|
|||
Loading…
Reference in New Issue