Intention/api/main_temp.py

606 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from flask import Flask, jsonify, request
from pydantic import BaseModel, Field
from werkzeug.exceptions import HTTPException
from typing import List
from pydantic import ValidationError
import time
from intentRecognition import IntentRecognition
from slotRecognition import SlotRecognition
from utils import CheckResult, load_standard_name, generate_project_prompt, \
load_standard_data, text_to_pinyin, \
standardize_projectDepartment, standardize_project_name, clean_useless_project_name, \
clean_useless_company_name, standardize_sub_company
from constants import PROJECT_NAME, PROJECT_DEPARTMENT, SIMILARITY_VALUE, IMPLEMENTATION_ORG, RISK_LEVEL
from config import *
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-17890"
MODEL_UIE_PATH = R"../uie/output/checkpoint-17290"
# 类别名称列表
labels = [
"天气查询", "互联网查询", "页面切换", "日计划数量查询", "周计划数量查询",
"日计划作业内容", "周计划作业内容", "施工人数", "作业考勤人数", "知识问答",
"通用对话", "作业面查询","班组人数查询","班组数查询","作业面内容","班组详情"
]
# 标签映射
label_map = {
0: 'O', # 非实体
1: 'B-date', 14: 'I-date',
2: 'B-projectName', 15: 'I-projectName',
3: 'B-projectType', 16: 'I-projectType',
4: 'B-constructionUnit', 17: 'I-constructionUnit',
5: 'B-implementationOrganization', 18: 'I-implementationOrganization',
6: 'B-projectDepartment', 19: 'I-projectDepartment',
7: 'B-projectManager', 20: 'I-projectManager',
8: 'B-subcontractor', 21: 'I-subcontractor',
9: 'B-teamLeader', 22: 'I-teamLeader',
10: 'B-riskLevel', 23: 'I-riskLevel',
11: 'B-page', 24: 'I-page',
12: 'B-operating', 25: 'I-operating',
13: 'B-teamName', 26: 'I-teamName',
}
#标准公司名和项目名
standard_company_program = load_standard_data("./standard_data/standard_company_program.json")
# 标准工程名,标准工程名拼音和工程名映射,标准工程名拼音
standard_project_name_list = load_standard_name('./standard_data/standard_project.txt')
pinyin_to_standard_project_name_map = {text_to_pinyin(kw): kw for kw in standard_project_name_list}
standard_project_name_pinyin_list = list(pinyin_to_standard_project_name_map.keys())
#标准分公司名,标准分公司名拼音和分公司名映射,标公司名拼音
standard_company_name_list = list(standard_company_program.keys())
pinyin_to_standard_company_name_map = {text_to_pinyin(kw): kw for kw in standard_company_name_list}
standard_company_name_pinyin_list = list(pinyin_to_standard_company_name_map.keys())
simply_to_standard_project_name_map = {clean_useless_project_name(kw): kw for kw in standard_project_name_list}
pinyin_simply_to_standard_project_name_map = {text_to_pinyin(clean_useless_project_name(kw)): kw for kw in standard_project_name_list}
simply_to_standard_company_name_map = {clean_useless_company_name(kw): kw for kw in standard_company_name_list}
pinyin_simply_to_standard_company_name_map = {text_to_pinyin(clean_useless_company_name(kw)): kw for kw in standard_company_name_list}
# 初始化工具类
intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
# 初始化槽位识别工具类
slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map)
# 设置Flask应用
print(f"标准化的工程名是:{standard_project_name_list}", flush=True)
print(f"pinyin标准化的工程名是 list{standard_project_name_pinyin_list}", flush=True)
print(f"pinyin-工程民对应关系 map{pinyin_to_standard_company_name_map}", flush=True)
app = Flask(__name__)
# 统一的异常处理函数
@app.errorhandler(Exception)
def handle_exception(e):
"""统一异常处理"""
if isinstance(e, HTTPException):
return jsonify({
"error": {
"type": e.name,
"message": e.description,
"status_code": e.code
}
}), e.code
return jsonify({
"error": {
"type": "InternalServerError",
"message": str(e)
}
}), 500
def validate_user(data):
"""验证用户ID"""
if data.get("user_id") != '3bb66776-1722-4c36-b14a-73dd210fe750':
return jsonify(
code=401,
msg='权限验证失败,请联系接口开发人员',
label=-1,
probability=-1
), 401
return None
class LabelMessage(BaseModel):
text: str = Field(..., description="消息内容")
user_id: str = Field(..., description="消息内容")
# 每条消息的结构
class Message(BaseModel):
role: str = Field(..., description="消息内容")
content: str = Field(..., description="消息内容")
# 请求数据的结构
class RequestData(BaseModel):
messages: List[Message] = Field(..., description="消息列表")
user_id: str = Field(..., description="用户ID")
# 意图识别
@app.route('/intent_reco', methods=['POST'])
def intent_reco():
"""意图识别"""
try:
# 获取请求中的 JSON 数据
data = request.get_json()
request_data = LabelMessage(**data) # Pydantic 会验证数据结构
text = request_data.text
user_id = request_data.user_id
# 检查必需字段
if not text:
return jsonify({"error": "text is required"}), 400
if not user_id:
return jsonify({"error": "user_id is required"}), 400
# 验证用户ID
user_validation_error = validate_user(data)
if user_validation_error:
return user_validation_error
# 调用predict方法进行意图识别
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(text)
return jsonify(
code=200,
msg="成功",
int=predicted_id,
label=predicted_label,
probability=float(predicted_probability)
)
except Exception as e:
return jsonify({"error": str(e)}), 500
# 槽位抽取
@app.route('/slot_reco', methods=['POST'])
def slot_reco():
"""槽位识别"""
try:
# 获取请求中的 JSON 数据
data = request.get_json()
request_data = LabelMessage(**data) # Pydantic 会验证数据结构
text = request_data.text
user_id = request_data.user_id
# 检查必需字段
if not text:
return jsonify({"error": "text is required"}), 400
if not user_id:
return jsonify({"error": "user_id is required"}), 400
# 验证用户ID
user_validation_error = validate_user(data)
if user_validation_error:
return user_validation_error
# 调用 recognize 方法进行槽位识别
entities = slot_recognizer.recognize(text)
return jsonify(
code=200,
msg="成功",
slot=entities)
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route('/agent', methods=['POST'])
def agent():
try:
data = request.get_json()
except Exception as e:
print(f"body不是一个有效的json")
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
try:
# 使用 Pydantic 来验证数据结构
request_data = RequestData(**data) # Pydantic 会验证数据结构
messages = request_data.messages
user_id = request_data.user_id
# 检查必需字段是否存在
if not messages:
return jsonify({"error": "messages is required"}), 400
if not user_id:
return jsonify({"error": "user_id is required"}), 400
# 验证用户ID假设这个函数已经定义
user_validation_error = validate_user(data)
if user_validation_error:
return user_validation_error
if len(messages) == 1: # 首轮
query = messages[0].content # 使用 Message 对象的 .content 属性
# 先进行意图识别
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query)
# 再进行槽位抽取
entities = slot_recognizer.recognize(query)
print(
f"第一轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}",flush=True)
# 多轮
else:
res = extract_multi_chat(messages)
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(res)
#0:天气1互联网查询9知识问答10通用对话
if predicted_id in [0, 1, 9, 10]:
print(f"多轮意图识别后的label:{predicted_label}, id:{predicted_id},message:{messages}",
flush=True)
return jsonify({
"code": 200, "msg": "成功",
"answer": {"int": predicted_id, "label": predicted_label, "probability": predicted_probability},
"finalQuery": res
})
entities = slot_recognizer.recognize(res)
print(
f"多轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}",flush=True)
#必须槽位缺失检查
status, sk = check_lost(predicted_id, entities)
if status == CheckResult.NEEDS_MORE_ROUNDS:
return jsonify({"code": 10001, "msg": "成功",
"answer": {"miss": sk},
})
#工程名、分公司名和项目名标准化
result, information = check_standard_name_slot(predicted_id, entities)
if result == CheckResult.NEEDS_MORE_ROUNDS:
return jsonify({
"code": 10001, "msg": "成功",
"answer": {"miss": information},
})
return jsonify({
"code": 200, "msg": "成功",
"answer": {"int": predicted_id, "label": predicted_label, "probability": predicted_probability,
"slot": entities},
})
except ValidationError as e:
return jsonify({"error": e.errors()}), 400 # 捕捉 Pydantic 错误并返回
except Exception as e:
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
def extract_multi_chat(messages):
from openai import OpenAI
client = OpenAI(base_url=api_base_url, api_key=api_key)
latest_message = messages[-1] # 最后一条用户提问
if latest_message.role == "user":
latest_user_question = latest_message.content.strip()
time_prefixes = ["今天", "昨天", "本周", "下周", "明天", "今日"] # 可扩展的时间前缀列表
if any(latest_user_question.startswith(prefix) for prefix in time_prefixes):
history_messages = []
else:
history_messages = messages[:-1] # 除最后一条之外的历史记录
# 格式化对话历史
chat_history = "\n".join([f"{msg.role}: {msg.content}" for msg in history_messages])
latest_user_question = latest_message.content if latest_message.role == "user" else ""
prompt = f'''
你是一个意图识别与补全助手,你的任务是根据用户的最新问题判断是否需要补全,如果不需要补全,则原样返回用户的最新问题,否则需要结合对话记录请你补用户的最新问题,并只返回最终的完整问题。请严格按照如下逻辑判断并执行:
---
【规则判断与补全流程】
第一步:用户最新问题是否以“公司”为主语?→ 原样返回,无需补全
- 若用户最新问题主语是“公司”,直接返回原句,无需补全。
- 主语为“公司”的典型句式:
- 以“公司”开头;
- 以“今天”“昨天”“本周”“下周”等时间词开头,紧跟“公司”作为主语;
- 示例:
- 用户的最新问题:“今天公司有多少四级风险作业计划?”
- 用户的最新问题:“今天公司有多少作业计划”
- 用户的最新问题“公司今天有多少4级风险的作业面
- 最终提问均为: 原句不变。
第二步:用户最新问题是否是完整的问题?→ 原样返回,无需补全
- 若用户最新问题中包含下列之一:具体的项目部名、工程名、分公司名、班组名、地区名等信息,且同时出现作业计划、作业面、班组等查询对象,视为完整问题,直接返回原句,无需补全。
- 示例:
- 用户最新问题:“今天张三班组有多少作业计划?”
- 用户最新问题:“今天绿雪莲塘工程有多少作业计划”
- 最终提问均为: 原句不变。
第三步:用户最新问题是否存在指代词?→ 结合用户最新问题和对话记录进行补全
- 若用户最新问题问题中出现模糊表达如“具体是哪些项”、“是哪两个”、“作业计划分别是什么”、“合肥中心变工程呢”、“具体是哪20项”等请结合上一个用户问题和上一个AI回复补全问题信息。
- 示例1
- 用户最新问题:“具体的作业计划分别是什么”
- 对话记录的最后一个用户问题:“今天送一分公司有多少项作业计划”
- 对话记录的最后一个AI回答“今天送电一分公司有21项作业计划”
- 则最终提问应为:
“今天送电一分公司的21项作业计划分别是什么”
- 示例2
- 用户的最新问题:“具体的作业内容是什么”
- 对话记录的最后一个用户问题:今天送一分公司第一项目部有多少项作业计划
- 对话记录的最后一个AI回答今天送电一分公司第一项目管理部有21项作业计划
- 则最终提问应为:
“今天送电一分公司第一项目管理部的21项作业计划分别是什么”
第四步:用户最新问题是否为序号指代(第一个/第2个→ 用完整工程/项目/公司名替换补全
- 精确提取用户所指的序号如“第3个”指第3个工程名、公司名或项目部名
- 将该工程、公司或项目部的完整名称(包括括号中的编号)提取出来;
- **用完整名称替换掉用户上一个问题中出现的简称或模糊表达,并保留用户问题中的其它部分原样不变(如时间、计划数、内容)不变**
- 示例1
- 用户最新问题:"第一个""第1个"
- 对话记录的最后一个用户问题:"2025年南苑调相机检修(PROJ-2023-0179)今天有多少作业计划""
- 对话记录的最后一个的AI回答列出多个工程名第1个是`检修公司调相机一二次设备检修维护和改造服务框架-2025年南苑调相机检修(PROJ-2023-0179)`
- 则最终提问应为:
`检修公司调相机一二次设备检修维护和改造服务框架-2025年南苑调相机检修(PROJ-2023-0179)今天有多少作业计划`
- 示例2
- 用户的最新问题:"第二个""第2个"
- 对话记录的最后一个用户问题:"宏源电力建设公司第三项目部今天有多少项作业计划""
- 对话记录的最后一个AI回答列出多个分公司名第2个"安徽宏源电力建设有限公司(线路)"
- 则最终提问应为:
"安徽宏源电力建设有限公司(线路)第三项目部今天有多少项作业计划"
第五步:输出最终问题
- 直接输出最终问题(无解释、无多余前缀或后缀)
- 保持句式自然清晰
---
对话记录:
{chat_history}
用户最新问题:
{latest_user_question}
请输出最终问题:'''
message = [
{"role": "user", "content": prompt}
]
print(f"message:{message}")
response = client.chat.completions.create(
messages=message,
model=model_name,
max_tokens=100,
temperature=0.1, # 降低随机性,提高确定性
stream=False
)
res = response.choices[0].message.content.strip()
print(f"多轮意图后用户想要的问题是:{res}", flush=True)
return res
def check_lost(int_res, slot):
#labels: ["天气查询","通用对话","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"]
mapping = {
2: [['page'], ['app'], ['module']],
3: [['date']],
4: [['date']],
5: [['date']],
6: [['date']],
7: [['date']],
8: [['date']],
11: [['date']],
12: [['date']],
13: [['date']],
14: [['date']],
15: [['date']],
}
intention_mapping = {2: "页面切换", 3: "日计划数量查询", 4: "周计划数量查询", 5: "日计划作业内容",
6: "周计划作业内容", 7: "施工人数", 8: "作业考勤人数", 11: "作业面查询",
12:"班组人数查询", 13:"班组数查询", 14:"作业面内容", 15:"班组详情"}
if not mapping.__contains__(int_res):
return 0, ""
#提取的槽位信息
cur_k = list(slot.keys())
idx = -1
idx_len = 99
for i in range(len(mapping[int_res])):
sk = mapping[int_res][i]
#不在提取的槽位信息里,但是在必须槽位表里
miss_params = [x for x in sk if x not in cur_k]
#不在必须槽位表里,但是在提取的槽位信息里
extra_params = [x for x in cur_k if x not in sk]
if len(extra_params) >= 0 and len(miss_params) == 0:
idx = i
idx_len = 0
break
if len(miss_params) < idx_len:
idx = i
idx_len = len(miss_params)
if idx_len == 0: # 匹配通过
return CheckResult.NO_MATCH, cur_k
#符合当前意图的的必须槽位,但是不在提取的槽位信息里
left = [x for x in mapping[int_res][idx] if x not in cur_k]
print(f"符合当前意图的的必须槽位,但是不在提取的槽位信息里, {left}",flush=True)
apologize_str = "非常抱歉,"
if int_res == 2:
return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询哪个页面?"
elif int_res in [3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15]:
return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}"
#标准化工程名
def check_standard_name_slot(int_res, slot) -> tuple:
intention_list = {3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15}
if int_res not in intention_list:
return CheckResult.NO_MATCH, ""
#项目名 当项目名存在时需要一定存在分公司(实施组织)名
if PROJECT_DEPARTMENT in slot:
if IMPLEMENTATION_ORG not in slot:
return CheckResult.NEEDS_MORE_ROUNDS, "请补充该项目部所属的分公司名称"
#工程名和分公司名和项目名标准化
for key, value in slot.items():
if key == PROJECT_NAME:
print(f"check_standard_name_slot 原始工程名 : {slot[PROJECT_NAME]}")
match_results = standardize_project_name(value, simply_to_standard_project_name_map, pinyin_simply_to_standard_project_name_map,70,90)
print(f"check_standard_name_slot 匹配后工程名 result:{match_results}",flush=True)
if match_results and len(match_results) == 1:
slot[key] = match_results[0]
else:
prompt = generate_project_prompt(match_results, original_name=slot[PROJECT_NAME], type="工程名")
return CheckResult.NEEDS_MORE_ROUNDS, prompt
if key == IMPLEMENTATION_ORG and slot[key] != "公司":
print(f"check_standard_name_slot 原始分公司名 : {slot[IMPLEMENTATION_ORG]}")
match_results = standardize_sub_company(value,simply_to_standard_company_name_map, pinyin_simply_to_standard_company_name_map,55,80)
print(f"check_standard_name_slot 匹配后分公司名: result:{match_results}",flush=True)
if match_results and len(match_results) == 1:
slot[key] = match_results[0]
else:
prompt = generate_project_prompt(match_results, original_name=slot[IMPLEMENTATION_ORG], type="分公司名")
return CheckResult.NEEDS_MORE_ROUNDS, prompt
if key == PROJECT_DEPARTMENT:
print(f"check_standard_name_slot 原始项目部名 : {slot[PROJECT_DEPARTMENT]}")
match_results = standardize_projectDepartment(slot[IMPLEMENTATION_ORG], value, standard_company_program, high_score=90)
print(f"check_standard_name_slot 匹配后项目部名: result:{match_results}",flush=True)
if match_results and len(match_results) == 1:
slot[key] = match_results[0]
else:
prompt = generate_project_prompt(match_results, original_name=slot[PROJECT_DEPARTMENT], type="项目名")
return CheckResult.NEEDS_MORE_ROUNDS, prompt
if key == RISK_LEVEL:
if slot[RISK_LEVEL] not in["2级","3级","4级","5级"] and slot[RISK_LEVEL] not in["二级","三级","四级","五级"]:
return CheckResult.NEEDS_MORE_ROUNDS, "您查询的风险等级在系统中未找到,请确认风险等级后再次提问"
return CheckResult.NO_MATCH, ""
#
# test_cases = [
# ("送一分公司"),
# ("送二分公司"),
# ("变电分公司"),
# ("建筑分公司"),
# ("检修试验分公司"),
# ("宏源电力公司"),
# ("宏源电力限公司"),
# ("宏源电力限公司线路"),
# ("宏源电力限公司变电"),
# ("送一分"),
# ("送二分"),
# ("变电分"),
# ("建筑分"),
# ("检修试验分"),
# ("宏源电力"),
# ("红源电力"),
# ("宏源电力有限"),
# ("宏源电力限线路"),
# ("宏源电力限变电"),
# ]
#
# print(f"加权混合策略 分公司名匹配**********************")
# start = time.perf_counter()
# for item in test_cases:
# match_results = standardize_sub_company(item,simply_to_standard_company_name_map, pinyin_simply_to_standard_company_name_map,55,80)
# print(f"加权混合策略 分公司名匹配 输入: {item}-> 输出: {match_results}")
# end = time.perf_counter()
# print(f"加权混合策略 耗时: {end - start:.4f} 秒")
#
#
#
# test_cases = [
# ("卢集"),
# ("芦集"),
# ("芦集变电站"),
# ("安庆四变电站"),
# ("锦绣变电站"),
# ("滁州护桥变电站"),
# ("合州换流站"),
# ("陕北合州换流站"),
# ("陕北安徽合州换流站"),
# ("金牛变电站"),
# ("香涧鹭岛工程"),
# ("延庆换流站"),
# ("国网延庆换流站"),
# ("国网北京延庆换流站"),
# ("陶楼广银线路工程"),
# ("紫蓬变电站"),
# ("宿州萧砀变电站"),
# ("冯井变电站"),
# ("富邦秋浦变电站"),
# ("包河玉龙变电站"),
#
# ("绿雪莲塘工程"),
# ("合肥循环园工程"),
# ("合肥长临河工程"),
# ("合肥中心变"),
# ("锁库变电站工程"),
# ("槽坊工程"),
#
# ("安庆四500kV变电站新建工程(PROJ-2024-0862)"),
# ("锦绣-常青π入中心变电站220kV架空线路工程(PROJ-2024-1206)"),
# ("渝北±800千伏换流站电气安装A包(调试部分)(PROJ-2024-1192)"),
# ("先锋-泉河π入安庆四变电站220kV线路工程(PROJ-2024-0834)"),
# ("安徽滁州护桥220kV变电站2号主变扩建工程(PROJ-2024-0821)"),
# ("合州士800千伏换流站电气安装A包(PROJ-2025-0056)"),
# ("卫田-陶楼T接首业变电站110kV电缆线路工程(PROJ-2024-1236)"),
# ("谯城(亳三)-希夷220kV线路工程(PROJ-2024-1205)"),
# ]
# print(f"去不重要词汇 工程名匹配******************************************")
# start = time.perf_counter()
# for item in test_cases:
# match_results = standardize_project_name(item, simply_to_standard_project_name_map, pinyin_simply_to_standard_project_name_map,70,90)
# print(f"工程名匹配 输入: {item}-> 输出: {match_results}")
# end = time.perf_counter()
# print(f"词集匹配 耗时: {end - start:.4f} 秒")
#
# print(f"项目名匹配******************************************")
# oral_program_name_list = [
# ("第1项目部"), # 期望返回所有"第三项目管理部"
# ("第2项目部"),
# ("第3项目部"),
# ("第4项目部"),
# ("第5项目部"),
# ("第6项目部"),
# ("第7项目部"),
# ("第8项目部"),
# ("第9项目部"),
# ("第10项目部"),
# ("第11项目部"),
# ("第12项目部"),
# ("第13项目部"),
# ("电缆班"),
# ("调试1队"),
# ("调试2队"),
# ("调试3队"),
# ("调试4队"),
# ("调试5队"),
# ("第一项目管理部"),
# ("第二项目管理部"),
# ("第五项目管理部"),
# ("第十一项目管理部(萧砀线路)"),
# ("第三项目管理部(张店线路)"),
# ("第三项目管理部(岳西线路)"),
# ("第五项目管理部(蚌埠)"),
# ("第三项目管理部(六安线路)"),
# ("第十一项目管理部(宿州线路)"),
# ("调试一队"),
# ("调试二队"),
# ("调试三队"),
# ("电缆班"),
# ]
#
# for company in standard_company_name_list:
# for program in oral_program_name_list:
# match_results = standardize_projectDepartment(company, program, standard_company_program, high_score=90)
# print(f"加权混合策略 项目部名称 输入: 公司:{company},项目部:{program}-> 输出: {match_results}")
if __name__ == '__main__':
app.run(host='0.0.0.0', port=18073, debug=True)