Intention/api/mian.py

350 lines
12 KiB
Python
Raw Normal View History

2025-02-27 09:06:34 +08:00
import pydantic
2025-02-25 09:27:14 +08:00
from flask import Flask, jsonify, request
2025-02-27 09:06:34 +08:00
from pydantic import BaseModel, Field
2025-02-25 09:27:14 +08:00
from werkzeug.exceptions import HTTPException
2025-02-27 09:06:34 +08:00
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer, ErnieForSequenceClassification
2025-02-25 09:27:14 +08:00
import paddle
2025-02-27 09:06:34 +08:00
import numpy as np
import paddle.nn.functional as F # 用于 Softmax
from typing import List, Dict
from pydantic import ValidationError
2025-02-25 09:27:14 +08:00
2025-02-27 16:33:26 +08:00
from intentRecognition import IntentRecognition
from slotRecognition import SlotRecognition
2025-02-27 17:32:47 +08:00
from fuzzywuzzy import process
2025-02-27 20:22:00 +08:00
from utils import CheckResult, StandardType, load_standard_name
2025-02-28 07:49:40 +08:00
from constants import PROJECT_NAME, PROJECT_DEPARTMENT, SIMILARITY_VALUE
2025-02-27 09:06:34 +08:00
# 常量
2025-02-27 16:33:26 +08:00
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-4160"
MODEL_UIE_PATH = R"../uie/output/checkpoint-4320"
2025-02-27 09:06:34 +08:00
# 类别名称列表
labels = [
"天气查询", "互联网查询", "页面切换", "日计划数量查询", "周计划数量查询",
"日计划作业内容", "周计划作业内容", "施工人数", "作业考勤人数", "知识问答"
]
2025-02-25 09:27:14 +08:00
# 标签映射
label_map = {
0: 'O', 1: 'B-date', 11: 'I-date',
2025-02-27 09:06:34 +08:00
2: 'B-projectName', 12: 'I-projectName',
3: 'B-projectType', 13: 'I-projectType',
4: 'B-constructionUnit', 14: 'I-constructionUnit',
5: 'B-implementationOrganization', 15: 'I-implementationOrganization',
6: 'B-projectDepartment', 16: 'I-projectDepartment',
7: 'B-projectManager', 17: 'I-projectManager',
2025-02-25 09:27:14 +08:00
8: 'B-subcontractor', 18: 'I-subcontractor',
2025-02-27 09:06:34 +08:00
9: 'B-teamLeader', 19: 'I-teamLeader',
10: 'B-riskLevel', 20: 'I-riskLevel'
2025-02-25 09:27:14 +08:00
}
2025-02-27 09:06:34 +08:00
# 初始化工具类
intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
# 初始化槽位识别工具类
slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map)
# 设置Flask应用
2025-02-27 17:32:47 +08:00
#标准工程名
2025-02-27 20:22:00 +08:00
standard_project_name_list = load_standard_name('./standard_data/standard_project.txt')
2025-02-27 17:32:47 +08:00
#标准项目名
2025-02-27 20:22:00 +08:00
standard_program_name_list = load_standard_name('./standard_data/standard_program.txt')
2025-02-28 07:49:40 +08:00
print(f":standard_project_name_list:{standard_project_name_list}")
2025-02-27 09:06:34 +08:00
app = Flask(__name__)
2025-02-25 09:27:14 +08:00
2025-02-27 16:33:26 +08:00
2025-02-25 09:27:14 +08:00
# 统一的异常处理函数
@app.errorhandler(Exception)
def handle_exception(e):
"""统一异常处理"""
if isinstance(e, HTTPException):
return jsonify({
"error": {
"type": e.name,
"message": e.description,
"status_code": e.code
}
}), e.code
return jsonify({
"error": {
"type": "InternalServerError",
"message": str(e)
}
}), 500
2025-02-27 09:06:34 +08:00
def validate_user(data):
"""验证用户ID"""
if data.get("user_id") != '3bb66776-1722-4c36-b14a-73dd210fe750':
return jsonify(
code=401,
msg='权限验证失败,请联系接口开发人员',
label=-1,
probability=-1
), 401
return None
class LabelMessage(BaseModel):
text: str = Field(..., description="消息内容")
user_id: str = Field(..., description="消息内容")
# 每条消息的结构
class Message(BaseModel):
role: str = Field(..., description="消息内容")
content: str = Field(..., description="消息内容")
# timestamp: str = Field(..., description="消息时间戳")
# 请求数据的结构
class RequestData(BaseModel):
messages: List[Message] = Field(..., description="消息列表")
user_id: str = Field(..., description="用户ID")
# 意图识别
@app.route('/intent_reco', methods=['POST'])
def intent_reco():
"""意图识别"""
try:
# 获取请求中的 JSON 数据
data = request.get_json()
request_data = LabelMessage(**data) # Pydantic 会验证数据结构
text = request_data.text
user_id = request_data.user_id
# 检查必需字段
if not text:
return jsonify({"error": "text is required"}), 400
if not user_id:
return jsonify({"error": "user_id is required"}), 400
# 验证用户ID
user_validation_error = validate_user(data)
if user_validation_error:
return user_validation_error
# 调用predict方法进行意图识别
2025-02-27 16:33:26 +08:00
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(text)
2025-02-27 09:06:34 +08:00
return jsonify(
code=200,
msg="成功",
int=predicted_id,
label=predicted_label,
probability=float(predicted_probability)
)
except Exception as e:
return jsonify({"error": str(e)}), 500
# 槽位抽取
@app.route('/slot_reco', methods=['POST'])
def slot_reco():
"""槽位识别"""
try:
# 获取请求中的 JSON 数据
data = request.get_json()
request_data = LabelMessage(**data) # Pydantic 会验证数据结构
text = request_data.text
user_id = request_data.user_id
# 检查必需字段
if not text:
return jsonify({"error": "text is required"}), 400
if not user_id:
return jsonify({"error": "user_id is required"}), 400
# 验证用户ID
user_validation_error = validate_user(data)
if user_validation_error:
return user_validation_error
2025-02-25 09:27:14 +08:00
2025-02-27 09:06:34 +08:00
# 调用 recognize 方法进行槽位识别
entities = slot_recognizer.recognize(text)
2025-02-25 09:27:14 +08:00
2025-02-27 09:06:34 +08:00
return jsonify(
code=200,
msg="成功",
slot=entities)
2025-02-25 09:27:14 +08:00
2025-02-27 09:06:34 +08:00
except Exception as e:
return jsonify({"error": str(e)}), 500
2025-02-25 09:27:14 +08:00
2025-02-27 09:06:34 +08:00
@app.route('/agent', methods=['POST'])
def agent():
try:
data = request.get_json()
# 使用 Pydantic 来验证数据结构
request_data = RequestData(**data) # Pydantic 会验证数据结构
messages = request_data.messages
user_id = request_data.user_id
2025-02-25 09:27:14 +08:00
2025-02-27 09:06:34 +08:00
# 检查必需字段是否存在
if not messages:
return jsonify({"error": "messages is required"}), 400
if not user_id:
return jsonify({"error": "user_id is required"}), 400
2025-02-25 09:27:14 +08:00
2025-02-27 09:06:34 +08:00
# 验证用户ID假设这个函数已经定义
user_validation_error = validate_user(data)
if user_validation_error:
return user_validation_error
if len(messages) == 1: # 首轮
query = messages[0].content # 使用 Message 对象的 .content 属性
# 先进行意图识别
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query)
# 再进行槽位抽取
entities = slot_recognizer.recognize(query)
2025-02-25 09:27:14 +08:00
2025-03-03 10:58:46 +08:00
print(f"第一轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}")
2025-02-25 09:27:14 +08:00
2025-02-27 09:06:34 +08:00
# 如果是后续轮次(多轮对话),这里只做示例,可能需要根据具体需求进行处理
else:
query = messages[0].content # 使用 Message 对象的 .content 属性
2025-03-03 10:58:46 +08:00
# 先进行意图识别
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query)
entities = multi_slot_recognizer(predicted_id, messages)
print(f"多轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}")
#必须槽位缺失检查
status, sk = check_lost(predicted_id, entities)
if status == CheckResult.NEEDS_MORE_ROUNDS:
return jsonify({"code": 10001, "msg": "成功",
"answer": { "miss": sk},
})
#工程名和项目名标准化
print(f"start to check_project_standard_slot")
result, information = check_project_standard_slot(predicted_id, entities)
print(f"end check_project_standard_slot,{result},{information}")
if result == CheckResult.NEEDS_MORE_ROUNDS:
2025-02-27 09:06:34 +08:00
return jsonify({
2025-03-03 10:58:46 +08:00
"code": 10001, "msg": "成功",
"answer": {"miss": information},
2025-02-27 09:06:34 +08:00
})
2025-02-25 09:27:14 +08:00
2025-03-03 10:58:46 +08:00
return jsonify({
"code": 200,"msg": "成功",
"answer": {"int": predicted_id, "label": predicted_label, "probability": predicted_probability, "slot": entities },
})
2025-02-27 09:06:34 +08:00
except ValidationError as e:
return jsonify({"error": e.errors()}), 400 # 捕捉 Pydantic 错误并返回
except Exception as e:
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
2025-02-25 09:27:14 +08:00
2025-03-03 10:58:46 +08:00
def multi_slot_recognizer(intention_id, messages):
from openai import OpenAI
final_slot = {}
api_base_url = "http://36.33.26.201:27861/v1"
api_key = 'EMPTY'
model_name = 'qwen2.5-instruct'
client = OpenAI(base_url = api_base_url, api_key = api_key)
prompt = f'''
根据用户的输入{messages}抽取出用户想了解的问题要求保持客观真实简单明了不要多余解释和阐述
'''
message = [{"role": "system", "content": prompt}]
message.extend(messages)
# print(message)
response = client.chat.completions.create(
messages=message,
model=model_name,
max_tokens=1000,
temperature=0.001,
stream=False
)
res = response.choices[0].message.content
print(f"多轮意图后用户想要的问题是{res}")
entries = slot_recognizer.recognize(res)
return entries
2025-02-27 09:06:34 +08:00
def check_lost(int_res, slot):
2025-02-27 16:33:26 +08:00
#labels: ["天气查询","互联网查询","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"]
2025-02-27 09:06:34 +08:00
mapping = {
2025-02-27 16:33:26 +08:00
2: [['page'], ['app'], ['module']],
3: [['date']],
2025-02-27 09:06:34 +08:00
4: [['date']],
5: [['date']],
6: [['date']],
7: [['date']],
2025-02-27 16:33:26 +08:00
8: [['date']],
2025-02-27 09:06:34 +08:00
}
2025-02-27 16:33:26 +08:00
intention_mapping = {2: "页面切换", 3: "日计划数量查询", 4: "周计划数量查询", 5: "日计划作业内容",
6: "周计划作业内容",7: "施工人数",8: "作业考勤人数"}
2025-02-27 09:06:34 +08:00
if not mapping.__contains__(int_res):
2025-02-27 16:33:26 +08:00
return 0, ""
#提取的槽位信息
2025-02-27 09:06:34 +08:00
cur_k = list(slot.keys())
idx = -1
idx_len = 99
for i in range(len(mapping[int_res])):
sk = mapping[int_res][i]
#不在提取的槽位信息里,但是在必须槽位表里
miss_params = [x for x in sk if x not in cur_k]
#不在必须槽位表里,但是在提取的槽位信息里
extra_params = [x for x in cur_k if x not in sk]
if len(extra_params) >= 0 and len(miss_params) == 0:
2025-02-27 09:06:34 +08:00
idx = i
idx_len = 0
break
if len(miss_params) < idx_len:
2025-02-27 09:06:34 +08:00
idx = i
idx_len = len(miss_params)
2025-02-25 09:27:14 +08:00
2025-02-27 09:06:34 +08:00
if idx_len == 0: # 匹配通过
return CheckResult.NO_MATCH, cur_k
#符合当前意图的的必须槽位,但是不在提取的槽位信息里
2025-02-27 09:06:34 +08:00
left = [x for x in mapping[int_res][idx] if x not in cur_k]
print(f"符合当前意图的的必须槽位,但是不在提取的槽位信息里, {left}")
2025-02-27 16:33:26 +08:00
apologize_str = "非常抱歉,"
if int_res == 2:
return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询哪个页面?"
2025-02-27 16:33:26 +08:00
elif int_res in [3, 4, 5, 6, 7, 8]:
return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}"
2025-02-25 09:27:14 +08:00
2025-02-27 17:32:47 +08:00
#标准化工程名
def check_project_standard_slot(int_res, slot) -> tuple:
intention_list = {3, 4, 5, 6, 7, 8}
if int_res not in intention_list:
2025-02-27 20:22:00 +08:00
return CheckResult.NO_MATCH, ""
for key, value in slot.items():
if key == PROJECT_NAME:
match_project, match_possibility = fuzzy_match(value, standard_project_name_list)
2025-02-27 20:22:00 +08:00
print(f"fuzzy_match project result:{match_project}, {match_possibility}")
2025-02-28 07:49:40 +08:00
if match_possibility >= SIMILARITY_VALUE:
slot[key] = match_project
else:
return CheckResult.NEEDS_MORE_ROUNDS, f"抱歉,您说的工程名是{match_project}"
if key == PROJECT_DEPARTMENT:
match_program, match_possibility = fuzzy_match(value, standard_program_name_list)
2025-02-27 20:22:00 +08:00
print(f"fuzzy_match program result:{match_program}, {match_possibility}")
2025-02-28 07:49:40 +08:00
if match_possibility >= SIMILARITY_VALUE:
slot[key] = match_program
else:
return CheckResult.NEEDS_MORE_ROUNDS, f"抱歉,您说的项目名是{match_program}"
2025-02-27 20:22:00 +08:00
return CheckResult.NO_MATCH, ""
2025-02-27 17:32:47 +08:00
def fuzzy_match(user_input, standard_name):
result = process.extract(user_input, standard_name)
2025-02-27 20:22:00 +08:00
return result[0][0], result[0][1]/100
2025-02-27 17:32:47 +08:00
2025-02-25 09:27:14 +08:00
if __name__ == '__main__':
2025-02-27 16:33:26 +08:00
app.run(host='0.0.0.0', port=18074, debug=True)