Intention/api/main.py

415 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from flask import Flask, jsonify, request
from pydantic import BaseModel, Field
from werkzeug.exceptions import HTTPException
from typing import List
from pydantic import ValidationError
from intentRecognition import IntentRecognition
from slotRecognition import SlotRecognition
from fuzzywuzzy import process
from utils import CheckResult, StandardType, load_standard_name
from constants import PROJECT_NAME, PROJECT_DEPARTMENT, SIMILARITY_VALUE, IMPLEMENTATION_ORG
from langchain_openai import OpenAIEmbeddings
from config import *
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-3540"
MODEL_UIE_PATH = R"../uie/output/checkpoint-3190"
# 类别名称列表
labels = [
"天气查询", "互联网查询", "页面切换", "日计划数量查询", "周计划数量查询",
"日计划作业内容", "周计划作业内容", "施工人数", "作业考勤人数", "知识问答",
"通用对话"
]
# 标签映射
label_map = {
0: 'O', # 非实体
1: 'B-date', 14: 'I-date',
2: 'B-projectName', 15: 'I-projectName',
3: 'B-projectType', 16: 'I-projectType',
4: 'B-constructionUnit', 17: 'I-constructionUnit',
5: 'B-implementationOrganization', 18: 'I-implementationOrganization',
6: 'B-projectDepartment', 19: 'I-projectDepartment',
7: 'B-projectManager', 20: 'I-projectManager',
8: 'B-subcontractor', 21: 'I-subcontractor',
9: 'B-teamLeader', 22: 'I-teamLeader',
10: 'B-riskLevel', 23: 'I-riskLevel',
11: 'B-page', 24: 'I-page',
12: 'B-operating', 25: 'I-operating',
13: 'B-teamName', 26: 'I-teamName',
}
params = {'model': 'bge-large-zh-v1.5',
'openai_api_base': 'http://127.0.0.1:9997/v1',
'openai_api_key': 'EMPTY',
'openai_proxy': ''}
# 初始化工具类
intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
# 初始化槽位识别工具类
slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map)
# 设置Flask应用
#标准工程名
standard_project_name_list = load_standard_name('./standard_data/standard_project.txt')
#标准项目名
standard_program_name_list = load_standard_name('./standard_data/standard_program.txt')
#标准公司名
standard_company_name_list = load_standard_name('./standard_data/standard_company.txt')
# 创建嵌入模型
embedding = OpenAIEmbeddings(**params)
# 获取标准项目部名称的嵌入向量
standard_program_embeddings = embedding.embed_documents(standard_program_name_list, chunk_size=500)
standard_project_embeddings = embedding.embed_documents(standard_project_name_list, chunk_size=500)
standard_company_embeddings = embedding.embed_documents(standard_company_name_list, chunk_size=500)
print(f":standard_project_name_list:{standard_project_name_list}")
app = Flask(__name__)
# 统一的异常处理函数
@app.errorhandler(Exception)
def handle_exception(e):
"""统一异常处理"""
if isinstance(e, HTTPException):
return jsonify({
"error": {
"type": e.name,
"message": e.description,
"status_code": e.code
}
}), e.code
return jsonify({
"error": {
"type": "InternalServerError",
"message": str(e)
}
}), 500
def validate_user(data):
"""验证用户ID"""
if data.get("user_id") != '3bb66776-1722-4c36-b14a-73dd210fe750':
return jsonify(
code=401,
msg='权限验证失败,请联系接口开发人员',
label=-1,
probability=-1
), 401
return None
class LabelMessage(BaseModel):
text: str = Field(..., description="消息内容")
user_id: str = Field(..., description="消息内容")
# 每条消息的结构
class Message(BaseModel):
role: str = Field(..., description="消息内容")
content: str = Field(..., description="消息内容")
# timestamp: str = Field(..., description="消息时间戳")
# 请求数据的结构
class RequestData(BaseModel):
messages: List[Message] = Field(..., description="消息列表")
user_id: str = Field(..., description="用户ID")
# 意图识别
@app.route('/intent_reco', methods=['POST'])
def intent_reco():
"""意图识别"""
try:
# 获取请求中的 JSON 数据
data = request.get_json()
request_data = LabelMessage(**data) # Pydantic 会验证数据结构
text = request_data.text
user_id = request_data.user_id
# 检查必需字段
if not text:
return jsonify({"error": "text is required"}), 400
if not user_id:
return jsonify({"error": "user_id is required"}), 400
# 验证用户ID
user_validation_error = validate_user(data)
if user_validation_error:
return user_validation_error
# 调用predict方法进行意图识别
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(text)
return jsonify(
code=200,
msg="成功",
int=predicted_id,
label=predicted_label,
probability=float(predicted_probability)
)
except Exception as e:
return jsonify({"error": str(e)}), 500
# 槽位抽取
@app.route('/slot_reco', methods=['POST'])
def slot_reco():
"""槽位识别"""
try:
# 获取请求中的 JSON 数据
data = request.get_json()
request_data = LabelMessage(**data) # Pydantic 会验证数据结构
text = request_data.text
user_id = request_data.user_id
# 检查必需字段
if not text:
return jsonify({"error": "text is required"}), 400
if not user_id:
return jsonify({"error": "user_id is required"}), 400
# 验证用户ID
user_validation_error = validate_user(data)
if user_validation_error:
return user_validation_error
# 调用 recognize 方法进行槽位识别
entities = slot_recognizer.recognize(text)
return jsonify(
code=200,
msg="成功",
slot=entities)
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route('/agent', methods=['POST'])
def agent():
try:
data = request.get_json()
# 使用 Pydantic 来验证数据结构
request_data = RequestData(**data) # Pydantic 会验证数据结构
messages = request_data.messages
user_id = request_data.user_id
# 检查必需字段是否存在
if not messages:
return jsonify({"error": "messages is required"}), 400
if not user_id:
return jsonify({"error": "user_id is required"}), 400
# 验证用户ID假设这个函数已经定义
user_validation_error = validate_user(data)
if user_validation_error:
return user_validation_error
if len(messages) == 1: # 首轮
query = messages[0].content # 使用 Message 对象的 .content 属性
# 先进行意图识别
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query)
# 再进行槽位抽取
entities = slot_recognizer.recognize(query)
print(f"第一轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}")
# 如果是后续轮次(多轮对话),这里只做示例,可能需要根据具体需求进行处理
else:
query = messages[0].content # 使用 Message 对象的 .content 属性
# 先进行意图识别
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query)
entities = multi_slot_recognizer(predicted_id, messages)
print(f"多轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}")
#必须槽位缺失检查
status, sk = check_lost(predicted_id, entities)
if status == CheckResult.NEEDS_MORE_ROUNDS:
return jsonify({"code": 10001, "msg": "成功",
"answer": { "miss": sk},
})
#工程名和项目名标准化
print(f"start to check_project_standard_slot")
result, information = check_project_standard_slot(predicted_id, entities)
print(f"end check_project_standard_slot,{result},{information}")
if result == CheckResult.NEEDS_MORE_ROUNDS:
return jsonify({
"code": 10001, "msg": "成功",
"answer": {"miss": information},
})
return jsonify({
"code": 200,"msg": "成功",
"answer": {"int": predicted_id, "label": predicted_label, "probability": predicted_probability, "slot": entities },
})
except ValidationError as e:
return jsonify({"error": e.errors()}), 400 # 捕捉 Pydantic 错误并返回
except Exception as e:
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
def multi_slot_recognizer(intention_id, messages):
from openai import OpenAI
client = OpenAI(base_url = api_base_url, api_key = api_key)
# prompt = f'''
# 根据用户的输入{messages},抽取出用户最近最想了解的一个问题,要求:保持客观真实,简单明了,不要多余解释和阐述,不需要输出如“用户想了解的问题”类似的字眼
# '''
prompt = f'''根据以下对话记录,提取用户最近一次提问的核心意图,根据上下文的回答内容并且关注用户最后的问题,提取出的意图需表述为完整的问题句式:
对话记录:{messages}'''
message = [{"role": "system", "content": prompt}]
message.extend(messages)
# print(message)
response = client.chat.completions.create(
messages=message,
model=model_name,
max_tokens=1000,
temperature=0.001,
stream=False
)
res = response.choices[0].message.content
print(f"多轮意图后用户想要的问题是{res}")
entries = slot_recognizer.recognize(res)
return entries
def check_lost(int_res, slot):
#labels: ["天气查询","通用对话","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"]
mapping = {
2: [['page'], ['app'], ['module']],
3: [['date']],
4: [['date']],
5: [['date']],
6: [['date']],
7: [['date']],
8: [['date']],
}
intention_mapping = {2: "页面切换", 3: "日计划数量查询", 4: "周计划数量查询", 5: "日计划作业内容",
6: "周计划作业内容",7: "施工人数",8: "作业考勤人数"}
if not mapping.__contains__(int_res):
return 0, ""
#提取的槽位信息
cur_k = list(slot.keys())
idx = -1
idx_len = 99
for i in range(len(mapping[int_res])):
sk = mapping[int_res][i]
#不在提取的槽位信息里,但是在必须槽位表里
miss_params = [x for x in sk if x not in cur_k]
#不在必须槽位表里,但是在提取的槽位信息里
extra_params = [x for x in cur_k if x not in sk]
if len(extra_params) >= 0 and len(miss_params) == 0:
idx = i
idx_len = 0
break
if len(miss_params) < idx_len:
idx = i
idx_len = len(miss_params)
if idx_len == 0: # 匹配通过
return CheckResult.NO_MATCH, cur_k
#符合当前意图的的必须槽位,但是不在提取的槽位信息里
left = [x for x in mapping[int_res][idx] if x not in cur_k]
print(f"符合当前意图的的必须槽位,但是不在提取的槽位信息里, {left}")
apologize_str = "非常抱歉,"
if int_res == 2:
return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询哪个页面?"
elif int_res in [3, 4, 5, 6, 7, 8]:
return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}"
#标准化工程名
def check_project_standard_slot(int_res, slot) -> tuple:
intention_list = {3, 4, 5, 6, 7, 8}
if int_res not in intention_list:
return CheckResult.NO_MATCH, ""
for key, value in slot.items():
if key == PROJECT_NAME:
# match_project, match_possibility = fuzzy_match(value, standard_project_name_list)
match_project, match_possibility = fuzzy_match(value,standard_project_embeddings,standard_project_name_list)
print(f"fuzzy_match project result:{match_project}, {match_possibility}")
if match_possibility >= SIMILARITY_VALUE:
slot[key] = match_project
else:
return CheckResult.NEEDS_MORE_ROUNDS, f"抱歉,您说的工程名是{match_project}"
if key == PROJECT_DEPARTMENT:
# match_program, match_possibility = fuzzy_match(value, standard_program_name_list)
match_program, match_possibility = fuzzy_match(value,standard_program_embeddings, standard_program_name_list)
print(f"fuzzy_match program result:{match_program}, {match_possibility}")
if match_possibility >= SIMILARITY_VALUE:
slot[key] = match_program
else:
return CheckResult.NEEDS_MORE_ROUNDS, f"抱歉,您说的项目名是{match_program}"
if key == IMPLEMENTATION_ORG and slot[key] != "公司":
match_company, match_possibility = fuzzy_match(value,standard_company_embeddings, standard_company_name_list)
print(f"fuzzy_match program result:{match_company}, {match_possibility}")
if match_possibility >= SIMILARITY_VALUE:
slot[key] = match_company
else:
return CheckResult.NEEDS_MORE_ROUNDS, f"抱歉,您说的分公司名是{match_company}"
return CheckResult.NO_MATCH, ""
# def fuzzy_match(user_input, standard_name):
# result = process.extract(user_input, standard_name)
# return result[0][0], result[0][1]/100
def fuzzy_match(query, standard_embeddings, standard_name_list):
"""
模糊匹配查询与标准名称列表,返回最相似名称及其相似度。
:param query: 查询名称
:param standard_embeddings: 标准名称的嵌入向量列表
:param standard_name_list: 标准名称列表
:return: 最相似名称, 相似度(保留 2 位小数)
"""
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
try:
# 查询名称的嵌入向量
query_embedding = embedding.embed_query(query)
# 计算相似度
similarities = cosine_similarity([query_embedding], standard_embeddings)[0]
# 找到最相似的项目名称
most_similar_index = np.argmax(similarities)
most_similar_name = standard_name_list[most_similar_index]
similarity_score = similarities[most_similar_index]
# 打印日志
print(f"输入名称: {query}")
print(f"最相似的名称: {most_similar_name}")
print(f"相似度: {similarity_score:.4f}")
return most_similar_name, round(similarity_score, 2)
except Exception as e:
print(f"相似性判断错误: {e}")
return None, None
# match_program, match_possibility = fuzzy_match("第一项目部定西")
# print(f"fuzzy_match program result:{match_program}, {match_possibility}")
if __name__ == '__main__':
app.run(host='0.0.0.0', port=18074, debug=True)