541 lines
28 KiB
Python
541 lines
28 KiB
Python
|
|
from flask import Flask, jsonify, request
|
|||
|
|
from pydantic import BaseModel, Field
|
|||
|
|
from werkzeug.exceptions import HTTPException
|
|||
|
|
from typing import List
|
|||
|
|
from pydantic import ValidationError
|
|||
|
|
|
|||
|
|
from intentRecognition import IntentRecognition
|
|||
|
|
from slotRecognition import SlotRecognition
|
|||
|
|
from utils import CheckResult, load_standard_name, generate_project_prompt, \
|
|||
|
|
load_standard_data, text_to_pinyin, multiple_standardize_single_name, \
|
|||
|
|
standardize_company_and_projectDepartment
|
|||
|
|
|
|||
|
|
from constants import PROJECT_NAME, PROJECT_DEPARTMENT, SIMILARITY_VALUE, IMPLEMENTATION_ORG, RISK_LEVEL
|
|||
|
|
from langchain_openai import OpenAIEmbeddings
|
|||
|
|
from config import *
|
|||
|
|
|
|||
|
|
# MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-16470"
|
|||
|
|
# MODEL_UIE_PATH = R"../uie/output_temp/checkpoint-17060"
|
|||
|
|
|
|||
|
|
# 类别名称列表
|
|||
|
|
labels = [
|
|||
|
|
"天气查询", "互联网查询", "页面切换", "日计划数量查询", "周计划数量查询",
|
|||
|
|
"日计划作业内容", "周计划作业内容", "施工人数", "作业考勤人数", "知识问答",
|
|||
|
|
"通用对话", "作业面查询","班组人数查询","班组数查询","作业面内容","班组详情"
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# 标签映射
|
|||
|
|
label_map = {
|
|||
|
|
0: 'O', # 非实体
|
|||
|
|
1: 'B-date', 14: 'I-date',
|
|||
|
|
2: 'B-projectName', 15: 'I-projectName',
|
|||
|
|
3: 'B-projectType', 16: 'I-projectType',
|
|||
|
|
4: 'B-constructionUnit', 17: 'I-constructionUnit',
|
|||
|
|
5: 'B-implementationOrganization', 18: 'I-implementationOrganization',
|
|||
|
|
6: 'B-projectDepartment', 19: 'I-projectDepartment',
|
|||
|
|
7: 'B-projectManager', 20: 'I-projectManager',
|
|||
|
|
8: 'B-subcontractor', 21: 'I-subcontractor',
|
|||
|
|
9: 'B-teamLeader', 22: 'I-teamLeader',
|
|||
|
|
10: 'B-riskLevel', 23: 'I-riskLevel',
|
|||
|
|
11: 'B-page', 24: 'I-page',
|
|||
|
|
12: 'B-operating', 25: 'I-operating',
|
|||
|
|
13: 'B-teamName', 26: 'I-teamName',
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# # 初始化工具类
|
|||
|
|
# intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
|
|||
|
|
#
|
|||
|
|
# # 初始化槽位识别工具类
|
|||
|
|
# slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map)
|
|||
|
|
# 设置Flask应用
|
|||
|
|
|
|||
|
|
#标准公司名和项目名
|
|||
|
|
standard_company_program = load_standard_data("./standard_data/standard_company_program.json")
|
|||
|
|
|
|||
|
|
# 标准工程名,标准工程名拼音和工程名映射,标准工程名拼音
|
|||
|
|
standard_project_name_list = load_standard_name('./standard_data/standard_project.txt')
|
|||
|
|
pinyin_to_standard_project_name_map = {text_to_pinyin(kw): kw for kw in standard_project_name_list}
|
|||
|
|
standard_project_name_pinyin_list = list(pinyin_to_standard_project_name_map.keys())
|
|||
|
|
|
|||
|
|
#标准分公司名,标准分公司名拼音和分公司名映射,标公司名拼音
|
|||
|
|
standard_company_name_list = list(standard_company_program.keys())
|
|||
|
|
pinyin_to_standard_company_name_map = {text_to_pinyin(kw): kw for kw in standard_company_name_list}
|
|||
|
|
standard_company_name_pinyin_list = list(pinyin_to_standard_company_name_map.keys())
|
|||
|
|
|
|||
|
|
print(f"标准化的工程名是:{standard_project_name_list}", flush=True)
|
|||
|
|
print(f"pinyin标准化的工程名是 list:{standard_project_name_pinyin_list}", flush=True)
|
|||
|
|
print(f"pinyin-工程民对应关系 map:{pinyin_to_standard_company_name_map}", flush=True)
|
|||
|
|
|
|||
|
|
app = Flask(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 统一的异常处理函数
|
|||
|
|
@app.errorhandler(Exception)
|
|||
|
|
def handle_exception(e):
|
|||
|
|
"""统一异常处理"""
|
|||
|
|
if isinstance(e, HTTPException):
|
|||
|
|
return jsonify({
|
|||
|
|
"error": {
|
|||
|
|
"type": e.name,
|
|||
|
|
"message": e.description,
|
|||
|
|
"status_code": e.code
|
|||
|
|
}
|
|||
|
|
}), e.code
|
|||
|
|
return jsonify({
|
|||
|
|
"error": {
|
|||
|
|
"type": "InternalServerError",
|
|||
|
|
"message": str(e)
|
|||
|
|
}
|
|||
|
|
}), 500
|
|||
|
|
|
|||
|
|
|
|||
|
|
def validate_user(data):
|
|||
|
|
"""验证用户ID"""
|
|||
|
|
if data.get("user_id") != '3bb66776-1722-4c36-b14a-73dd210fe750':
|
|||
|
|
return jsonify(
|
|||
|
|
code=401,
|
|||
|
|
msg='权限验证失败,请联系接口开发人员',
|
|||
|
|
label=-1,
|
|||
|
|
probability=-1
|
|||
|
|
), 401
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LabelMessage(BaseModel):
|
|||
|
|
text: str = Field(..., description="消息内容")
|
|||
|
|
user_id: str = Field(..., description="消息内容")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 每条消息的结构
|
|||
|
|
class Message(BaseModel):
|
|||
|
|
role: str = Field(..., description="消息内容")
|
|||
|
|
content: str = Field(..., description="消息内容")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 请求数据的结构
|
|||
|
|
class RequestData(BaseModel):
|
|||
|
|
messages: List[Message] = Field(..., description="消息列表")
|
|||
|
|
user_id: str = Field(..., description="用户ID")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 意图识别
|
|||
|
|
@app.route('/intent_reco', methods=['POST'])
|
|||
|
|
def intent_reco():
|
|||
|
|
"""意图识别"""
|
|||
|
|
try:
|
|||
|
|
# 获取请求中的 JSON 数据
|
|||
|
|
data = request.get_json()
|
|||
|
|
request_data = LabelMessage(**data) # Pydantic 会验证数据结构
|
|||
|
|
text = request_data.text
|
|||
|
|
user_id = request_data.user_id
|
|||
|
|
# 检查必需字段
|
|||
|
|
if not text:
|
|||
|
|
return jsonify({"error": "text is required"}), 400
|
|||
|
|
if not user_id:
|
|||
|
|
return jsonify({"error": "user_id is required"}), 400
|
|||
|
|
|
|||
|
|
# 验证用户ID
|
|||
|
|
user_validation_error = validate_user(data)
|
|||
|
|
if user_validation_error:
|
|||
|
|
return user_validation_error
|
|||
|
|
|
|||
|
|
# 调用predict方法进行意图识别
|
|||
|
|
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(text)
|
|||
|
|
|
|||
|
|
return jsonify(
|
|||
|
|
code=200,
|
|||
|
|
msg="成功",
|
|||
|
|
int=predicted_id,
|
|||
|
|
label=predicted_label,
|
|||
|
|
probability=float(predicted_probability)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
return jsonify({"error": str(e)}), 500
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 槽位抽取
|
|||
|
|
@app.route('/slot_reco', methods=['POST'])
|
|||
|
|
def slot_reco():
|
|||
|
|
"""槽位识别"""
|
|||
|
|
try:
|
|||
|
|
# 获取请求中的 JSON 数据
|
|||
|
|
data = request.get_json()
|
|||
|
|
request_data = LabelMessage(**data) # Pydantic 会验证数据结构
|
|||
|
|
text = request_data.text
|
|||
|
|
user_id = request_data.user_id
|
|||
|
|
|
|||
|
|
# 检查必需字段
|
|||
|
|
if not text:
|
|||
|
|
return jsonify({"error": "text is required"}), 400
|
|||
|
|
if not user_id:
|
|||
|
|
return jsonify({"error": "user_id is required"}), 400
|
|||
|
|
|
|||
|
|
# 验证用户ID
|
|||
|
|
user_validation_error = validate_user(data)
|
|||
|
|
if user_validation_error:
|
|||
|
|
return user_validation_error
|
|||
|
|
|
|||
|
|
# 调用 recognize 方法进行槽位识别
|
|||
|
|
entities = slot_recognizer.recognize(text)
|
|||
|
|
|
|||
|
|
return jsonify(
|
|||
|
|
code=200,
|
|||
|
|
msg="成功",
|
|||
|
|
slot=entities)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
return jsonify({"error": str(e)}), 500
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.route('/agent', methods=['POST'])
|
|||
|
|
def agent():
|
|||
|
|
try:
|
|||
|
|
data = request.get_json()
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"body不是一个有效的json")
|
|||
|
|
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
|
|||
|
|
try:
|
|||
|
|
# 使用 Pydantic 来验证数据结构
|
|||
|
|
request_data = RequestData(**data) # Pydantic 会验证数据结构
|
|||
|
|
messages = request_data.messages
|
|||
|
|
user_id = request_data.user_id
|
|||
|
|
|
|||
|
|
# 检查必需字段是否存在
|
|||
|
|
if not messages:
|
|||
|
|
return jsonify({"error": "messages is required"}), 400
|
|||
|
|
if not user_id:
|
|||
|
|
return jsonify({"error": "user_id is required"}), 400
|
|||
|
|
|
|||
|
|
# 验证用户ID(假设这个函数已经定义)
|
|||
|
|
user_validation_error = validate_user(data)
|
|||
|
|
if user_validation_error:
|
|||
|
|
return user_validation_error
|
|||
|
|
if len(messages) == 1: # 首轮
|
|||
|
|
query = messages[0].content # 使用 Message 对象的 .content 属性
|
|||
|
|
# 先进行意图识别
|
|||
|
|
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query)
|
|||
|
|
# 再进行槽位抽取
|
|||
|
|
entities = slot_recognizer.recognize(query)
|
|||
|
|
|
|||
|
|
print(
|
|||
|
|
f"第一轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}",flush=True)
|
|||
|
|
# 多轮
|
|||
|
|
else:
|
|||
|
|
res = extract_multi_chat(messages)
|
|||
|
|
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(res)
|
|||
|
|
#0:天气,1:互联网查询,9:知识问答,10:通用对话
|
|||
|
|
if predicted_id in [0, 1, 9, 10]:
|
|||
|
|
print(f"多轮意图识别后的label:{predicted_label}, id:{predicted_id},message:{messages}",
|
|||
|
|
flush=True)
|
|||
|
|
return jsonify({
|
|||
|
|
"code": 200, "msg": "成功",
|
|||
|
|
"answer": {"int": predicted_id, "label": predicted_label, "probability": predicted_probability},
|
|||
|
|
"finalQuery": res
|
|||
|
|
})
|
|||
|
|
entities = slot_recognizer.recognize(res)
|
|||
|
|
print(
|
|||
|
|
f"多轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}",flush=True)
|
|||
|
|
|
|||
|
|
#必须槽位缺失检查
|
|||
|
|
status, sk = check_lost(predicted_id, entities)
|
|||
|
|
if status == CheckResult.NEEDS_MORE_ROUNDS:
|
|||
|
|
return jsonify({"code": 10001, "msg": "成功",
|
|||
|
|
"answer": {"miss": sk},
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
#工程名和项目名标准化
|
|||
|
|
result, information = check_project_standard_slot(predicted_id, entities)
|
|||
|
|
if result == CheckResult.NEEDS_MORE_ROUNDS:
|
|||
|
|
return jsonify({
|
|||
|
|
"code": 10001, "msg": "成功",
|
|||
|
|
"answer": {"miss": information},
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return jsonify({
|
|||
|
|
"code": 200, "msg": "成功",
|
|||
|
|
"answer": {"int": predicted_id, "label": predicted_label, "probability": predicted_probability,
|
|||
|
|
"slot": entities},
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
except ValidationError as e:
|
|||
|
|
return jsonify({"error": e.errors()}), 400 # 捕捉 Pydantic 错误并返回
|
|||
|
|
except Exception as e:
|
|||
|
|
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
|
|||
|
|
|
|||
|
|
def extract_multi_chat(messages):
|
|||
|
|
from openai import OpenAI
|
|||
|
|
client = OpenAI(base_url=api_base_url, api_key=api_key)
|
|||
|
|
|
|||
|
|
# 格式化对话历史
|
|||
|
|
chat_history = "\n".join([f"{msg.role}: {msg.content}" for msg in messages])
|
|||
|
|
|
|||
|
|
prompt = f'''你是一个智能助手,需要从以下对话记录中提取用户最近一次提问的完整问题:
|
|||
|
|
1. **仅关注用户的最后一个问题**,无论之前用户提问了什么,**不要受到之前用户问题的影响**。
|
|||
|
|
2. **如果用户的最后一个问题包含指代词**(如“作业计划分别是什么”、“具体是哪2项”、“刚刚那个故事”、“明天呢”、“合肥中心变工程呢”等),请结合用户上一次的问题和**AI(助手)回答**,补充信息,使问题成为完整的句子。
|
|||
|
|
3. **如果用户的最后一个问题的主语是“公司”这个字眼(如“公司今天有多少四级风险作业计划”或“公司今天有多少4级风险的作业面”)则不要参考对话历史进行补全,保持用户原始表达,不要替换为具体的公司名,工程名或项目部名等。**
|
|||
|
|
4. **如果用户的最后一个问题本身是完整的**(即未使用上述2里的指代词),直接输出该问题,不要受前文影响。
|
|||
|
|
5. **如果问题缺少上下文信息**(如工程、项目部和时间等),仅在**最近的 AI 回答**提供了明确的上下文时进行补全,否则保持用户的原始输入,不要添加错误的补全信息。
|
|||
|
|
6. **如果用户的最新问题包含时间信息**(如“今天、明天、本周”),请确保其被保留,并且不改变时间表达方式。
|
|||
|
|
- **如果用户的提问本身省略了时间信息,但最近 AI 回答包含时间信息,则补全时间**。
|
|||
|
|
- **例如:用户问“具体是哪20项”时,最近 AI 回答是“今天送1分公司第二项目管理部有20项作业计划”,那么补全后的问题应为“今天送1分公司第二项目管理部具体是哪20项作业计划”**。
|
|||
|
|
7. **不要改写问题的主体和语序**,仅在需要时补全信息,避免误修改用户原始表达。
|
|||
|
|
8. 直接输出补全后的完整问题,不需要额外解释,也不需要输出“用户想了解的问题”这样的字眼。
|
|||
|
|
9. **当用户的最后一条消息使用了“第一个”、“第1个”、“第2个”……等指代方式,且上一条 AI 回复中列出了多个选项(如多个工程名、公司名、项目部等),你需要:**
|
|||
|
|
- 精确提取用户所指的序号(如“第3个”指第3个工程名、公司名或项目部名);
|
|||
|
|
- 将该工程、公司或项目部的完整名称(包括括号中的编号)提取出来;
|
|||
|
|
- **用完整名称替换掉用户上一个问题中出现的简称或模糊表达,并保留用户问题中的其它部分(如时间、计划数、内容)不变**;
|
|||
|
|
- 示例:
|
|||
|
|
- 原始问题:`2025年南苑调相机检修(PROJ-2023-0179)今天有多少作业计划`
|
|||
|
|
- AI 回答:列出多个工程,第1个是`检修公司调相机一二次设备检修维护和改造服务框架-2025年南苑调相机检修(PROJ-2023-0179)`
|
|||
|
|
- 用户回复:“第1个”
|
|||
|
|
- 则最终提问应为:
|
|||
|
|
`检修公司调相机一二次设备检修维护和改造服务框架-2025年南苑调相机检修(PROJ-2023-0179)今天有多少作业计划`
|
|||
|
|
**对话记录:**
|
|||
|
|
{chat_history}
|
|||
|
|
|
|||
|
|
请提取并补全用户的最新问题:'''
|
|||
|
|
|
|||
|
|
|
|||
|
|
# prompt = f'''你是一个智能助手,需要从以下对话记录中提取用户最近一次提问的完整问题:
|
|||
|
|
# 1. **仅关注用户的最后一个问题**,无论之前用户提问了什么,**不要受到之前用户问题的影响**。
|
|||
|
|
# 2. **如果用户的最后一个问题包含指代词**(如“作业计划分别是什么”、“具体是哪2项”、“刚刚那个故事”、“明天呢”、“合肥中心变工程呢”等),请结合用户上一次的问题和**AI(助手)回答**,补充信息,使问题成为完整的句子。
|
|||
|
|
# 3. **如果用户的最后一个问题本身是完整的**(即未使用上述2里的指代词),直接输出该问题,不要受前文影响。
|
|||
|
|
# - **如果用户的最后一个问题包含“公司”字眼并且“公司”单独出现(如“公司今天有多少作业计划”)**,则不要参考对话历史进行补全,保持用户原始表达。
|
|||
|
|
# - **如果用户的最后一个问题里公司,工程,项目部等都不出现(如“今天有多少作业计划”),则不要参考对话历史进行补全,保持用户原始表达。
|
|||
|
|
# 4. **如果问题缺少上下文信息**(如工程、项目部和时间等),仅在**最近的 AI 回答**提供了明确的上下文时进行补全,否则保持用户的原始输入,不要添加错误的补全信息。
|
|||
|
|
# 5. **如果用户的最新问题包含时间信息**(如“今天、明天、本周”),请确保其被保留,并且不改变时间表达方式。
|
|||
|
|
# - **如果用户的提问本身省略了时间信息,但最近 AI 回答包含时间信息,则补全时间**。
|
|||
|
|
# - **例如:用户问“具体是哪20项”时,最近 AI 回答是“今天送1分公司第二项目管理部有20项作业计划”,那么补全后的问题应为“今天送1分公司第二项目管理部具体是哪20项作业计划”**。
|
|||
|
|
# 6. **不要改写问题的主体和语序**,仅在需要时补全信息,避免误修改用户原始表达。
|
|||
|
|
# 7. 直接输出补全后的完整问题,不需要额外解释,也不需要输出“用户想了解的问题”这样的字眼。
|
|||
|
|
#
|
|||
|
|
# **对话记录:**
|
|||
|
|
# {chat_history}
|
|||
|
|
#
|
|||
|
|
# 请提取并补全用户的最新问题:'''
|
|||
|
|
|
|||
|
|
message = [
|
|||
|
|
{"role": "system", "content": "你是一个智能助手,负责提取用户最近的问题,并自动补全缺失信息,使其成为完整的问题句子。"},
|
|||
|
|
{"role": "user", "content": prompt}
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
response = client.chat.completions.create(
|
|||
|
|
messages=message,
|
|||
|
|
model=model_name,
|
|||
|
|
max_tokens=100,
|
|||
|
|
temperature=0.3, # 降低随机性,提高确定性
|
|||
|
|
stream=False
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
res = response.choices[0].message.content.strip()
|
|||
|
|
print(f"多轮意图后用户想要的问题是:{res}", flush=True)
|
|||
|
|
return res
|
|||
|
|
|
|||
|
|
# def multi_slot_recognizer(intention_id, messages):
|
|||
|
|
# from openai import OpenAI
|
|||
|
|
# client = OpenAI(base_url=api_base_url, api_key=api_key)
|
|||
|
|
#
|
|||
|
|
# # prompt = f'''
|
|||
|
|
# # 根据用户的输入{messages},抽取出用户最近最想了解的一个问题,要求:保持客观真实,简单明了,不要多余解释和阐述,不需要输出如“用户想了解的问题”类似的字眼
|
|||
|
|
# # '''
|
|||
|
|
# prompt = f'''根据以下对话记录,提取用户最近一次提问的核心意图,根据关键信息和上下文的回答内容并且关注用户最后的问题,提取出的意图需表述为完整的问题句式:
|
|||
|
|
# 对话记录:{messages}'''
|
|||
|
|
#
|
|||
|
|
# message = [{"role": "system", "content": prompt}]
|
|||
|
|
# message.extend(messages)
|
|||
|
|
# # print(message)
|
|||
|
|
# response = client.chat.completions.create(
|
|||
|
|
# messages=message,
|
|||
|
|
# model=model_name,
|
|||
|
|
# max_tokens=1000,
|
|||
|
|
# temperature=0.001,
|
|||
|
|
# stream=False
|
|||
|
|
# )
|
|||
|
|
# res = response.choices[0].message.content
|
|||
|
|
#
|
|||
|
|
# print(f"多轮意图后用户想要的问题是{res}",flush=True)
|
|||
|
|
# entries = slot_recognizer.recognize(res)
|
|||
|
|
#
|
|||
|
|
# return entries
|
|||
|
|
|
|||
|
|
def check_lost(int_res, slot):
|
|||
|
|
#labels: ["天气查询","通用对话","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"]
|
|||
|
|
mapping = {
|
|||
|
|
2: [['page'], ['app'], ['module']],
|
|||
|
|
3: [['date']],
|
|||
|
|
4: [['date']],
|
|||
|
|
5: [['date']],
|
|||
|
|
6: [['date']],
|
|||
|
|
7: [['date']],
|
|||
|
|
8: [['date']],
|
|||
|
|
11: [['date']],
|
|||
|
|
12: [['date']],
|
|||
|
|
13: [['date']],
|
|||
|
|
14: [['date']],
|
|||
|
|
15: [['date']],
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
intention_mapping = {2: "页面切换", 3: "日计划数量查询", 4: "周计划数量查询", 5: "日计划作业内容",
|
|||
|
|
6: "周计划作业内容", 7: "施工人数", 8: "作业考勤人数", 11: "作业面查询",
|
|||
|
|
12:"班组人数查询", 13:"班组数查询", 14:"作业面内容", 15:"班组详情"}
|
|||
|
|
if not mapping.__contains__(int_res):
|
|||
|
|
return 0, ""
|
|||
|
|
#提取的槽位信息
|
|||
|
|
cur_k = list(slot.keys())
|
|||
|
|
idx = -1
|
|||
|
|
idx_len = 99
|
|||
|
|
for i in range(len(mapping[int_res])):
|
|||
|
|
sk = mapping[int_res][i]
|
|||
|
|
#不在提取的槽位信息里,但是在必须槽位表里
|
|||
|
|
miss_params = [x for x in sk if x not in cur_k]
|
|||
|
|
#不在必须槽位表里,但是在提取的槽位信息里
|
|||
|
|
extra_params = [x for x in cur_k if x not in sk]
|
|||
|
|
if len(extra_params) >= 0 and len(miss_params) == 0:
|
|||
|
|
idx = i
|
|||
|
|
idx_len = 0
|
|||
|
|
break
|
|||
|
|
if len(miss_params) < idx_len:
|
|||
|
|
idx = i
|
|||
|
|
idx_len = len(miss_params)
|
|||
|
|
|
|||
|
|
if idx_len == 0: # 匹配通过
|
|||
|
|
return CheckResult.NO_MATCH, cur_k
|
|||
|
|
#符合当前意图的的必须槽位,但是不在提取的槽位信息里
|
|||
|
|
left = [x for x in mapping[int_res][idx] if x not in cur_k]
|
|||
|
|
print(f"符合当前意图的的必须槽位,但是不在提取的槽位信息里, {left}",flush=True)
|
|||
|
|
apologize_str = "非常抱歉,"
|
|||
|
|
if int_res == 2:
|
|||
|
|
return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询哪个页面?"
|
|||
|
|
elif int_res in [3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15]:
|
|||
|
|
return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}?"
|
|||
|
|
|
|||
|
|
|
|||
|
|
#标准化工程名
|
|||
|
|
def check_project_standard_slot(int_res, slot) -> tuple:
|
|||
|
|
intention_list = {3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15}
|
|||
|
|
if int_res not in intention_list:
|
|||
|
|
return CheckResult.NO_MATCH, ""
|
|||
|
|
|
|||
|
|
#项目名 当项目名存在时需要一定存在分公司(实施组织)名
|
|||
|
|
if PROJECT_DEPARTMENT in slot:
|
|||
|
|
if IMPLEMENTATION_ORG not in slot:
|
|||
|
|
return CheckResult.NEEDS_MORE_ROUNDS, "请补充该项目部所属的分公司名称"
|
|||
|
|
else:
|
|||
|
|
standard_company, matched_projectDepartment = standardize_company_and_projectDepartment(slot[IMPLEMENTATION_ORG], slot[PROJECT_DEPARTMENT], standard_company_name_list, standard_company_program, pinyin_to_standard_company_name_map)
|
|||
|
|
print(f"check_project_standard_slot : {slot[IMPLEMENTATION_ORG]}, {slot[PROJECT_DEPARTMENT]}")
|
|||
|
|
# if not standard_company:
|
|||
|
|
# return CheckResult.NEEDS_MORE_ROUNDS, f"未匹配到您说的分公司名:{slot[IMPLEMENTATION_ORG]},请提供更准确的分公司名"
|
|||
|
|
# if not matched_projectDepartment:
|
|||
|
|
# return CheckResult.NEEDS_MORE_ROUNDS, f"未匹配到您说的项目名:{slot[PROJECT_DEPARTMENT]},请提供更准确的项目名"
|
|||
|
|
# if len(standard_company) > 1:
|
|||
|
|
# prompt = generate_project_prompt(matched_projectDepartment)
|
|||
|
|
# return CheckResult.NEEDS_MORE_ROUNDS, prompt
|
|||
|
|
# if len(matched_projectDepartment) == 1:
|
|||
|
|
# slot[IMPLEMENTATION_ORG] = standard_company[0]
|
|||
|
|
# slot[PROJECT_DEPARTMENT] = matched_projectDepartment[0]
|
|||
|
|
# elif len(matched_projectDepartment) > 1:
|
|||
|
|
# prompt = generate_project_prompt(standard_company, original_name=slot[IMPLEMENTATION_ORG], type="分公司名")
|
|||
|
|
# return CheckResult.NEEDS_MORE_ROUNDS, prompt
|
|||
|
|
if not standard_company:
|
|||
|
|
return CheckResult.NEEDS_MORE_ROUNDS, f"未匹配到您说的分公司名:{slot[IMPLEMENTATION_ORG]},请提供更准确的分公司名"
|
|||
|
|
if not matched_projectDepartment:
|
|||
|
|
return CheckResult.NEEDS_MORE_ROUNDS, f"未匹配到您说的项目名:{slot[PROJECT_DEPARTMENT]},请提供更准确的项目名"
|
|||
|
|
if standard_company and len(matched_projectDepartment) == 1:
|
|||
|
|
slot[IMPLEMENTATION_ORG] = standard_company
|
|||
|
|
slot[PROJECT_DEPARTMENT] = matched_projectDepartment[0]
|
|||
|
|
elif standard_company and len(matched_projectDepartment) > 1:
|
|||
|
|
prompt = generate_project_prompt(matched_projectDepartment)
|
|||
|
|
return CheckResult.NEEDS_MORE_ROUNDS, prompt
|
|||
|
|
|
|||
|
|
#工程名和分公司名标准化
|
|||
|
|
for key, value in slot.items():
|
|||
|
|
if key == PROJECT_NAME:
|
|||
|
|
print(f"check_project_standard_slot original project : {slot[PROJECT_NAME]}")
|
|||
|
|
match_results = multiple_standardize_single_name(value, standard_project_name_list, standard_project_name_pinyin_list, pinyin_to_standard_project_name_map,20,70)
|
|||
|
|
print(f"standardize_single_name 工程名 :result:{match_results}",flush=True)
|
|||
|
|
if match_results and len(match_results) == 1:
|
|||
|
|
slot[key] = match_results[0]
|
|||
|
|
else:
|
|||
|
|
prompt = generate_project_prompt(match_results, original_name=slot[PROJECT_NAME], type="工程名")
|
|||
|
|
return CheckResult.NEEDS_MORE_ROUNDS, prompt
|
|||
|
|
|
|||
|
|
if key == IMPLEMENTATION_ORG and slot[key] != "公司":
|
|||
|
|
print(f"check_project_standard_slot original company : {slot[IMPLEMENTATION_ORG]}")
|
|||
|
|
match_results = multiple_standardize_single_name(value, standard_company_name_list, standard_company_name_pinyin_list, pinyin_to_standard_company_name_map, lower_score=50, high_score=80, isArabicNumConv = True)
|
|||
|
|
print(f"check_project_standard_slot 分公司名: result:{match_results}",flush=True)
|
|||
|
|
if match_results and len(match_results) == 1:
|
|||
|
|
slot[key] = match_results[0]
|
|||
|
|
else:
|
|||
|
|
prompt = generate_project_prompt(match_results, original_name=slot[IMPLEMENTATION_ORG], type="分公司名")
|
|||
|
|
return CheckResult.NEEDS_MORE_ROUNDS, prompt
|
|||
|
|
if key == RISK_LEVEL:
|
|||
|
|
if slot[RISK_LEVEL] not in["2级","3级","4级","5级"] and slot[RISK_LEVEL] not in["二级","三级","四级","五级"]:
|
|||
|
|
return CheckResult.NEEDS_MORE_ROUNDS, "您查询的风险等级在系统中未找到,请确认风险等级后再次提问"
|
|||
|
|
|
|||
|
|
return CheckResult.NO_MATCH, ""
|
|||
|
|
|
|||
|
|
# test_cases = [
|
|||
|
|
# ("安徽宏源电力建设有限公司", "第三项目管理部"), # 期望返回所有"第三项目管理部"
|
|||
|
|
# ("安徽宏源电力建设有限公司", "第九项目部"), # 期望返回 "第九项目管理部"
|
|||
|
|
# ("顺安电网公司", "第二项目部"), # 期望匹配"顺安电网建设有限公司"下的"第二项目管理部"
|
|||
|
|
# ("送电一公司", "第三项目部"), # 期望返回"第三项目管理部"
|
|||
|
|
# ("送电2公司", "第三项目部"), # 期望返回"第三项目管理部"
|
|||
|
|
# ("消防分公司", "第七项目部"), # 期望返回"第七项目管理部
|
|||
|
|
# ("建筑分公司", "第七项目部"), # 期望返回"第七项目管理部"
|
|||
|
|
# ("建筑消防分公司", "第七项目部"), # 期望返回"第七项目管理部"
|
|||
|
|
# ("建筑分公司消防分公司", "第七项目部") # 期望返回"第七项目管理部"
|
|||
|
|
# ]
|
|||
|
|
#
|
|||
|
|
# for company, project in test_cases:
|
|||
|
|
# # result = standardize_company_and_project(company, project,standard_company_program)
|
|||
|
|
# result = standardize_company_and_projectDepartment(company, project,standard_company_name_list, standard_company_program, pinyin_to_standard_company_name_map)
|
|||
|
|
# # result = multiple_standardize_single_name("company", standard_project_name_list, standard_project_name_pinyin_list, pinyin_to_standard_project_name_map,40,70)
|
|||
|
|
# print(f"输入: {company}, {project} -> 输出: {result}")
|
|||
|
|
#
|
|||
|
|
# result = standardize_single_name("送电一公司", standard_company_name_list)
|
|||
|
|
# print(f"输入: 送一分公司-> 输出: {result}")
|
|||
|
|
#
|
|||
|
|
# prompt = generate_project_prompt(result, "分公司名")
|
|||
|
|
# print(f"prompt:{prompt}")
|
|||
|
|
#
|
|||
|
|
# result = standardize_single_name("合肥中心变", standard_project_name_list)
|
|||
|
|
# print(f"输入: 合肥中心变-> 输出: {result}")
|
|||
|
|
#
|
|||
|
|
# prompt = generate_project_prompt(result, "工程名")
|
|||
|
|
# print(f"prompt:{prompt}")
|
|||
|
|
|
|||
|
|
# result = standardize_single_name("合肥中心变", standard_project_name_list, 60, 75)
|
|||
|
|
# print(f"输入: 合肥中心变-> 输出: {result}")
|
|||
|
|
#
|
|||
|
|
# result = standardize_single_name("阜阳阜四变电站工程", standard_project_name_list, 60, 75)
|
|||
|
|
# print(f"输入: 阜阳阜四变电站工程-> 输出: {result}")
|
|||
|
|
#
|
|||
|
|
# result = standardize_single_name("芦集变电站", standard_project_name_list, 20, 50)
|
|||
|
|
# print(f"输入: 芦集变电站-> 输出: {result}")
|
|||
|
|
#
|
|||
|
|
# match_results = multiple_standardize_single_name("宋轶分公司", standard_company_name_list, standard_company_name_pinyin_list, pinyin_to_standard_company_name_map,75,80)
|
|||
|
|
# print(f"standardize_pinyin_single_name 输入: 宋轶分公司-> 输出: {match_results}")
|
|||
|
|
# #
|
|||
|
|
# match_results = multiple_standardize_single_name("合肥中心变", standard_project_name_list, standard_project_name_pinyin_list, pinyin_to_standard_project_name_map,40,70)
|
|||
|
|
# print(f"standardize_pinyin_single_name 输入: 合肥中心变-> 输出: {match_results}")
|
|||
|
|
#
|
|||
|
|
# match_results = multiple_standardize_single_name("淮南安丰", standard_project_name_list, standard_project_name_pinyin_list, pinyin_to_standard_project_name_map,40,70)
|
|||
|
|
# print(f"standardize_pinyin_single_name 输入: 淮南安丰工程-> 输出: {match_results}")
|
|||
|
|
#
|
|||
|
|
# match_results = multiple_standardize_single_name("芦集变电站", standard_project_name_list, standard_project_name_pinyin_list, pinyin_to_standard_project_name_map,20,70)
|
|||
|
|
# print(f"standardize_pinyin_single_name 输入: 芦集变电站-> 输出: {match_results}")
|
|||
|
|
#
|
|||
|
|
# company, project = standardize_company_and_projectDepartment("变电分公司","第一项目部", standard_company_name_list, standard_company_program, pinyin_to_standard_company_name_map)
|
|||
|
|
# print(f"company:{company}, project:{project}")
|
|||
|
|
#
|
|||
|
|
# company, project = standardize_company_and_projectDepartment("变电分公司","第十一项目部", standard_company_name_list, standard_company_program, pinyin_to_standard_company_name_map)
|
|||
|
|
# print(f"company:{company}, project:{project}")
|
|||
|
|
# company, project = standardize_company_and_projectDepartment("试验分公司","电缆班", standard_company_name_list, standard_company_program, pinyin_to_standard_company_name_map)
|
|||
|
|
# print(f"company:{company}, project:{project}")
|
|||
|
|
company, project = standardize_company_and_projectDepartment("宏源电力投资有限公司","第三项目部", standard_company_name_list, standard_company_program, pinyin_to_standard_company_name_map)
|
|||
|
|
print(f"company:{company}, project:{project}")
|
|||
|
|
#
|
|||
|
|
# if __name__ == '__main__':
|
|||
|
|
# app.run(host='0.0.0.0', port=18073, debug=True)
|