From 62db786d2bac141c4b3fd005f59106c7606c2819 Mon Sep 17 00:00:00 2001 From: weiweiw Date: Mon, 3 Mar 2025 10:58:46 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A4=9A=E8=BD=AE=E9=97=AE=E8=AF=A2=E4=BC=98?= =?UTF-8?q?=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/mian.py | 81 ++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 55 insertions(+), 26 deletions(-) diff --git a/api/mian.py b/api/mian.py index 6d2e100..8acbf60 100644 --- a/api/mian.py +++ b/api/mian.py @@ -201,45 +201,74 @@ def agent(): # 再进行槽位抽取 entities = slot_recognizer.recognize(query) - print(f"意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}") - - #必须槽位缺失检查 - status, sk = check_lost(predicted_id, entities) - if status == CheckResult.NEEDS_MORE_ROUNDS: - return jsonify({"code": 10001, "msg": "成功", - "answer": { "miss": sk}, - }) - - #工程名和项目名标准化 - print(f"start to check_project_standard_slot") - result, information = check_project_standard_slot(predicted_id, entities) - print(f"end check_project_standard_slot,{result},{information}") - if result == CheckResult.NEEDS_MORE_ROUNDS: - return jsonify({ - "code": 10001, "msg": "成功", - "answer": {"miss": information}, - }) - - return jsonify({ - "code": 200,"msg": "成功", - "answer": {"int": predicted_id, "label": predicted_label, "probability": predicted_probability, "slot": entities }, - }) + print(f"第一轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}") # 如果是后续轮次(多轮对话),这里只做示例,可能需要根据具体需求进行处理 else: query = messages[0].content # 使用 Message 对象的 .content 属性 + # 先进行意图识别 + predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query) + entities = multi_slot_recognizer(predicted_id, messages) + + print(f"多轮意图识别后的label:{predicted_label}, id:{predicted_id},槽位抽取后的实体:{entities},message:{messages}") + + #必须槽位缺失检查 + status, sk = check_lost(predicted_id, entities) + if status == CheckResult.NEEDS_MORE_ROUNDS: + return jsonify({"code": 10001, "msg": "成功", + "answer": { "miss": sk}, + }) + + #工程名和项目名标准化 + print(f"start to check_project_standard_slot") + result, information = check_project_standard_slot(predicted_id, entities) + print(f"end check_project_standard_slot,{result},{information}") + if result == CheckResult.NEEDS_MORE_ROUNDS: return jsonify({ - "user_id": user_id, - "query": query, - "message_count": len(messages) + "code": 10001, "msg": "成功", + "answer": {"miss": information}, }) + return jsonify({ + "code": 200,"msg": "成功", + "answer": {"int": predicted_id, "label": predicted_label, "probability": predicted_probability, "slot": entities }, + }) + except ValidationError as e: return jsonify({"error": e.errors()}), 400 # 捕捉 Pydantic 错误并返回 except Exception as e: return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回 +def multi_slot_recognizer(intention_id, messages): + from openai import OpenAI + final_slot = {} + api_base_url = "http://36.33.26.201:27861/v1" + api_key = 'EMPTY' + model_name = 'qwen2.5-instruct' + client = OpenAI(base_url = api_base_url, api_key = api_key) + + prompt = f''' + 根据用户的输入{messages},抽取出用户想了解的问题,要求:保持客观真实,简单明了,不要多余解释和阐述 + ''' + + message = [{"role": "system", "content": prompt}] + message.extend(messages) + # print(message) + response = client.chat.completions.create( + messages=message, + model=model_name, + max_tokens=1000, + temperature=0.001, + stream=False + ) + res = response.choices[0].message.content + + print(f"多轮意图后用户想要的问题是{res}") + entries = slot_recognizer.recognize(res) + + return entries + def check_lost(int_res, slot): #labels: ["天气查询","互联网查询","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"] mapping = {