Intention/api/main.py

348 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from flask import Flask, jsonify, request
from pydantic import BaseModel, Field
from werkzeug.exceptions import HTTPException
from typing import List
from pydantic import ValidationError
from intentRecognition import IntentRecognition
from slotRecognition import SlotRecognition
from fuzzywuzzy import process
from utils import CheckResult, StandardType, load_standard_name
from constants import PROJECT_NAME, PROJECT_DEPARTMENT, SIMILARITY_VALUE
from config import *
# 常量
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-4170"
MODEL_UIE_PATH = R"../uie/output/checkpoint-1740"
# 类别名称列表
labels = [
"天气查询", "互联网查询", "页面切换", "日计划数量查询", "周计划数量查询",
"日计划作业内容", "周计划作业内容", "施工人数", "作业考勤人数", "知识问答"
]
# 标签映射
label_map = {
0: 'O', # 非实体
1: 'B-date', 13: 'I-date',
2: 'B-projectName', 14: 'I-projectName',
3: 'B-projectType', 15: 'I-projectType',
4: 'B-constructionUnit', 16: 'I-constructionUnit',
5: 'B-implementationOrganization', 17: 'I-implementationOrganization',
6: 'B-projectDepartment', 18: 'I-projectDepartment',
7: 'B-projectManager', 19: 'I-projectManager',
8: 'B-subcontractor', 20: 'I-subcontractor',
9: 'B-teamLeader', 21: 'I-teamLeader',
10: 'B-riskLevel', 22: 'I-riskLevel',
11: 'B-page', 23: 'I-page',
12: 'B-operating', 24: 'I-operating',
}
# 初始化工具类
intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
# 初始化槽位识别工具类
slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map)
# 设置Flask应用
#标准工程名
standard_project_name_list = load_standard_name('./standard_data/standard_project.txt')
#标准项目名
standard_program_name_list = load_standard_name('./standard_data/standard_program.txt')
print(f":standard_project_name_list:{standard_project_name_list}")
app = Flask(__name__)
# 统一的异常处理函数
@app.errorhandler(Exception)
def handle_exception(e):
"""统一异常处理"""
if isinstance(e, HTTPException):
return jsonify({
"error": {
"type": e.name,
"message": e.description,
"status_code": e.code
}
}), e.code
return jsonify({
"error": {
"type": "InternalServerError",
"message": str(e)
}
}), 500
def validate_user(data):
"""验证用户ID"""
if data.get("user_id") != '3bb66776-1722-4c36-b14a-73dd210fe750':
return jsonify(
code=401,
msg='权限验证失败,请联系接口开发人员',
label=-1,
probability=-1
), 401
return None
class LabelMessage(BaseModel):
text: str = Field(..., description="消息内容")
user_id: str = Field(..., description="消息内容")
# 每条消息的结构
class Message(BaseModel):
role: str = Field(..., description="消息内容")
content: str = Field(..., description="消息内容")
# timestamp: str = Field(..., description="消息时间戳")
# 请求数据的结构
class RequestData(BaseModel):
messages: List[Message] = Field(..., description="消息列表")
user_id: str = Field(..., description="用户ID")
# 意图识别
@app.route('/intent_reco', methods=['POST'])
def intent_reco():
"""意图识别"""
try:
# 获取请求中的 JSON 数据
data = request.get_json()
request_data = LabelMessage(**data) # Pydantic 会验证数据结构
text = request_data.text
user_id = request_data.user_id
# 检查必需字段
if not text:
return jsonify({"error": "text is required"}), 400
if not user_id:
return jsonify({"error": "user_id is required"}), 400
# 验证用户ID
user_validation_error = validate_user(data)
if user_validation_error:
return user_validation_error
# 调用predict方法进行意图识别
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(text)
return jsonify(
code=200,
msg="成功",
int=predicted_id,
label=predicted_label,
probability=float(predicted_probability)
)
except Exception as e:
return jsonify({"error": str(e)}), 500
# 槽位抽取
@app.route('/slot_reco', methods=['POST'])
def slot_reco():
"""槽位识别"""
try:
# 获取请求中的 JSON 数据
data = request.get_json()
request_data = LabelMessage(**data) # Pydantic 会验证数据结构
text = request_data.text
user_id = request_data.user_id
# 检查必需字段
if not text:
return jsonify({"error": "text is required"}), 400
if not user_id:
return jsonify({"error": "user_id is required"}), 400
# 验证用户ID
user_validation_error = validate_user(data)
if user_validation_error:
return user_validation_error
# 调用 recognize 方法进行槽位识别
entities = slot_recognizer.recognize(text)
return jsonify(
code=200,
msg="成功",
slot=entities)
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route('/agent', methods=['POST'])
def agent():
try:
data = request.get_json()
# 使用 Pydantic 来验证数据结构
request_data = RequestData(**data) # Pydantic 会验证数据结构
messages = request_data.messages
user_id = request_data.user_id
# 检查必需字段是否存在
if not messages:
return jsonify({"error": "messages is required"}), 400
if not user_id:
return jsonify({"error": "user_id is required"}), 400
# 验证用户ID假设这个函数已经定义
user_validation_error = validate_user(data)
if user_validation_error:
return user_validation_error
if len(messages) == 1: # 首轮
query = messages[0].content # 使用 Message 对象的 .content 属性
# 先进行意图识别
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query)
# 再进行槽位抽取
entities = slot_recognizer.recognize(query)
print(f"第一轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}")
# 如果是后续轮次(多轮对话),这里只做示例,可能需要根据具体需求进行处理
else:
query = messages[0].content # 使用 Message 对象的 .content 属性
# 先进行意图识别
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query)
entities = multi_slot_recognizer(predicted_id, messages)
print(f"多轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}")
# #必须槽位缺失检查
status, sk = check_lost(predicted_id, entities)
if status == CheckResult.NEEDS_MORE_ROUNDS:
return jsonify({"code": 10001, "msg": "成功",
"answer": { "miss": sk},
})
#工程名和项目名标准化
print(f"start to check_project_standard_slot")
result, information = check_project_standard_slot(predicted_id, entities)
print(f"end check_project_standard_slot,{result},{information}")
if result == CheckResult.NEEDS_MORE_ROUNDS:
return jsonify({
"code": 10001, "msg": "成功",
"answer": {"miss": information},
})
return jsonify({
"code": 200,"msg": "成功",
"answer": {"int": predicted_id, "label": predicted_label, "probability": predicted_probability, "slot": entities },
})
except ValidationError as e:
return jsonify({"error": e.errors()}), 400 # 捕捉 Pydantic 错误并返回
except Exception as e:
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
#用户多轮对话后,提取出去用户最想了解的问题
#message格式如"messages":[{"role":"user","content":"今天合肥下塘变电站工程的作业内容是什么"},{"role":"assistant","content":"抱歉您说的工程名是安徽合肥下塘220kV变电站新建工程吗"},{"role":"user","content":"不是"}]}
def multi_slot_recognizer(intention_id, messages):
from openai import OpenAI
client = OpenAI(base_url = api_base_url, api_key = api_key)
prompt = f'''
根据用户的输入{messages},抽取出用户想了解的问题,要求:保持客观真实,简单明了,不要多余解释和阐述
'''
message = [{"role": "system", "content": prompt}]
message.extend(messages)
# print(message)
response = client.chat.completions.create(
messages=message,
model=model_name,
max_tokens=1000,
temperature=0.001,
stream=False
)
res = response.choices[0].message.content
print(f"多轮意图后用户想要的问题是{res}")
entries = slot_recognizer.recognize(res)
return entries
#必须槽位缺失检查
def check_lost(int_res, slot):
#labels: ["天气查询","互联网查询","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"]
mapping = {
2: [['page'], ['app'], ['module']],
3: [['date']],
4: [['date']],
5: [['date']],
6: [['date']],
7: [['date']],
8: [['date']],
}
intention_mapping = {2: "页面切换", 3: "日计划数量查询", 4: "周计划数量查询", 5: "日计划作业内容",
6: "周计划作业内容",7: "施工人数",8: "作业考勤人数"}
if not mapping.__contains__(int_res):
return 0, ""
#提取的槽位信息
cur_k = list(slot.keys())
idx = -1
idx_len = 99
for i in range(len(mapping[int_res])):
sk = mapping[int_res][i]
#不在提取的槽位信息里,但是在必须槽位表里
miss_params = [x for x in sk if x not in cur_k]
#不在必须槽位表里,但是在提取的槽位信息里
extra_params = [x for x in cur_k if x not in sk]
if len(extra_params) >= 0 and len(miss_params) == 0:
idx = i
idx_len = 0
break
if len(miss_params) < idx_len:
idx = i
idx_len = len(miss_params)
if idx_len == 0: # 匹配通过
return CheckResult.NO_MATCH, cur_k
#符合当前意图的的必须槽位,但是不在提取的槽位信息里
left = [x for x in mapping[int_res][idx] if x not in cur_k]
print(f"符合当前意图的的必须槽位,但是不在提取的槽位信息里, {left}")
apologize_str = "非常抱歉,"
if int_res == 2:
return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询哪个页面?"
elif int_res in [3, 4, 5, 6, 7, 8]:
return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}"
#标准化工程名和项目部名称
def check_project_standard_slot(int_res, slot) -> tuple:
intention_list = {3, 4, 5, 6, 7, 8}
if int_res not in intention_list:
return CheckResult.NO_MATCH, ""
for key, value in slot.items():
if key == PROJECT_NAME:
match_project, match_possibility = fuzzy_match(value, standard_project_name_list)
print(f"fuzzy_match project result:{match_project}, {match_possibility}")
if match_possibility >= SIMILARITY_VALUE:
slot[key] = match_project
else:
return CheckResult.NEEDS_MORE_ROUNDS, f"抱歉,您说的工程名是{match_project}"
if key == PROJECT_DEPARTMENT:
match_program, match_possibility = fuzzy_match(value, standard_program_name_list)
print(f"fuzzy_match program result:{match_program}, {match_possibility}")
if match_possibility >= SIMILARITY_VALUE:
slot[key] = match_program
else:
return CheckResult.NEEDS_MORE_ROUNDS, f"抱歉,您说的项目名是{match_program}"
return CheckResult.NO_MATCH, ""
def fuzzy_match(user_input, standard_name):
result = process.extract(user_input, standard_name)
return result[0][0], result[0][1]/100
if __name__ == '__main__':
app.run(host='0.0.0.0', port=18074, debug=True)