多轮问询优化
This commit is contained in:
parent
4bc78e36fc
commit
62db786d2b
81
api/mian.py
81
api/mian.py
|
|
@ -201,45 +201,74 @@ def agent():
|
||||||
# 再进行槽位抽取
|
# 再进行槽位抽取
|
||||||
entities = slot_recognizer.recognize(query)
|
entities = slot_recognizer.recognize(query)
|
||||||
|
|
||||||
print(f"意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}")
|
print(f"第一轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}")
|
||||||
|
|
||||||
#必须槽位缺失检查
|
|
||||||
status, sk = check_lost(predicted_id, entities)
|
|
||||||
if status == CheckResult.NEEDS_MORE_ROUNDS:
|
|
||||||
return jsonify({"code": 10001, "msg": "成功",
|
|
||||||
"answer": { "miss": sk},
|
|
||||||
})
|
|
||||||
|
|
||||||
#工程名和项目名标准化
|
|
||||||
print(f"start to check_project_standard_slot")
|
|
||||||
result, information = check_project_standard_slot(predicted_id, entities)
|
|
||||||
print(f"end check_project_standard_slot,{result},{information}")
|
|
||||||
if result == CheckResult.NEEDS_MORE_ROUNDS:
|
|
||||||
return jsonify({
|
|
||||||
"code": 10001, "msg": "成功",
|
|
||||||
"answer": {"miss": information},
|
|
||||||
})
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
"code": 200,"msg": "成功",
|
|
||||||
"answer": {"int": predicted_id, "label": predicted_label, "probability": predicted_probability, "slot": entities },
|
|
||||||
})
|
|
||||||
|
|
||||||
# 如果是后续轮次(多轮对话),这里只做示例,可能需要根据具体需求进行处理
|
# 如果是后续轮次(多轮对话),这里只做示例,可能需要根据具体需求进行处理
|
||||||
else:
|
else:
|
||||||
query = messages[0].content # 使用 Message 对象的 .content 属性
|
query = messages[0].content # 使用 Message 对象的 .content 属性
|
||||||
|
# 先进行意图识别
|
||||||
|
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query)
|
||||||
|
entities = multi_slot_recognizer(predicted_id, messages)
|
||||||
|
|
||||||
|
print(f"多轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}")
|
||||||
|
|
||||||
|
#必须槽位缺失检查
|
||||||
|
status, sk = check_lost(predicted_id, entities)
|
||||||
|
if status == CheckResult.NEEDS_MORE_ROUNDS:
|
||||||
|
return jsonify({"code": 10001, "msg": "成功",
|
||||||
|
"answer": { "miss": sk},
|
||||||
|
})
|
||||||
|
|
||||||
|
#工程名和项目名标准化
|
||||||
|
print(f"start to check_project_standard_slot")
|
||||||
|
result, information = check_project_standard_slot(predicted_id, entities)
|
||||||
|
print(f"end check_project_standard_slot,{result},{information}")
|
||||||
|
if result == CheckResult.NEEDS_MORE_ROUNDS:
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"user_id": user_id,
|
"code": 10001, "msg": "成功",
|
||||||
"query": query,
|
"answer": {"miss": information},
|
||||||
"message_count": len(messages)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
"code": 200,"msg": "成功",
|
||||||
|
"answer": {"int": predicted_id, "label": predicted_label, "probability": predicted_probability, "slot": entities },
|
||||||
|
})
|
||||||
|
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
return jsonify({"error": e.errors()}), 400 # 捕捉 Pydantic 错误并返回
|
return jsonify({"error": e.errors()}), 400 # 捕捉 Pydantic 错误并返回
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
|
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
|
||||||
|
|
||||||
|
|
||||||
|
def multi_slot_recognizer(intention_id, messages):
|
||||||
|
from openai import OpenAI
|
||||||
|
final_slot = {}
|
||||||
|
api_base_url = "http://36.33.26.201:27861/v1"
|
||||||
|
api_key = 'EMPTY'
|
||||||
|
model_name = 'qwen2.5-instruct'
|
||||||
|
client = OpenAI(base_url = api_base_url, api_key = api_key)
|
||||||
|
|
||||||
|
prompt = f'''
|
||||||
|
根据用户的输入{messages},抽取出用户想了解的问题,要求:保持客观真实,简单明了,不要多余解释和阐述
|
||||||
|
'''
|
||||||
|
|
||||||
|
message = [{"role": "system", "content": prompt}]
|
||||||
|
message.extend(messages)
|
||||||
|
# print(message)
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
messages=message,
|
||||||
|
model=model_name,
|
||||||
|
max_tokens=1000,
|
||||||
|
temperature=0.001,
|
||||||
|
stream=False
|
||||||
|
)
|
||||||
|
res = response.choices[0].message.content
|
||||||
|
|
||||||
|
print(f"多轮意图后用户想要的问题是{res}")
|
||||||
|
entries = slot_recognizer.recognize(res)
|
||||||
|
|
||||||
|
return entries
|
||||||
|
|
||||||
def check_lost(int_res, slot):
|
def check_lost(int_res, slot):
|
||||||
#labels: ["天气查询","互联网查询","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"]
|
#labels: ["天气查询","互联网查询","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"]
|
||||||
mapping = {
|
mapping = {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue