2025-02-27 09:06:34 +08:00
|
|
|
|
import pydantic
|
2025-02-25 09:27:14 +08:00
|
|
|
|
from flask import Flask, jsonify, request
|
2025-02-27 09:06:34 +08:00
|
|
|
|
from pydantic import BaseModel, Field
|
2025-02-25 09:27:14 +08:00
|
|
|
|
from werkzeug.exceptions import HTTPException
|
2025-02-27 09:06:34 +08:00
|
|
|
|
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer, ErnieForSequenceClassification
|
2025-02-25 09:27:14 +08:00
|
|
|
|
import paddle
|
2025-02-27 09:06:34 +08:00
|
|
|
|
import numpy as np
|
|
|
|
|
|
import paddle.nn.functional as F # 用于 Softmax
|
|
|
|
|
|
from typing import List, Dict
|
|
|
|
|
|
from pydantic import ValidationError
|
2025-02-25 09:27:14 +08:00
|
|
|
|
|
2025-02-27 16:33:26 +08:00
|
|
|
|
from intentRecognition import IntentRecognition
|
|
|
|
|
|
from slotRecognition import SlotRecognition
|
2025-02-27 09:06:34 +08:00
|
|
|
|
|
|
|
|
|
|
# 常量
|
2025-02-27 16:33:26 +08:00
|
|
|
|
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-4160"
|
|
|
|
|
|
MODEL_UIE_PATH = R"../uie/output/checkpoint-4320"
|
2025-02-27 09:06:34 +08:00
|
|
|
|
|
|
|
|
|
|
# 类别名称列表
|
|
|
|
|
|
labels = [
|
|
|
|
|
|
"天气查询", "互联网查询", "页面切换", "日计划数量查询", "周计划数量查询",
|
|
|
|
|
|
"日计划作业内容", "周计划作业内容", "施工人数", "作业考勤人数", "知识问答"
|
|
|
|
|
|
]
|
2025-02-25 09:27:14 +08:00
|
|
|
|
|
|
|
|
|
|
# 标签映射
|
|
|
|
|
|
label_map = {
|
|
|
|
|
|
0: 'O', 1: 'B-date', 11: 'I-date',
|
2025-02-27 09:06:34 +08:00
|
|
|
|
2: 'B-projectName', 12: 'I-projectName',
|
|
|
|
|
|
3: 'B-projectType', 13: 'I-projectType',
|
|
|
|
|
|
4: 'B-constructionUnit', 14: 'I-constructionUnit',
|
|
|
|
|
|
5: 'B-implementationOrganization', 15: 'I-implementationOrganization',
|
|
|
|
|
|
6: 'B-projectDepartment', 16: 'I-projectDepartment',
|
|
|
|
|
|
7: 'B-projectManager', 17: 'I-projectManager',
|
2025-02-25 09:27:14 +08:00
|
|
|
|
8: 'B-subcontractor', 18: 'I-subcontractor',
|
2025-02-27 09:06:34 +08:00
|
|
|
|
9: 'B-teamLeader', 19: 'I-teamLeader',
|
|
|
|
|
|
10: 'B-riskLevel', 20: 'I-riskLevel'
|
2025-02-25 09:27:14 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-02-27 09:06:34 +08:00
|
|
|
|
# 初始化工具类
|
|
|
|
|
|
intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
|
|
|
|
|
|
|
|
|
|
|
|
# 初始化槽位识别工具类
|
|
|
|
|
|
slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map)
|
|
|
|
|
|
# 设置Flask应用
|
|
|
|
|
|
app = Flask(__name__)
|
2025-02-25 09:27:14 +08:00
|
|
|
|
|
2025-02-27 16:33:26 +08:00
|
|
|
|
|
2025-02-25 09:27:14 +08:00
|
|
|
|
# 统一的异常处理函数
|
|
|
|
|
|
@app.errorhandler(Exception)
|
|
|
|
|
|
def handle_exception(e):
|
|
|
|
|
|
"""统一异常处理"""
|
|
|
|
|
|
if isinstance(e, HTTPException):
|
|
|
|
|
|
return jsonify({
|
|
|
|
|
|
"error": {
|
|
|
|
|
|
"type": e.name,
|
|
|
|
|
|
"message": e.description,
|
|
|
|
|
|
"status_code": e.code
|
|
|
|
|
|
}
|
|
|
|
|
|
}), e.code
|
|
|
|
|
|
return jsonify({
|
|
|
|
|
|
"error": {
|
|
|
|
|
|
"type": "InternalServerError",
|
|
|
|
|
|
"message": str(e)
|
|
|
|
|
|
}
|
|
|
|
|
|
}), 500
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-02-27 09:06:34 +08:00
|
|
|
|
def validate_user(data):
|
|
|
|
|
|
"""验证用户ID"""
|
|
|
|
|
|
if data.get("user_id") != '3bb66776-1722-4c36-b14a-73dd210fe750':
|
|
|
|
|
|
return jsonify(
|
|
|
|
|
|
code=401,
|
|
|
|
|
|
msg='权限验证失败,请联系接口开发人员',
|
|
|
|
|
|
label=-1,
|
|
|
|
|
|
probability=-1
|
|
|
|
|
|
), 401
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LabelMessage(BaseModel):
|
|
|
|
|
|
text: str = Field(..., description="消息内容")
|
|
|
|
|
|
user_id: str = Field(..., description="消息内容")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 每条消息的结构
|
|
|
|
|
|
class Message(BaseModel):
|
|
|
|
|
|
role: str = Field(..., description="消息内容")
|
|
|
|
|
|
content: str = Field(..., description="消息内容")
|
|
|
|
|
|
# timestamp: str = Field(..., description="消息时间戳")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 请求数据的结构
|
|
|
|
|
|
class RequestData(BaseModel):
|
|
|
|
|
|
messages: List[Message] = Field(..., description="消息列表")
|
|
|
|
|
|
user_id: str = Field(..., description="用户ID")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 意图识别
|
|
|
|
|
|
@app.route('/intent_reco', methods=['POST'])
|
|
|
|
|
|
def intent_reco():
|
|
|
|
|
|
"""意图识别"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 获取请求中的 JSON 数据
|
|
|
|
|
|
data = request.get_json()
|
|
|
|
|
|
request_data = LabelMessage(**data) # Pydantic 会验证数据结构
|
|
|
|
|
|
text = request_data.text
|
|
|
|
|
|
user_id = request_data.user_id
|
|
|
|
|
|
# 检查必需字段
|
|
|
|
|
|
if not text:
|
|
|
|
|
|
return jsonify({"error": "text is required"}), 400
|
|
|
|
|
|
if not user_id:
|
|
|
|
|
|
return jsonify({"error": "user_id is required"}), 400
|
|
|
|
|
|
|
|
|
|
|
|
# 验证用户ID
|
|
|
|
|
|
user_validation_error = validate_user(data)
|
|
|
|
|
|
if user_validation_error:
|
|
|
|
|
|
return user_validation_error
|
|
|
|
|
|
|
|
|
|
|
|
# 调用predict方法进行意图识别
|
2025-02-27 16:33:26 +08:00
|
|
|
|
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(text)
|
2025-02-27 09:06:34 +08:00
|
|
|
|
|
|
|
|
|
|
return jsonify(
|
|
|
|
|
|
code=200,
|
|
|
|
|
|
msg="成功",
|
|
|
|
|
|
int=predicted_id,
|
|
|
|
|
|
label=predicted_label,
|
|
|
|
|
|
probability=float(predicted_probability)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 槽位抽取
|
|
|
|
|
|
@app.route('/slot_reco', methods=['POST'])
|
|
|
|
|
|
def slot_reco():
|
|
|
|
|
|
"""槽位识别"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 获取请求中的 JSON 数据
|
|
|
|
|
|
data = request.get_json()
|
|
|
|
|
|
request_data = LabelMessage(**data) # Pydantic 会验证数据结构
|
|
|
|
|
|
text = request_data.text
|
|
|
|
|
|
user_id = request_data.user_id
|
|
|
|
|
|
|
|
|
|
|
|
# 检查必需字段
|
|
|
|
|
|
if not text:
|
|
|
|
|
|
return jsonify({"error": "text is required"}), 400
|
|
|
|
|
|
if not user_id:
|
|
|
|
|
|
return jsonify({"error": "user_id is required"}), 400
|
|
|
|
|
|
|
|
|
|
|
|
# 验证用户ID
|
|
|
|
|
|
user_validation_error = validate_user(data)
|
|
|
|
|
|
if user_validation_error:
|
|
|
|
|
|
return user_validation_error
|
2025-02-25 09:27:14 +08:00
|
|
|
|
|
2025-02-27 09:06:34 +08:00
|
|
|
|
# 调用 recognize 方法进行槽位识别
|
|
|
|
|
|
entities = slot_recognizer.recognize(text)
|
2025-02-25 09:27:14 +08:00
|
|
|
|
|
2025-02-27 09:06:34 +08:00
|
|
|
|
return jsonify(
|
|
|
|
|
|
code=200,
|
|
|
|
|
|
msg="成功",
|
|
|
|
|
|
slot=entities)
|
2025-02-25 09:27:14 +08:00
|
|
|
|
|
2025-02-27 09:06:34 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
return jsonify({"error": str(e)}), 500
|
2025-02-25 09:27:14 +08:00
|
|
|
|
|
|
|
|
|
|
|
2025-02-27 09:06:34 +08:00
|
|
|
|
@app.route('/agent', methods=['POST'])
|
|
|
|
|
|
def agent():
|
|
|
|
|
|
try:
|
|
|
|
|
|
data = request.get_json()
|
|
|
|
|
|
# 使用 Pydantic 来验证数据结构
|
|
|
|
|
|
request_data = RequestData(**data) # Pydantic 会验证数据结构
|
|
|
|
|
|
messages = request_data.messages
|
|
|
|
|
|
user_id = request_data.user_id
|
2025-02-25 09:27:14 +08:00
|
|
|
|
|
2025-02-27 09:06:34 +08:00
|
|
|
|
# 检查必需字段是否存在
|
|
|
|
|
|
if not messages:
|
|
|
|
|
|
return jsonify({"error": "messages is required"}), 400
|
|
|
|
|
|
if not user_id:
|
|
|
|
|
|
return jsonify({"error": "user_id is required"}), 400
|
2025-02-25 09:27:14 +08:00
|
|
|
|
|
2025-02-27 09:06:34 +08:00
|
|
|
|
# 验证用户ID(假设这个函数已经定义)
|
|
|
|
|
|
user_validation_error = validate_user(data)
|
|
|
|
|
|
if user_validation_error:
|
|
|
|
|
|
return user_validation_error
|
|
|
|
|
|
if len(messages) == 1: # 首轮
|
|
|
|
|
|
query = messages[0].content # 使用 Message 对象的 .content 属性
|
|
|
|
|
|
# 先进行意图识别
|
|
|
|
|
|
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query)
|
|
|
|
|
|
# 再进行槽位抽取
|
|
|
|
|
|
entities = slot_recognizer.recognize(query)
|
2025-02-25 09:27:14 +08:00
|
|
|
|
|
2025-02-27 16:33:26 +08:00
|
|
|
|
print(f"意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}")
|
|
|
|
|
|
#必须槽位缺失检查
|
|
|
|
|
|
status, sk = check_lost(predicted_id, entities)
|
|
|
|
|
|
if status == 1:
|
|
|
|
|
|
return jsonify({
|
|
|
|
|
|
"code": 200,
|
|
|
|
|
|
"msg": "成功",
|
|
|
|
|
|
"answer": {
|
|
|
|
|
|
"miss": sk
|
|
|
|
|
|
},
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
return jsonify({
|
|
|
|
|
|
"code": 200,
|
|
|
|
|
|
"msg": "成功",
|
|
|
|
|
|
"answer": {
|
|
|
|
|
|
"int": predicted_id,
|
|
|
|
|
|
"label": predicted_label,
|
|
|
|
|
|
"probability": predicted_probability,
|
|
|
|
|
|
"slot": entities
|
|
|
|
|
|
},
|
|
|
|
|
|
})
|
2025-02-25 09:27:14 +08:00
|
|
|
|
|
2025-02-27 09:06:34 +08:00
|
|
|
|
# 如果是后续轮次(多轮对话),这里只做示例,可能需要根据具体需求进行处理
|
|
|
|
|
|
else:
|
|
|
|
|
|
query = messages[0].content # 使用 Message 对象的 .content 属性
|
|
|
|
|
|
return jsonify({
|
|
|
|
|
|
"user_id": user_id,
|
|
|
|
|
|
"query": query,
|
|
|
|
|
|
"message_count": len(messages)
|
|
|
|
|
|
})
|
2025-02-25 09:27:14 +08:00
|
|
|
|
|
2025-02-27 09:06:34 +08:00
|
|
|
|
except ValidationError as e:
|
|
|
|
|
|
return jsonify({"error": e.errors()}), 400 # 捕捉 Pydantic 错误并返回
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
|
2025-02-25 09:27:14 +08:00
|
|
|
|
|
|
|
|
|
|
|
2025-02-27 16:33:26 +08:00
|
|
|
|
# def check_lost(int_res, slot):
|
|
|
|
|
|
# return 0, ""
|
2025-02-27 09:06:34 +08:00
|
|
|
|
def check_lost(int_res, slot):
|
2025-02-27 16:33:26 +08:00
|
|
|
|
#labels: ["天气查询","互联网查询","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"]
|
2025-02-27 09:06:34 +08:00
|
|
|
|
mapping = {
|
2025-02-27 16:33:26 +08:00
|
|
|
|
2: [['page'], ['app'], ['module']],
|
|
|
|
|
|
3: [['date']],
|
2025-02-27 09:06:34 +08:00
|
|
|
|
4: [['date']],
|
|
|
|
|
|
5: [['date']],
|
|
|
|
|
|
6: [['date']],
|
|
|
|
|
|
7: [['date']],
|
2025-02-27 16:33:26 +08:00
|
|
|
|
8: [['date']],
|
2025-02-27 09:06:34 +08:00
|
|
|
|
}
|
2025-02-27 16:33:26 +08:00
|
|
|
|
#3:"页面切换",
|
|
|
|
|
|
intention_mapping = {2: "页面切换", 3: "日计划数量查询", 4: "周计划数量查询", 5: "日计划作业内容",
|
|
|
|
|
|
6: "周计划作业内容",7: "施工人数",8: "作业考勤人数"}
|
2025-02-27 09:06:34 +08:00
|
|
|
|
if not mapping.__contains__(int_res):
|
2025-02-27 16:33:26 +08:00
|
|
|
|
return 0, ""
|
2025-02-27 09:06:34 +08:00
|
|
|
|
cur_k = list(slot.keys())
|
|
|
|
|
|
idx = -1
|
|
|
|
|
|
idx_len = 99
|
|
|
|
|
|
for i in range(len(mapping[int_res])):
|
|
|
|
|
|
sk = mapping[int_res][i]
|
|
|
|
|
|
left = [x for x in sk if x not in cur_k]
|
|
|
|
|
|
more = [x for x in cur_k if x not in sk]
|
|
|
|
|
|
if len(more) >= 0 and len(left) == 0:
|
|
|
|
|
|
idx = i
|
|
|
|
|
|
idx_len = 0
|
|
|
|
|
|
break
|
|
|
|
|
|
if len(left) < idx_len:
|
|
|
|
|
|
idx = i
|
|
|
|
|
|
idx_len = len(left)
|
2025-02-25 09:27:14 +08:00
|
|
|
|
|
2025-02-27 09:06:34 +08:00
|
|
|
|
if idx_len == 0: # 匹配通过
|
|
|
|
|
|
return 0, cur_k
|
|
|
|
|
|
left = [x for x in mapping[int_res][idx] if x not in cur_k]
|
2025-02-27 16:33:26 +08:00
|
|
|
|
apologize_str = "非常抱歉,"
|
|
|
|
|
|
if int_res == 2:
|
|
|
|
|
|
return 1, f"{apologize_str}请问你想查询哪个页面?"
|
|
|
|
|
|
elif int_res in [3, 4, 5, 6, 7, 8]:
|
|
|
|
|
|
return 1, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}?"
|
2025-02-25 09:27:14 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2025-02-27 16:33:26 +08:00
|
|
|
|
app.run(host='0.0.0.0', port=18074, debug=True)
|