491 lines
19 KiB
Python
491 lines
19 KiB
Python
from flask import Flask, jsonify, request
|
||
from pydantic import BaseModel, Field
|
||
from werkzeug.exceptions import HTTPException
|
||
from typing import List
|
||
from pydantic import ValidationError
|
||
import time
|
||
|
||
from intentRecognition import IntentRecognition
|
||
from slotRecognition import SlotRecognition
|
||
from utils import CheckResult, check_standard_name_slot_probability, check_lost
|
||
|
||
from config import *
|
||
|
||
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-22620"
|
||
MODEL_UIE_PATH = R"../uie/output/checkpoint-22320"
|
||
|
||
# 类别名称列表
|
||
labels = [
|
||
"天气查询", "互联网查询", "页面切换", "日计划数量查询", "周计划数量查询",
|
||
"日计划作业内容", "周计划作业内容", "施工人数", "作业考勤人数", "知识问答",
|
||
"通用对话", "作业面查询", "班组人数查询", "班组数查询", "作业面内容", "班组详情"
|
||
]
|
||
|
||
# 标签映射
|
||
label_map = {
|
||
0: 'O', # 非实体
|
||
1: 'B-date', 15: 'I-date',
|
||
2: 'B-projectName', 16: 'I-projectName',
|
||
3: 'B-projectType', 17: 'I-projectType',
|
||
4: 'B-constructionUnit', 18: 'I-constructionUnit',
|
||
5: 'B-implementationOrganization', 19: 'I-implementationOrganization',
|
||
6: 'B-projectDepartment', 20: 'I-projectDepartment',
|
||
7: 'B-projectManager', 21: 'I-projectManager',
|
||
8: 'B-subcontractor', 22: 'I-subcontractor',
|
||
9: 'B-teamLeader', 23: 'I-teamLeader',
|
||
10: 'B-riskLevel', 24: 'I-riskLevel',
|
||
11: 'B-page', 25: 'I-page',
|
||
12: 'B-operating', 26: 'I-operating',
|
||
13: 'B-teamName', 27: 'I-teamName',
|
||
14: 'B-constructionArea', 28: 'I-constructionArea',
|
||
}
|
||
|
||
# 初始化工具类
|
||
intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
|
||
|
||
# 初始化槽位识别工具类
|
||
slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map)
|
||
# 设置Flask应用
|
||
|
||
# update_data_from_local()
|
||
from globalData import GlobalData
|
||
GlobalData.update_from_local()
|
||
|
||
app = Flask(__name__)
|
||
|
||
|
||
# 统一的异常处理函数
|
||
@app.errorhandler(Exception)
|
||
def handle_exception(e):
|
||
"""统一异常处理"""
|
||
if isinstance(e, HTTPException):
|
||
return jsonify({
|
||
"error": {
|
||
"type": e.name,
|
||
"message": e.description,
|
||
"status_code": e.code
|
||
}
|
||
}), e.code
|
||
return jsonify({
|
||
"error": {
|
||
"type": "InternalServerError",
|
||
"message": str(e)
|
||
}
|
||
}), 500
|
||
|
||
|
||
def validate_user(data):
|
||
"""验证用户ID"""
|
||
if data.get("user_id") != '3bb66776-1722-4c36-b14a-73dd210fe750':
|
||
return jsonify(
|
||
code=401,
|
||
msg='权限验证失败,请联系接口开发人员',
|
||
label=-1,
|
||
probability=-1
|
||
), 401
|
||
return None
|
||
|
||
|
||
class LabelMessage(BaseModel):
|
||
text: str = Field(..., description="消息内容")
|
||
user_id: str = Field(..., description="消息内容")
|
||
|
||
|
||
# 每条消息的结构
|
||
class Message(BaseModel):
|
||
role: str = Field(..., description="消息内容")
|
||
content: str = Field(..., description="消息内容")
|
||
|
||
|
||
# 请求数据的结构
|
||
class RequestData(BaseModel):
|
||
messages: List[Message] = Field(..., description="消息列表")
|
||
user_id: str = Field(..., description="用户ID")
|
||
|
||
|
||
# 意图识别
|
||
@app.route('/intent_reco', methods=['POST'])
|
||
def intent_reco():
|
||
"""意图识别"""
|
||
try:
|
||
# 获取请求中的 JSON 数据
|
||
data = request.get_json()
|
||
request_data = LabelMessage(**data) # Pydantic 会验证数据结构
|
||
text = request_data.text
|
||
user_id = request_data.user_id
|
||
# 检查必需字段
|
||
if not text:
|
||
return jsonify({"error": "text is required"}), 400
|
||
if not user_id:
|
||
return jsonify({"error": "user_id is required"}), 400
|
||
|
||
# 验证用户ID
|
||
user_validation_error = validate_user(data)
|
||
if user_validation_error:
|
||
return user_validation_error
|
||
|
||
# 调用predict方法进行意图识别
|
||
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(text)
|
||
|
||
return jsonify(
|
||
code=200,
|
||
msg="成功",
|
||
int=predicted_id,
|
||
label=predicted_label,
|
||
probability=float(predicted_probability)
|
||
)
|
||
|
||
except Exception as e:
|
||
return jsonify({"error": str(e)}), 500
|
||
|
||
|
||
# 槽位抽取
|
||
@app.route('/slot_reco', methods=['POST'])
|
||
def slot_reco():
|
||
"""槽位识别"""
|
||
try:
|
||
# 获取请求中的 JSON 数据
|
||
data = request.get_json()
|
||
request_data = LabelMessage(**data) # Pydantic 会验证数据结构
|
||
text = request_data.text
|
||
user_id = request_data.user_id
|
||
|
||
# 检查必需字段
|
||
if not text:
|
||
return jsonify({"error": "text is required"}), 400
|
||
if not user_id:
|
||
return jsonify({"error": "user_id is required"}), 400
|
||
|
||
# 验证用户ID
|
||
user_validation_error = validate_user(data)
|
||
if user_validation_error:
|
||
return user_validation_error
|
||
|
||
# 调用 recognize 方法进行槽位识别
|
||
entities, slot_probability = slot_recognizer.recognize_probability(text)
|
||
print(
|
||
f"槽位抽取后的实体:{entities},实体后的可能值:{slot_probability}",
|
||
flush=True)
|
||
return jsonify(
|
||
code=200,
|
||
msg="成功",
|
||
slot=entities)
|
||
|
||
except Exception as e:
|
||
return jsonify({"error": str(e)}), 500
|
||
|
||
|
||
@app.route('/agent', methods=['POST'])
|
||
def agent():
|
||
try:
|
||
data = request.get_json()
|
||
except Exception as e:
|
||
print(f"body不是一个有效的json")
|
||
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
|
||
try:
|
||
# 使用 Pydantic 来验证数据结构
|
||
request_data = RequestData(**data) # Pydantic 会验证数据结构
|
||
messages = request_data.messages
|
||
user_id = request_data.user_id
|
||
|
||
# 检查必需字段是否存在
|
||
if not messages:
|
||
return jsonify({"error": "messages is required"}), 400
|
||
if not user_id:
|
||
return jsonify({"error": "user_id is required"}), 400
|
||
|
||
# 验证用户ID(假设这个函数已经定义)
|
||
user_validation_error = validate_user(data)
|
||
if user_validation_error:
|
||
return user_validation_error
|
||
if len(messages) == 1: # 首轮
|
||
query = messages[0].content # 使用 Message 对象的 .content 属性
|
||
# 先进行意图识别
|
||
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query)
|
||
# 再进行槽位抽取
|
||
entities,slot_probability = slot_recognizer.recognize_probability(query)
|
||
print(
|
||
f"第一轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},slot_probability:{slot_probability},message:{messages}",
|
||
flush=True)
|
||
# 多轮
|
||
else:
|
||
res = extract_multi_chat(messages)
|
||
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(res)
|
||
#0:天气,1:互联网查询,9:知识问答,10:通用对话
|
||
if predicted_id in [0, 1, 9, 10]:
|
||
print(f"多轮意图识别后的label:{predicted_label}, id:{predicted_id},message:{messages}",
|
||
flush=True)
|
||
return jsonify({
|
||
"code": 200, "msg": "成功",
|
||
"answer": {"int": predicted_id, "label": predicted_label, "probability": predicted_probability},
|
||
"finalQuery": res
|
||
})
|
||
entities, slot_probability = slot_recognizer.recognize_probability(res)
|
||
print(
|
||
f"多轮意图识别后的槽位:槽位抽取后的实体:{entities},slot_probability:{slot_probability}",
|
||
flush=True)
|
||
|
||
#必须槽位缺失检查
|
||
status, sk = check_lost(predicted_id, entities)
|
||
if status == CheckResult.NEEDS_MORE_ROUNDS:
|
||
return jsonify({"code": 10001, "msg": "成功",
|
||
"answer": {"miss": sk},
|
||
})
|
||
|
||
#工程名、分公司名和项目名标准化
|
||
result, information = check_standard_name_slot_probability(predicted_id, entities)
|
||
if result == CheckResult.NEEDS_MORE_ROUNDS:
|
||
return jsonify({
|
||
"code": 10001, "msg": "成功",
|
||
"answer": {"miss": information},
|
||
})
|
||
if result == CheckResult.NEEDS_MORE_ROUNDS:
|
||
return jsonify({
|
||
"code": 10001, "msg": "成功",
|
||
"answer": {"miss": information},
|
||
})
|
||
|
||
return jsonify({
|
||
"code": 200, "msg": "成功",
|
||
"answer": {"int": predicted_id, "label": predicted_label, "probability": predicted_probability,
|
||
"slot": entities},
|
||
})
|
||
|
||
except ValidationError as e:
|
||
return jsonify({"error": e.errors()}), 400 # 捕捉 Pydantic 错误并返回
|
||
except Exception as e:
|
||
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
|
||
|
||
|
||
def extract_multi_chat(messages):
|
||
from openai import OpenAI
|
||
client = OpenAI(base_url=api_base_url, api_key=api_key)
|
||
|
||
latest_message = messages[-1] # 最后一条用户提问
|
||
if latest_message.role == "user":
|
||
latest_user_question = latest_message.content.strip()
|
||
time_prefixes = ["今天", "昨天", "本周", "下周", "明天", "今日"] # 可扩展的时间前缀列表
|
||
if any(latest_user_question.startswith(prefix) for prefix in time_prefixes):
|
||
history_messages = []
|
||
else:
|
||
history_messages = messages[:-1] # 除最后一条之外的历史记录
|
||
|
||
# 格式化对话历史
|
||
chat_history = "\n".join([f"{msg.role}: {msg.content}" for msg in history_messages])
|
||
latest_user_question = latest_message.content if latest_message.role == "user" else ""
|
||
|
||
prompt = f'''
|
||
你是一个意图识别与补全助手,你的任务是根据用户的最新问题判断是否需要补全,如果不需要补全,则原样返回用户的最新问题,否则需要结合对话记录请你补用户的最新问题,并只返回最终的完整问题。请严格按照如下逻辑判断并执行:
|
||
|
||
---
|
||
|
||
【规则判断与补全流程】
|
||
|
||
第一步:用户最新问题是否以“公司”为主语?→ 原样返回,无需补全
|
||
- 若用户最新问题主语是“公司”,直接返回原句,无需补全。
|
||
- 主语为“公司”的典型句式:
|
||
- 以“公司”开头;
|
||
- 以“今天”“昨天”“本周”“下周”等时间词开头,紧跟“公司”作为主语;
|
||
- 示例:
|
||
- 用户的最新问题:“今天公司有多少四级风险作业计划?”
|
||
- 用户的最新问题:“今天公司有多少作业计划”
|
||
- 用户的最新问题:“公司今天有多少4级风险的作业面?”
|
||
- 最终提问均为: 原句不变。
|
||
|
||
第二步:用户最新问题是否是完整的问题?→ 原样返回,无需补全
|
||
- 若用户最新问题中包含下列之一:具体的项目部名、工程名、分公司名、班组名、地区名等信息,且同时出现作业计划、作业面、班组等查询对象,视为完整问题,直接返回原句,无需补全。
|
||
- 示例:
|
||
- 用户最新问题:“今天张三班组有多少作业计划?”
|
||
- 用户最新问题:“今天绿雪莲塘工程有多少作业计划”
|
||
- 最终提问均为: 原句不变。
|
||
|
||
第三步:用户最新问题是否存在指代词?→ 结合用户最新问题和对话记录进行补全
|
||
- 若用户最新问题问题中出现模糊表达,如“具体是哪些项”、“是哪两个”、“作业计划分别是什么”、“合肥中心变工程呢”、“具体是哪20项”等,请结合上一个用户问题和上一个AI回复补全问题信息。
|
||
- 示例1:
|
||
- 用户最新问题:“具体的作业计划分别是什么”
|
||
- 对话记录的最后一个用户问题:“今天送一分公司有多少项作业计划”
|
||
- 对话记录的最后一个AI回答:“今天送电一分公司有21项作业计划”
|
||
- 则最终提问应为:
|
||
“今天送电一分公司的21项作业计划分别是什么”
|
||
- 示例2:
|
||
- 用户的最新问题:“具体的作业内容是什么”
|
||
- 对话记录的最后一个用户问题:今天送一分公司第一项目部有多少项作业计划
|
||
- 对话记录的最后一个AI回答:今天送电一分公司第一项目管理部有21项作业计划
|
||
- 则最终提问应为:
|
||
“今天送电一分公司第一项目管理部的21项作业计划分别是什么”
|
||
|
||
第四步:用户最新问题是否为序号指代(第一个/第2个)?→ 用完整工程/项目/公司名替换补全
|
||
- 精确提取用户所指的序号(如“第3个”指第3个工程名、公司名或项目部名);
|
||
- 将该工程、公司或项目部的完整名称(包括括号中的编号)提取出来;
|
||
- **用完整名称替换掉用户上一个问题中出现的简称或模糊表达,并保留用户问题中的其它部分原样不变(如时间、计划数、内容)不变**;
|
||
- 示例1:
|
||
- 用户最新问题:"第一个" 或"第1个"
|
||
- 对话记录的最后一个用户问题:"2025年南苑调相机检修(PROJ-2023-0179)今天有多少作业计划""
|
||
- 对话记录的最后一个的AI回答:列出多个工程名,第1个是`检修公司调相机一二次设备检修维护和改造服务框架-2025年南苑调相机检修(PROJ-2023-0179)`
|
||
- 则最终提问应为:
|
||
`检修公司调相机一二次设备检修维护和改造服务框架-2025年南苑调相机检修(PROJ-2023-0179)今天有多少作业计划`
|
||
- 示例2:
|
||
- 用户的最新问题:"第二个" 或"第2个"
|
||
- 对话记录的最后一个用户问题:"宏源电力建设公司第三项目部今天有多少项作业计划""
|
||
- 对话记录的最后一个AI回答:列出多个分公司名,第2个:"安徽宏源电力建设有限公司(线路)"
|
||
- 则最终提问应为:
|
||
"安徽宏源电力建设有限公司(线路)第三项目部今天有多少项作业计划"
|
||
|
||
第五步:输出最终问题
|
||
- 直接输出最终问题(无解释、无多余前缀或后缀)
|
||
- 保持句式自然清晰
|
||
|
||
---
|
||
|
||
对话记录:
|
||
{chat_history}
|
||
|
||
用户最新问题:
|
||
{latest_user_question}
|
||
|
||
请输出最终问题:'''
|
||
|
||
message = [
|
||
{"role": "user", "content": prompt}
|
||
]
|
||
|
||
print(f"message:{message}")
|
||
|
||
response = client.chat.completions.create(
|
||
messages=message,
|
||
model=model_name,
|
||
max_tokens=100,
|
||
temperature=0.1, # 降低随机性,提高确定性
|
||
stream=False
|
||
)
|
||
|
||
res = response.choices[0].message.content.strip()
|
||
print(f"多轮意图后用户想要的问题是:{res}", flush=True)
|
||
return res
|
||
|
||
|
||
#
|
||
# #
|
||
# test_cases = [
|
||
# ("送一分公司"),
|
||
# ("送二分公司"),
|
||
# ("变电分公司"),
|
||
# ("建筑分公司"),
|
||
# ("检修试验分公司"),
|
||
# ("宏源电力公司"),
|
||
# ("宏源电力限公司"),
|
||
# ("宏源电力限公司线路"),
|
||
# ("宏源电力限公司变电"),
|
||
# ("送一分"),
|
||
# ("送二分"),
|
||
# ("变电分"),
|
||
# ("建筑分"),
|
||
# ("检修试验分"),
|
||
# ("宏源电力"),
|
||
# ("红源电力"),
|
||
# ("宏源电力有限"),
|
||
# ("宏源电力限线路"),
|
||
# ("宏源电力限变电"),
|
||
# ]
|
||
#
|
||
# print(f"加权混合策略 分公司名匹配**********************")
|
||
# start = time.perf_counter()
|
||
# for item in test_cases:
|
||
# match_results = standardize_sub_company(item,simply_to_standard_company_name_map, pinyin_simply_to_standard_company_name_map,55,80)
|
||
# print(f"加权混合策略 分公司名匹配 输入: {item}-> 输出: {match_results}")
|
||
# end = time.perf_counter()
|
||
# print(f"加权混合策略 耗时: {end - start:.4f} 秒")
|
||
#
|
||
#
|
||
# #
|
||
# test_cases = [
|
||
# ("卢集"),
|
||
# ("芦集"),
|
||
# ("芦集变电站"),
|
||
# ("安庆四变电站"),
|
||
# ("锦绣变电站"),
|
||
# ("滁州护桥变电站"),
|
||
# ("合州换流站"),
|
||
# ("陕北合州换流站"),
|
||
# ("陕北安徽合州换流站"),
|
||
# ("金牛变电站"),
|
||
# ("香涧鹭岛工程"),
|
||
# ("延庆换流站"),
|
||
# ("国网延庆换流站"),
|
||
# ("国网北京延庆换流站"),
|
||
# ("陶楼广银线路工程"),
|
||
# ("紫蓬变电站"),
|
||
# ("宿州萧砀变电站"),
|
||
# ("冯井变电站"),
|
||
# ("富邦秋浦变电站"),
|
||
# ("包河玉龙变电站"),
|
||
#
|
||
# ("绿雪莲塘工程"),
|
||
# ("合肥循环园工程"),
|
||
# ("合肥长临河工程"),
|
||
# ("合肥中心变"),
|
||
# ("锁库变电站工程"),
|
||
# ("槽坊工程"),
|
||
#
|
||
# ("安庆四500kV变电站新建工程(PROJ-2024-0862)"),
|
||
# ("锦绣-常青π入中心变电站220kV架空线路工程(PROJ-2024-1206)"),
|
||
# ("渝北±800千伏换流站电气安装A包(调试部分)(PROJ-2024-1192)"),
|
||
# ("先锋-泉河π入安庆四变电站220kV线路工程(PROJ-2024-0834)"),
|
||
# ("安徽滁州护桥220kV变电站2号主变扩建工程(PROJ-2024-0821)"),
|
||
# ("合州士800千伏换流站电气安装A包(PROJ-2025-0056)"),
|
||
# ("卫田-陶楼T接首业变电站110kV电缆线路工程(PROJ-2024-1236)"),
|
||
# ("谯城(亳三)-希夷220kV线路工程(PROJ-2024-1205)"),
|
||
# ]
|
||
# print(f"去不重要词汇 工程名匹配******************************************")
|
||
# start = time.perf_counter()
|
||
# for item in test_cases:
|
||
# match_results = standardize_project_name(item, simply_to_standard_project_name_map,
|
||
# pinyin_simply_to_standard_project_name_map, 70, 90)
|
||
# print(f"工程名匹配 输入: {item}-> 输出: {match_results}")
|
||
# end = time.perf_counter()
|
||
# print(f"词集匹配 耗时: {end - start:.4f} 秒")
|
||
#
|
||
# print(f"项目名匹配******************************************")
|
||
# oral_program_name_list = [
|
||
# ("第1项目部"), # 期望返回所有"第三项目管理部"
|
||
# ("第2项目部"),
|
||
# ("第3项目部"),
|
||
# ("第4项目部"),
|
||
# ("第5项目部"),
|
||
# ("第6项目部"),
|
||
# ("第7项目部"),
|
||
# ("第8项目部"),
|
||
# ("第9项目部"),
|
||
# ("第10项目部"),
|
||
# ("第11项目部"),
|
||
# ("第12项目部"),
|
||
# ("第13项目部"),
|
||
# ("电缆班"),
|
||
# ("调试1队"),
|
||
# ("调试2队"),
|
||
# ("调试3队"),
|
||
# ("调试4队"),
|
||
# ("调试5队"),
|
||
# ("第一项目管理部"),
|
||
# ("第二项目管理部"),
|
||
# ("第五项目管理部"),
|
||
# ("第十一项目管理部(萧砀线路)"),
|
||
# ("第三项目管理部(张店线路)"),
|
||
# ("第三项目管理部(岳西线路)"),
|
||
# ("第五项目管理部(蚌埠)"),
|
||
# ("第三项目管理部(六安线路)"),
|
||
# ("第十一项目管理部(宿州线路)"),
|
||
# ("调试一队"),
|
||
# ("调试二队"),
|
||
# ("调试三队"),
|
||
# ("电缆班"),
|
||
# ]
|
||
#
|
||
# for company in standard_company_name_list:
|
||
# for program in oral_program_name_list:
|
||
# match_results = standardize_projectDepartment(company, program, standard_company_program, high_score=90)
|
||
# print(f"加权混合策略 项目部名称 输入: 公司:{company},项目部:{program}-> 输出: {match_results}")
|
||
|
||
if __name__ == '__main__':
|
||
app.run(host='0.0.0.0', port=18074, debug=True)
|