Intention/api/main.py

429 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import time
from flask import Flask, jsonify, request
from pydantic import BaseModel, Field
from werkzeug.exceptions import HTTPException
from typing import List
from pydantic import ValidationError
from logger_util import setup_logger
from intentRecognition import IntentRecognition
from slotRecognition import SlotRecognition
from utils import CheckResult, check_standard_name_slot_probability, check_lost, process_msg_content
from config import *
from globalData import GlobalData
from apscheduler.schedulers.background import BackgroundScheduler
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-14672"
MODEL_UIE_PATH = R"../uie/output/checkpoint-16380"
# 类别名称列表
labels = [
"天气查询", "互联网查询", "页面切换", "日计划数量查询", "周计划数量查询",
"日计划作业内容", "周计划作业内容", "施工人数", "作业考勤人数", "知识问答",
"通用对话", "作业面查询", "班组人数查询", "班组数查询", "作业面内容", "班组详情",
"工程进度查询", "人员查询", "分公司查询","工程数量查询","工程详情查询","项目部数量查询",
"建管单位数量查询","建管单位详情","分包单位数量查询","分包单位详情"
]
# 标签映射
label_map = {
0: 'O', # 非实体
1: 'B-date', 20: 'I-date',
2: 'B-projectName', 21: 'I-projectName',
3: 'B-projectType', 22: 'I-projectType',
4: 'B-constructionUnit', 23: 'I-constructionUnit',
5: 'B-implementationOrganization', 24: 'I-implementationOrganization',
6: 'B-projectDepartment', 25: 'I-projectDepartment',
7: 'B-projectManager', 26: 'I-projectManager',
8: 'B-subcontractor', 27: 'I-subcontractor',
9: 'B-teamLeader', 28: 'I-teamLeader',
10: 'B-riskLevel', 29: 'I-riskLevel',
11: 'B-page', 30: 'I-page',
12: 'B-operating', 31: 'I-operating',
13: 'B-teamName', 32: 'I-teamName',
14: 'B-constructionArea', 33: 'I-constructionArea',
15: 'B-personName', 34: 'I-personName',
16: 'B-personQueryType', 35: 'I-personQueryType',
17: 'B-projectStatus', 36: 'I-projectStatus',
18: 'B-skyNet', 37: 'I-skyNet',
19: 'B-programNavigation', 38: 'I-programNavigation'
}
logger = setup_logger("main", level=logging.DEBUG)
# 初始化工具类
intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
# 初始化槽位识别工具类
slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map)
# 设置Flask应用
app = Flask(__name__)
def job():
logger.info(f"✅ [Info] Executing update_from_redis...at {time.strftime('%Y-%m-%d %H:%M:%S')}")
GlobalData.update_from_redis()
job()
# 创建后台调度器
scheduler = BackgroundScheduler()
scheduler.add_job(job, 'cron', hour=3, minute=0) # 每天凌晨3点执行
scheduler.start()
# 统一的异常处理函数
@app.errorhandler(Exception)
def handle_exception(e):
"""统一异常处理"""
if isinstance(e, HTTPException):
return jsonify({
"error": {
"type": e.name,
"message": e.description,
"status_code": e.code
}
}), e.code
return jsonify({
"error": {
"type": "InternalServerError",
"message": str(e)
}
}), 500
def validate_user(data):
"""验证用户ID"""
if data.get("user_id") != '3bb66776-1722-4c36-b14a-73dd210fe750':
return jsonify(
code=401,
msg='权限验证失败,请联系接口开发人员',
label=-1,
probability=-1
), 401
return None
class LabelMessage(BaseModel):
text: str = Field(..., description="消息内容")
user_id: str = Field(..., description="消息内容")
# 每条消息的结构
class Message(BaseModel):
role: str = Field(..., description="消息内容")
content: str = Field(..., description="消息内容")
# 请求数据的结构
class RequestData(BaseModel):
messages: List[Message] = Field(..., description="消息列表")
user_id: str = Field(..., description="用户ID")
# 意图识别
@app.route('/intent_reco', methods=['POST'])
def intent_reco():
"""意图识别"""
try:
# 获取请求中的 JSON 数据
data = request.get_json()
request_data = LabelMessage(**data) # Pydantic 会验证数据结构
text = request_data.text
user_id = request_data.user_id
# 检查必需字段
if not text:
return jsonify({"error": "text is required"}), 400
if not user_id:
return jsonify({"error": "user_id is required"}), 400
# 验证用户ID
user_validation_error = validate_user(data)
if user_validation_error:
return user_validation_error
# 调用predict方法进行意图识别
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(text)
return jsonify(
code=200,
msg="成功",
int=predicted_id,
label=predicted_label,
probability=float(predicted_probability)
)
except Exception as e:
logger.error(f"error:{e}")
return jsonify({"error": str(e)}), 500
# 槽位抽取
@app.route('/slot_reco', methods=['POST'])
def slot_reco():
"""槽位识别"""
try:
# 获取请求中的 JSON 数据
data = request.get_json()
request_data = LabelMessage(**data) # Pydantic 会验证数据结构
text = request_data.text
user_id = request_data.user_id
# 检查必需字段
if not text:
return jsonify({"error": "text is required"}), 400
if not user_id:
return jsonify({"error": "user_id is required"}), 400
# 验证用户ID
user_validation_error = validate_user(data)
if user_validation_error:
return user_validation_error
# 调用 recognize 方法进行槽位识别
# entities = slot_recognizer.recognize(text)
entities, slot_probability = slot_recognizer.recognize_probability(text)
logger.info(f"槽位抽取后的实体:{entities},实体后的可能值:{slot_probability}")
return jsonify(
code=200,
msg="成功",
slot=entities)
except Exception as e:
logger.error(f"error:{e}")
return jsonify({"error": str(e)}), 500
@app.route('/agent', methods=['POST'])
def agent():
try:
data = request.get_json()
except Exception as e:
logger.error(f"body不是一个有效的json")
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
try:
# 使用 Pydantic 来验证数据结构
request_data = RequestData(**data) # Pydantic 会验证数据结构
messages = request_data.messages
user_id = request_data.user_id
# 检查必需字段是否存在
if not messages:
return jsonify({"error": "messages is required"}), 400
if not user_id:
return jsonify({"error": "user_id is required"}), 400
# 验证用户ID假设这个函数已经定义
user_validation_error = validate_user(data)
if user_validation_error:
return user_validation_error
if len(messages) == 1: # 首轮
query = messages[0].content # 使用 Message 对象的 .content 属性
# 先进行意图识别
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query)
# 再进行槽位抽取
entities, slot_probability = slot_recognizer.recognize_probability(query)
logger.info(
f"第一轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},,slot_probability:{slot_probability},message:{messages}",
)
# 多轮
else:
res = extract_multi_chat(messages)
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(res)
#0:天气1互联网查询9知识问答10通用对话
if predicted_id in [0, 1, 9, 10]:
logger.info(f"多轮意图识别后的label:{predicted_label}, id:{predicted_id},message:{messages}")
return jsonify({
"code": 200, "msg": "成功",
"answer": {"int": predicted_id, "label": predicted_label, "probability": predicted_probability},
"finalQuery": res
})
# entities = slot_recognizer.recognize(res)
entities, slot_probability = slot_recognizer.recognize_probability(res)
logger.info(
f"多轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},slot_probability:{slot_probability},message:{messages}")
#必须槽位缺失检查
status, sk = check_lost(predicted_id, entities)
if status == CheckResult.NEEDS_MORE_ROUNDS:
return jsonify({"code": 10001, "msg": "成功",
"answer": {"miss": sk},
})
#工程名、分公司名和项目名标准化
result, information = check_standard_name_slot_probability(predicted_id, entities)
if result == CheckResult.NEEDS_MORE_ROUNDS:
return jsonify({
"code": 10001, "msg": "成功",
"answer": {"miss": information},
})
return jsonify({
"code": 200, "msg": "成功",
"answer": {"int": predicted_id, "label": predicted_label, "probability": predicted_probability,
"slot": entities},
})
except ValidationError as e:
return jsonify({"error": e.errors()}), 400 # 捕捉 Pydantic 错误并返回
except Exception as e:
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
def format_chat_history(history_messages):
reset_keywords = ["当前", "今天", "昨天", "本周", "下周", "明天", "今日", "打开","工程进度"]
keep_index = 0
# Step 1: 查找最靠近当前的 user 消息,若其包含关键词,则记录其索引
for i in reversed(range(len(history_messages))):
msg = history_messages[i]
if msg.role == "user" and any(kw in msg.content and len(msg.content) > 5 for kw in reset_keywords):
keep_index = i
break
# Step 2: 截取需要保留的历史消息
filtered_messages = history_messages[keep_index:]
# Step 3: 构建格式化历史
formatted_history = ""
for i, msg in enumerate(filtered_messages):
formatted_history += f"\n<turn id={i+1}>\n"
formatted_history += f"<role>{msg.role}</role>\n"
formatted_history += f"<content>{process_msg_content(msg.content)}</content>\n"
formatted_history += "</turn>\n"
return formatted_history
def extract_multi_chat(messages):
from openai import OpenAI
client = OpenAI(base_url=api_base_url, api_key=api_key)
history_messages = messages[-7:] if len(messages) >= 7 else messages
chat_history = format_chat_history(history_messages)
logger.info(f"chat_history:{chat_history}")
prompt = f'''
你是一个多轮对话理解和还原专家,擅长从复杂的上下文中提取关键信息,理清语义逻辑,最终还原出用户的真实意图。请你以"自然语言编程"的方式逐步思考并处理以下任务,最终生成完整明确的用户查询。
请严格按照以下步骤执行:
## 第一步:初始化变量
初始化以下变量(用于逻辑推理,请勿输出):
- 当前_用户问题 = "" # 表示最终需要还原的完整问题
- 当前_实体 = "" # 表示用户所提及的公司、项目部、工程名称等具体业务实体
- 当前_时间 = "" # 表示与查询相关的时间点或时间段
- 下一步_操作 = "" # 表示当前对话中模型需要用户补充的信息类型
- 下一步_选择列表 = "" # 表示需要从中选择具体内容的候选项列表
## 第二步:逐轮解析对话历史
从最早的对话轮次开始,依次处理每一轮对话。
### 轮次 XX从1开始递增
### 如果当前角色为user:
请根据之前 assistant 的引导语,结合当前用户输入,判断是否需要进行补全操作,并按以下规则处理:
- 如果 `下一步_操作 == "补充时间"`:调用函数 `补全时间(当前_用户问题, 用户输入)` 并更新 `当前_用户问题`
- 否则 如果 `下一步_操作 == "补充分公司"`:调用函数 `补全分公司(当前_用户问题, 用户输入)` 并更新 `当前_用户问题`
- 否则 如果 `下一步_操作 == "选择列表"`:调用函数 `替换序号为实体(当前_用户问题, 用户输入, 下一步_选择列表)` 并更新 `当前_用户问题`
- 否则 如果 有完整的句意(用户输入):将当前用户输入作为 `当前_用户问题`
- 否则 如果 包含模糊表达(用户输入):调用函数 `补全模糊表达(当前_用户问题, 用户输入)`并更新 `当前_用户问题`
- 否则 如果 是查询新属性(用户输入):调用函数 `替换新属性(当前_用户问题, 用户输入)`并更新 `当前_用户问题`
- 否则 将当前用户输入作为 `当前_用户问题`
处理完成后请清空 `下一步_操作` 和`下一步_选择列表`
### 如果当前角色为 assistant:
请根据 assistant 的输出内容,判断接下来用户是否被引导进行补全,并更新 `下一步_操作`
- 如果 assistant 的回复中包含"请问你想查询什么时间的"或类似引导时间的内容:设定 `下一步_操作 = "补充时间"`
- 否则 如果包含"请补充该项目部所属的分公司名称":设定 `下一步_操作 = "补充分公司"`
- 否则 如果包含"请确认您要选择哪一个":设定 `下一步_操作 = "选择列表"`,并将当前 assistant 回复中列出的选项存入 `下一步_选择列表`
- 否则 什么都不做
### 重复以上解析过程
请逐条处理每轮对话,直到所有历史对话处理完毕,然后进入第三步。
## 第三步:输出还原后的最终用户问题
在完成全部历史消息处理后,请输出变量 `当前_用户问题`,它即为根据上下文补全后的完整查询。
---
辅助函数说明:
函数 补全时间(文本, 时间词):
支持识别如"2025-5-15""昨天""等模糊时间
从时间词里提取时间
在文本开头添加提取到的时间并返回,保持其他内容不变
如果没有提取到时间则直接返回原文本内容
示例:补全时间(""第一项目部有多少作业计划", "今天的") 返回 "今天第一项目部有多少作业计划"
函数 补全分公司(文本, 分公司词):
从分公司词里提取分公司
在文本里的时间词之后添加提取到的分公司并返回,保持其他内容不变
如果没有提取到分公司信息则直接返回原文本内容
示例:补全分公司("今天第一项目有多少作业计划", "送二分公司") 返回 "今天送二分公司第一项目有多少作业计划"
函数 替换序号为实体(文本, 选择项, 选择列表):
从选择项如"第X个"中提取X的值并处理中文数字和阿拉伯数字作为序号
根据序号识别出"第1个XXX第2个YYY"格式的选项列表提取出完整的实体名称
保留文本中的动作词和目标对象,将文本中的实体引用替换为完整实体名称
示例:替换序号为实体("中心变工程进度", "第1个", "第1个燃气工程第2个给水工程") 返回 "燃气工程进度"
函数 包含模糊表达(文本):
检查是否包含"具体是哪些项""是哪两个"等模糊表达
返回布尔值表示是否包含
示例:包含模糊表达("具体是哪些项") 返回 True
函数 补全模糊表达(文本,模糊表达):
结合文本内容和模糊表达,补全语义并返回
示例:补全模糊表达("今天送一分公司有多少作业计划", "具体是哪些") 返回 "今天送一分公司具体有哪些作业计划"
函数 是查询新属性(文本, 新问题):
如果新问题中提取不到主体 且仅能提取到查询属性
且这个查询属性和文本中提取到的查询属性不同 则返回TRUE
其他情况均返回FALSE
示例:是查询新属性("今天送一分公司有多少作业计划", "作业内容") 返回 True
函数 替换新属性(文本,新查询属性):
先删除文本中的"有多少"等类似的表达数量表达,
再将文本里的查询属性替换为新查询属性,并保持其他内容不变并返回 且保持新查询属性的语气
示例:替换新属性("今天送一分公司有多少作业计划", "作业内容") 返回 "今天送一分公司的作业内容"
函数 有完整的句意(新问题):
如果新问题里有主体同时有操作对象或查询对象则返回TRUE
其他情况均返回FALSE
对话历史如下:
{chat_history}
请你仅输出还原后的完整问题,不要输出任何变量、中间步骤或解释说明,确保结果自然通顺,语义完整。
'''
message = [
{"role": "user", "content": prompt}
]
response = client.chat.completions.create(
messages=message,
model=model_name,
max_tokens=100,
temperature=0.1, # 降低随机性,提高确定性
stream=False
)
res = response.choices[0].message.content.strip()
logger.info(f"多轮意图后用户想要的问题是:{res}")
return res
if __name__ == '__main__':
# 启动时立即执行一次
app.run(host='0.0.0.0', port=18074, debug=False)