Intention/api/main.py

491 lines
19 KiB
Python
Raw Normal View History

2025-02-25 09:27:14 +08:00
from flask import Flask, jsonify, request
2025-02-27 09:06:34 +08:00
from pydantic import BaseModel, Field
2025-02-25 09:27:14 +08:00
from werkzeug.exceptions import HTTPException
2025-03-03 13:21:11 +08:00
from typing import List
2025-02-27 09:06:34 +08:00
from pydantic import ValidationError
2025-04-18 13:19:44 +08:00
import time
2025-02-25 09:27:14 +08:00
2025-02-27 16:33:26 +08:00
from intentRecognition import IntentRecognition
from slotRecognition import SlotRecognition
from utils import CheckResult, check_standard_name_slot_probability, check_lost
2025-03-03 11:03:48 +08:00
from config import *
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-22620"
MODEL_UIE_PATH = R"../uie/output/checkpoint-22320"
2025-02-27 09:06:34 +08:00
# 类别名称列表
labels = [
"天气查询", "互联网查询", "页面切换", "日计划数量查询", "周计划数量查询",
"日计划作业内容", "周计划作业内容", "施工人数", "作业考勤人数", "知识问答",
2025-04-18 13:19:44 +08:00
"通用对话", "作业面查询", "班组人数查询", "班组数查询", "作业面内容", "班组详情"
2025-02-27 09:06:34 +08:00
]
2025-02-25 09:27:14 +08:00
# 标签映射
label_map = {
2025-03-09 14:51:30 +08:00
0: 'O', # 非实体
1: 'B-date', 15: 'I-date',
2: 'B-projectName', 16: 'I-projectName',
3: 'B-projectType', 17: 'I-projectType',
4: 'B-constructionUnit', 18: 'I-constructionUnit',
5: 'B-implementationOrganization', 19: 'I-implementationOrganization',
6: 'B-projectDepartment', 20: 'I-projectDepartment',
7: 'B-projectManager', 21: 'I-projectManager',
8: 'B-subcontractor', 22: 'I-subcontractor',
9: 'B-teamLeader', 23: 'I-teamLeader',
10: 'B-riskLevel', 24: 'I-riskLevel',
11: 'B-page', 25: 'I-page',
12: 'B-operating', 26: 'I-operating',
13: 'B-teamName', 27: 'I-teamName',
14: 'B-constructionArea', 28: 'I-constructionArea',
2025-02-25 09:27:14 +08:00
}
2025-04-18 13:19:44 +08:00
# 初始化工具类
intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
# 初始化槽位识别工具类
slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map)
# 设置Flask应用
# update_data_from_local()
from globalData import GlobalData
GlobalData.update_from_local()
2025-02-27 09:06:34 +08:00
app = Flask(__name__)
2025-02-25 09:27:14 +08:00
2025-04-18 13:19:44 +08:00
2025-02-25 09:27:14 +08:00
# 统一的异常处理函数
@app.errorhandler(Exception)
def handle_exception(e):
"""统一异常处理"""
if isinstance(e, HTTPException):
return jsonify({
"error": {
"type": e.name,
"message": e.description,
"status_code": e.code
}
}), e.code
return jsonify({
"error": {
"type": "InternalServerError",
"message": str(e)
}
}), 500
2025-02-27 09:06:34 +08:00
def validate_user(data):
"""验证用户ID"""
if data.get("user_id") != '3bb66776-1722-4c36-b14a-73dd210fe750':
return jsonify(
code=401,
msg='权限验证失败,请联系接口开发人员',
label=-1,
probability=-1
), 401
return None
class LabelMessage(BaseModel):
text: str = Field(..., description="消息内容")
user_id: str = Field(..., description="消息内容")
# 每条消息的结构
class Message(BaseModel):
role: str = Field(..., description="消息内容")
content: str = Field(..., description="消息内容")
# 请求数据的结构
class RequestData(BaseModel):
messages: List[Message] = Field(..., description="消息列表")
user_id: str = Field(..., description="用户ID")
# 意图识别
@app.route('/intent_reco', methods=['POST'])
def intent_reco():
"""意图识别"""
try:
# 获取请求中的 JSON 数据
data = request.get_json()
request_data = LabelMessage(**data) # Pydantic 会验证数据结构
text = request_data.text
user_id = request_data.user_id
# 检查必需字段
if not text:
return jsonify({"error": "text is required"}), 400
if not user_id:
return jsonify({"error": "user_id is required"}), 400
# 验证用户ID
user_validation_error = validate_user(data)
if user_validation_error:
return user_validation_error
# 调用predict方法进行意图识别
2025-02-27 16:33:26 +08:00
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(text)
2025-02-27 09:06:34 +08:00
return jsonify(
code=200,
msg="成功",
int=predicted_id,
label=predicted_label,
probability=float(predicted_probability)
)
except Exception as e:
return jsonify({"error": str(e)}), 500
# 槽位抽取
@app.route('/slot_reco', methods=['POST'])
def slot_reco():
"""槽位识别"""
try:
# 获取请求中的 JSON 数据
data = request.get_json()
request_data = LabelMessage(**data) # Pydantic 会验证数据结构
text = request_data.text
user_id = request_data.user_id
# 检查必需字段
if not text:
return jsonify({"error": "text is required"}), 400
if not user_id:
return jsonify({"error": "user_id is required"}), 400
# 验证用户ID
user_validation_error = validate_user(data)
if user_validation_error:
return user_validation_error
2025-02-25 09:27:14 +08:00
2025-02-27 09:06:34 +08:00
# 调用 recognize 方法进行槽位识别
entities, slot_probability = slot_recognizer.recognize_probability(text)
print(
f"槽位抽取后的实体:{entities},实体后的可能值:{slot_probability}",
flush=True)
2025-02-27 09:06:34 +08:00
return jsonify(
code=200,
msg="成功",
slot=entities)
2025-02-25 09:27:14 +08:00
2025-02-27 09:06:34 +08:00
except Exception as e:
return jsonify({"error": str(e)}), 500
2025-02-25 09:27:14 +08:00
2025-02-27 09:06:34 +08:00
@app.route('/agent', methods=['POST'])
def agent():
try:
data = request.get_json()
except Exception as e:
print(f"body不是一个有效的json")
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
try:
2025-02-27 09:06:34 +08:00
# 使用 Pydantic 来验证数据结构
request_data = RequestData(**data) # Pydantic 会验证数据结构
messages = request_data.messages
user_id = request_data.user_id
2025-02-25 09:27:14 +08:00
2025-02-27 09:06:34 +08:00
# 检查必需字段是否存在
if not messages:
return jsonify({"error": "messages is required"}), 400
if not user_id:
return jsonify({"error": "user_id is required"}), 400
2025-02-25 09:27:14 +08:00
2025-02-27 09:06:34 +08:00
# 验证用户ID假设这个函数已经定义
user_validation_error = validate_user(data)
if user_validation_error:
return user_validation_error
if len(messages) == 1: # 首轮
query = messages[0].content # 使用 Message 对象的 .content 属性
# 先进行意图识别
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query)
# 再进行槽位抽取
entities,slot_probability = slot_recognizer.recognize_probability(query)
print(
f"第一轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},slot_probability:{slot_probability},message:{messages}",
2025-04-18 13:19:44 +08:00
flush=True)
# 多轮
2025-02-27 09:06:34 +08:00
else:
res = extract_multi_chat(messages)
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(res)
#0:天气1互联网查询9知识问答10通用对话
if predicted_id in [0, 1, 9, 10]:
print(f"多轮意图识别后的label:{predicted_label}, id:{predicted_id},message:{messages}",
flush=True)
return jsonify({
"code": 200, "msg": "成功",
"answer": {"int": predicted_id, "label": predicted_label, "probability": predicted_probability},
"finalQuery": res
})
entities, slot_probability = slot_recognizer.recognize_probability(res)
print(
f"多轮意图识别后的槽位:槽位抽取后的实体:{entities},slot_probability:{slot_probability}",
2025-04-18 13:19:44 +08:00
flush=True)
2025-03-03 10:58:46 +08:00
#必须槽位缺失检查
status, sk = check_lost(predicted_id, entities)
if status == CheckResult.NEEDS_MORE_ROUNDS:
return jsonify({"code": 10001, "msg": "成功",
"answer": {"miss": sk},
})
2025-04-17 09:11:53 +08:00
#工程名、分公司名和项目名标准化
result, information = check_standard_name_slot_probability(predicted_id, entities)
if result == CheckResult.NEEDS_MORE_ROUNDS:
return jsonify({
"code": 10001, "msg": "成功",
"answer": {"miss": information},
})
if result == CheckResult.NEEDS_MORE_ROUNDS:
return jsonify({
"code": 10001, "msg": "成功",
"answer": {"miss": information},
})
2025-02-25 09:27:14 +08:00
2025-03-03 10:58:46 +08:00
return jsonify({
"code": 200, "msg": "成功",
"answer": {"int": predicted_id, "label": predicted_label, "probability": predicted_probability,
"slot": entities},
2025-03-03 10:58:46 +08:00
})
2025-02-27 09:06:34 +08:00
except ValidationError as e:
return jsonify({"error": e.errors()}), 400 # 捕捉 Pydantic 错误并返回
except Exception as e:
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
2025-02-25 09:27:14 +08:00
2025-04-18 13:19:44 +08:00
def extract_multi_chat(messages):
2025-03-03 10:58:46 +08:00
from openai import OpenAI
client = OpenAI(base_url=api_base_url, api_key=api_key)
2025-04-18 13:19:44 +08:00
latest_message = messages[-1] # 最后一条用户提问
2025-04-17 09:11:53 +08:00
if latest_message.role == "user":
latest_user_question = latest_message.content.strip()
time_prefixes = ["今天", "昨天", "本周", "下周", "明天", "今日"] # 可扩展的时间前缀列表
if any(latest_user_question.startswith(prefix) for prefix in time_prefixes):
history_messages = []
else:
history_messages = messages[:-1] # 除最后一条之外的历史记录
# 格式化对话历史
2025-04-17 09:11:53 +08:00
chat_history = "\n".join([f"{msg.role}: {msg.content}" for msg in history_messages])
latest_user_question = latest_message.content if latest_message.role == "user" else ""
prompt = f'''
你是一个意图识别与补全助手你的任务是根据用户的最新问题判断是否需要补全如果不需要补全则原样返回用户的最新问题否则需要结合对话记录请你补用户的最新问题并只返回最终的完整问题请严格按照如下逻辑判断并执行
---
规则判断与补全流程
第一步用户最新问题是否以公司为主语 原样返回无需补全
- 若用户最新问题主语是公司直接返回原句无需补全
- 主语为公司的典型句式
- 公司开头
- 今天昨天本周下周等时间词开头紧跟公司作为主语
- 示例
- 用户的最新问题今天公司有多少四级风险作业计划
- 用户的最新问题今天公司有多少作业计划
- 用户的最新问题公司今天有多少4级风险的作业面
- 最终提问均为 原句不变
第二步用户最新问题是否是完整的问题 原样返回无需补全
- 若用户最新问题中包含下列之一具体的项目部名工程名分公司名班组名地区名等信息且同时出现作业计划作业面班组等查询对象视为完整问题直接返回原句无需补全
- 示例
- 用户最新问题今天张三班组有多少作业计划
- 用户最新问题今天绿雪莲塘工程有多少作业计划
- 最终提问均为 原句不变
第三步用户最新问题是否存在指代词 结合用户最新问题和对话记录进行补全
- 若用户最新问题问题中出现模糊表达具体是哪些项是哪两个作业计划分别是什么合肥中心变工程呢具体是哪20项请结合上一个用户问题和上一个AI回复补全问题信息
- 示例1
- 用户最新问题具体的作业计划分别是什么
- 对话记录的最后一个用户问题今天送一分公司有多少项作业计划
- 对话记录的最后一个AI回答今天送电一分公司有21项作业计划
- 则最终提问应为
今天送电一分公司的21项作业计划分别是什么
- 示例2
- 用户的最新问题具体的作业内容是什么
- 对话记录的最后一个用户问题今天送一分公司第一项目部有多少项作业计划
- 对话记录的最后一个AI回答今天送电一分公司第一项目管理部有21项作业计划
- 则最终提问应为
今天送电一分公司第一项目管理部的21项作业计划分别是什么
第四步用户最新问题是否为序号指代第一个/第2个 用完整工程/项目/公司名替换补全
- 精确提取用户所指的序号第3个指第3个工程名公司名或项目部名
- 将该工程公司或项目部的完整名称包括括号中的编号提取出来
- **用完整名称替换掉用户上一个问题中出现的简称或模糊表达并保留用户问题中的其它部分原样不变如时间计划数内容不变**
- 示例1
- 用户最新问题:"第一个" "第1个"
- 对话记录的最后一个用户问题"2025年南苑调相机检修(PROJ-2023-0179)今天有多少作业计划""
- 对话记录的最后一个的AI回答列出多个工程名第1个是`检修公司调相机一二次设备检修维护和改造服务框架-2025年南苑调相机检修(PROJ-2023-0179)`
- 则最终提问应为
`检修公司调相机一二次设备检修维护和改造服务框架-2025年南苑调相机检修(PROJ-2023-0179)今天有多少作业计划`
- 示例2
- 用户的最新问题:"第二个" "第2个"
- 对话记录的最后一个用户问题"宏源电力建设公司第三项目部今天有多少项作业计划""
- 对话记录的最后一个AI回答列出多个分公司名第2个"安徽宏源电力建设有限公司(线路)"
- 则最终提问应为
"安徽宏源电力建设有限公司(线路)第三项目部今天有多少项作业计划"
第五步输出最终问题
- 直接输出最终问题无解释无多余前缀或后缀
- 保持句式自然清晰
---
对话记录
{chat_history}
2025-04-17 09:11:53 +08:00
用户最新问题
{latest_user_question}
请输出最终问题'''
message = [
{"role": "user", "content": prompt}
]
2025-03-03 10:58:46 +08:00
2025-04-17 09:11:53 +08:00
print(f"message:{message}")
2025-03-03 10:58:46 +08:00
response = client.chat.completions.create(
messages=message,
model=model_name,
max_tokens=100,
2025-04-18 13:19:44 +08:00
temperature=0.1, # 降低随机性,提高确定性
2025-03-03 10:58:46 +08:00
stream=False
)
res = response.choices[0].message.content.strip()
print(f"多轮意图后用户想要的问题是:{res}", flush=True)
return res
2025-04-18 13:19:44 +08:00
#
2025-04-18 16:39:06 +08:00
# #
2025-04-18 13:19:44 +08:00
# test_cases = [
# ("送一分公司"),
# ("送二分公司"),
# ("变电分公司"),
# ("建筑分公司"),
# ("检修试验分公司"),
# ("宏源电力公司"),
# ("宏源电力限公司"),
# ("宏源电力限公司线路"),
# ("宏源电力限公司变电"),
# ("送一分"),
# ("送二分"),
# ("变电分"),
# ("建筑分"),
# ("检修试验分"),
# ("宏源电力"),
# ("红源电力"),
# ("宏源电力有限"),
# ("宏源电力限线路"),
# ("宏源电力限变电"),
# ]
2025-04-17 09:11:53 +08:00
#
2025-04-18 13:19:44 +08:00
# print(f"加权混合策略 分公司名匹配**********************")
# start = time.perf_counter()
# for item in test_cases:
# match_results = standardize_sub_company(item,simply_to_standard_company_name_map, pinyin_simply_to_standard_company_name_map,55,80)
# print(f"加权混合策略 分公司名匹配 输入: {item}-> 输出: {match_results}")
# end = time.perf_counter()
# print(f"加权混合策略 耗时: {end - start:.4f} 秒")
#
2025-04-17 09:11:53 +08:00
#
2025-04-18 16:39:06 +08:00
# #
2025-04-18 13:19:44 +08:00
# test_cases = [
# ("卢集"),
# ("芦集"),
# ("芦集变电站"),
# ("安庆四变电站"),
# ("锦绣变电站"),
# ("滁州护桥变电站"),
# ("合州换流站"),
# ("陕北合州换流站"),
# ("陕北安徽合州换流站"),
# ("金牛变电站"),
# ("香涧鹭岛工程"),
# ("延庆换流站"),
# ("国网延庆换流站"),
# ("国网北京延庆换流站"),
# ("陶楼广银线路工程"),
# ("紫蓬变电站"),
# ("宿州萧砀变电站"),
# ("冯井变电站"),
# ("富邦秋浦变电站"),
# ("包河玉龙变电站"),
2025-04-17 09:11:53 +08:00
#
2025-04-18 13:19:44 +08:00
# ("绿雪莲塘工程"),
# ("合肥循环园工程"),
# ("合肥长临河工程"),
# ("合肥中心变"),
# ("锁库变电站工程"),
# ("槽坊工程"),
#
2025-04-18 13:19:44 +08:00
# ("安庆四500kV变电站新建工程(PROJ-2024-0862)"),
# ("锦绣-常青π入中心变电站220kV架空线路工程(PROJ-2024-1206)"),
# ("渝北±800千伏换流站电气安装A包(调试部分)(PROJ-2024-1192)"),
# ("先锋-泉河π入安庆四变电站220kV线路工程(PROJ-2024-0834)"),
# ("安徽滁州护桥220kV变电站2号主变扩建工程(PROJ-2024-0821)"),
# ("合州士800千伏换流站电气安装A包(PROJ-2025-0056)"),
# ("卫田-陶楼T接首业变电站110kV电缆线路工程(PROJ-2024-1236)"),
# ("谯城(亳三)-希夷220kV线路工程(PROJ-2024-1205)"),
# ]
# print(f"去不重要词汇 工程名匹配******************************************")
# start = time.perf_counter()
# for item in test_cases:
2025-04-18 16:39:06 +08:00
# match_results = standardize_project_name(item, simply_to_standard_project_name_map,
# pinyin_simply_to_standard_project_name_map, 70, 90)
2025-04-18 13:19:44 +08:00
# print(f"工程名匹配 输入: {item}-> 输出: {match_results}")
# end = time.perf_counter()
# print(f"词集匹配 耗时: {end - start:.4f} 秒")
#
2025-04-18 13:19:44 +08:00
# print(f"项目名匹配******************************************")
# oral_program_name_list = [
# ("第1项目部"), # 期望返回所有"第三项目管理部"
# ("第2项目部"),
# ("第3项目部"),
# ("第4项目部"),
# ("第5项目部"),
# ("第6项目部"),
# ("第7项目部"),
# ("第8项目部"),
# ("第9项目部"),
# ("第10项目部"),
# ("第11项目部"),
# ("第12项目部"),
# ("第13项目部"),
# ("电缆班"),
# ("调试1队"),
# ("调试2队"),
# ("调试3队"),
# ("调试4队"),
# ("调试5队"),
# ("第一项目管理部"),
# ("第二项目管理部"),
# ("第五项目管理部"),
# ("第十一项目管理部(萧砀线路)"),
# ("第三项目管理部(张店线路)"),
# ("第三项目管理部(岳西线路)"),
# ("第五项目管理部(蚌埠)"),
# ("第三项目管理部(六安线路)"),
# ("第十一项目管理部(宿州线路)"),
# ("调试一队"),
# ("调试二队"),
# ("调试三队"),
# ("电缆班"),
# ]
#
2025-04-18 13:19:44 +08:00
# for company in standard_company_name_list:
# for program in oral_program_name_list:
# match_results = standardize_projectDepartment(company, program, standard_company_program, high_score=90)
# print(f"加权混合策略 项目部名称 输入: 公司:{company},项目部:{program}-> 输出: {match_results}")
2025-04-17 09:11:53 +08:00
2025-02-25 09:27:14 +08:00
if __name__ == '__main__':
2025-02-27 16:33:26 +08:00
app.run(host='0.0.0.0', port=18074, debug=True)