diff --git a/api/constants.py b/api/constants.py index 0293e3c..44563e7 100644 --- a/api/constants.py +++ b/api/constants.py @@ -9,7 +9,7 @@ USELESS_PROJECT_WORDS = ["项目", "工程", "千伏", "公司", "直流"] #项目名标准化时需要过滤掉的词汇 -USELESS_PROGRAM_DEPARTMENT_WORDS = {"项目管理部","项目部", "项目", "管理"} +USELESS_PROGRAM_DEPARTMENT_WORDS = {"项目管理部", "项目部"} #公司名标准化时需要过滤掉的词汇 USELESS_COMPANY_WORDS = ["公司","有限","责任","工程","科技"] diff --git a/api/main.py b/api/main.py index 3b293a5..9163860 100644 --- a/api/main.py +++ b/api/main.py @@ -18,8 +18,8 @@ from apscheduler.schedulers.background import BackgroundScheduler # MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-25910" # MODEL_UIE_PATH = R"../uie/output/checkpoint-32750" -MODEL_ERNIE_PATH = R"../ernie/output_temp/checkpoint-34340" -MODEL_UIE_PATH = R"../uie/output_temp/checkpoint-34050" +MODEL_ERNIE_PATH = R"../ernie/output_temp/checkpoint-25627" +MODEL_UIE_PATH = R"../uie/output_temp/checkpoint-36320" # 类别名称列表 labels = [ @@ -33,23 +33,25 @@ labels = [ # 标签映射 label_map = { 0: 'O', # 非实体 - 1: 'B-date', 18: 'I-date', - 2: 'B-projectName', 19: 'I-projectName', - 3: 'B-projectType', 20: 'I-projectType', - 4: 'B-constructionUnit', 21: 'I-constructionUnit', - 5: 'B-implementationOrganization', 22: 'I-implementationOrganization', - 6: 'B-projectDepartment', 23: 'I-projectDepartment', - 7: 'B-projectManager', 24: 'I-projectManager', - 8: 'B-subcontractor', 25: 'I-subcontractor', - 9: 'B-teamLeader', 26: 'I-teamLeader', - 10: 'B-riskLevel', 27: 'I-riskLevel', - 11: 'B-page', 28: 'I-page', - 12: 'B-operating', 29: 'I-operating', - 13: 'B-teamName', 30: 'I-teamName', - 14: 'B-constructionArea', 31: 'I-constructionArea', - 15: 'B-personName', 32: 'I-personName', - 16: 'B-personQueryType', 33: 'I-personQueryType', - 17: 'B-projectStatus', 34: 'I-projectStatus', + 1: 'B-date', 20: 'I-date', + 2: 'B-projectName', 21: 'I-projectName', + 3: 'B-projectType', 22: 'I-projectType', + 4: 'B-constructionUnit', 23: 'I-constructionUnit', + 5: 'B-implementationOrganization', 24: 'I-implementationOrganization', + 6: 'B-projectDepartment', 25: 'I-projectDepartment', + 7: 'B-projectManager', 26: 'I-projectManager', + 8: 'B-subcontractor', 27: 'I-subcontractor', + 9: 'B-teamLeader', 28: 'I-teamLeader', + 10: 'B-riskLevel', 29: 'I-riskLevel', + 11: 'B-page', 30: 'I-page', + 12: 'B-operating', 31: 'I-operating', + 13: 'B-teamName', 32: 'I-teamName', + 14: 'B-constructionArea', 33: 'I-constructionArea', + 15: 'B-personName', 34: 'I-personName', + 16: 'B-personQueryType', 35: 'I-personQueryType', + 17: 'B-projectStatus', 36: 'I-projectStatus', + 18: 'B-skyNet', 37: 'I-skyNet', + 19: 'B-programNavigation', 38: 'I-programNavigation' } logger = setup_logger("main", level=logging.DEBUG) @@ -70,7 +72,7 @@ job() # 创建后台调度器 scheduler = BackgroundScheduler() -scheduler.add_job(job, 'cron', hour=3, minute=0) # 每天凌晨1点执行 +scheduler.add_job(job, 'cron', hour=3, minute=0) # 每天凌晨3点执行 scheduler.start() # 统一的异常处理函数 @@ -278,7 +280,7 @@ def extract_multi_chat(messages): latest_message = messages[-1] latest_user_question = latest_message.content if latest_message.role == "user" else "" - time_prefixes = ["当前","今天", "昨天", "本周", "下周", "明天", "今日"] + time_prefixes = ["当前","今天", "昨天", "本周", "下周", "明天", "今日","打开"] history_messages = [] if any(prefix in latest_user_question and prefix != latest_user_question for prefix in time_prefixes) else messages[:-1] logger.info(f"len(history_messages):{len(history_messages)}") @@ -293,7 +295,7 @@ def extract_multi_chat(messages): oldest_chat_history = "" if has_time_prefix else "\n".join([f"{msg.role}: {msg.content}" for msg in history_messages[:2]]) logger.info(f"last_chat_history:{last_chat_history}") - logger.info(f"oldest_chat_history):{oldest_chat_history}") + logger.info(f"oldest_chat_history:{oldest_chat_history}") prompt = f''' 你是一个意图识别与补全助手,你的任务是根据用户的最新问题判断是否需要补全,如果不需要补全,则原样返回用户的最新问题,否则需要结合最新对话历史和最老对话历史补全用户的最新问题,并只返回最终的完整问题。请严格按照如下逻辑判断并执行: diff --git a/api/main_temp.py b/api/main_temp.py index ba76890..5e0a664 100644 --- a/api/main_temp.py +++ b/api/main_temp.py @@ -15,8 +15,8 @@ from config import * from globalData import GlobalData from apscheduler.schedulers.background import BackgroundScheduler -MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-33510" -MODEL_UIE_PATH = R"../uie/output/checkpoint-33220" +MODEL_ERNIE_PATH = R"../ernie/output_temp/checkpoint-29288" +MODEL_UIE_PATH = R"../uie/output_temp/checkpoint-36320" # 类别名称列表 labels = [ @@ -30,23 +30,25 @@ labels = [ # 标签映射 label_map = { 0: 'O', # 非实体 - 1: 'B-date', 18: 'I-date', - 2: 'B-projectName', 19: 'I-projectName', - 3: 'B-projectType', 20: 'I-projectType', - 4: 'B-constructionUnit', 21: 'I-constructionUnit', - 5: 'B-implementationOrganization', 22: 'I-implementationOrganization', - 6: 'B-projectDepartment', 23: 'I-projectDepartment', - 7: 'B-projectManager', 24: 'I-projectManager', - 8: 'B-subcontractor', 25: 'I-subcontractor', - 9: 'B-teamLeader', 26: 'I-teamLeader', - 10: 'B-riskLevel', 27: 'I-riskLevel', - 11: 'B-page', 28: 'I-page', - 12: 'B-operating', 29: 'I-operating', - 13: 'B-teamName', 30: 'I-teamName', - 14: 'B-constructionArea', 31: 'I-constructionArea', - 15: 'B-personName', 32: 'I-personName', - 16: 'B-personQueryType', 33: 'I-personQueryType', - 17: 'B-projectStatus', 34: 'I-projectStatus', + 1: 'B-date', 20: 'I-date', + 2: 'B-projectName', 21: 'I-projectName', + 3: 'B-projectType', 22: 'I-projectType', + 4: 'B-constructionUnit', 23: 'I-constructionUnit', + 5: 'B-implementationOrganization', 24: 'I-implementationOrganization', + 6: 'B-projectDepartment', 25: 'I-projectDepartment', + 7: 'B-projectManager', 26: 'I-projectManager', + 8: 'B-subcontractor', 27: 'I-subcontractor', + 9: 'B-teamLeader', 28: 'I-teamLeader', + 10: 'B-riskLevel', 29: 'I-riskLevel', + 11: 'B-page', 30: 'I-page', + 12: 'B-operating', 31: 'I-operating', + 13: 'B-teamName', 32: 'I-teamName', + 14: 'B-constructionArea', 33: 'I-constructionArea', + 15: 'B-personName', 34: 'I-personName', + 16: 'B-personQueryType', 35: 'I-personQueryType', + 17: 'B-projectStatus', 36: 'I-projectStatus', + 18: 'B-skyNet', 37: 'I-skyNet', + 19: 'B-programNavigation', 38: 'I-programNavigation' } logger = setup_logger("main", level=logging.DEBUG) @@ -67,7 +69,7 @@ job() # 创建后台调度器 scheduler = BackgroundScheduler() -scheduler.add_job(job, 'cron', hour=3, minute=0) # 每天凌晨1点执行 +scheduler.add_job(job, 'cron', hour=3, minute=0) # 每天凌晨3点执行 scheduler.start() # 统一的异常处理函数 @@ -275,12 +277,12 @@ def extract_multi_chat(messages): latest_message = messages[-1] latest_user_question = latest_message.content if latest_message.role == "user" else "" - time_prefixes = ["当前","今天", "昨天", "本周", "下周", "明天", "今日"] + time_prefixes = ["当前","今天", "昨天", "本周", "下周", "明天", "今日","打开"] history_messages = [] if any(prefix in latest_user_question and prefix != latest_user_question for prefix in time_prefixes) else messages[:-1] logger.info(f"len(history_messages):{len(history_messages)}") - #最新问题的上一个问题里如何含有时间,则清空最老的历史对话 + #最新问题的上一个问题里如果含有时间,则清空最老的历史对话 last_two_messages = history_messages[-2:] has_time_prefix = any( msg.role == "user" and any(prefix in msg.content and prefix != msg.content for prefix in time_prefixes) @@ -290,7 +292,7 @@ def extract_multi_chat(messages): oldest_chat_history = "" if has_time_prefix else "\n".join([f"{msg.role}: {msg.content}" for msg in history_messages[:2]]) logger.info(f"last_chat_history:{last_chat_history}") - logger.info(f"oldest_chat_history):{oldest_chat_history}") + logger.info(f"oldest_chat_history:{oldest_chat_history}") prompt = f''' 你是一个意图识别与补全助手,你的任务是根据用户的最新问题判断是否需要补全,如果不需要补全,则原样返回用户的最新问题,否则需要结合最新对话历史和最老对话历史补全用户的最新问题,并只返回最终的完整问题。请严格按照如下逻辑判断并执行: @@ -298,45 +300,27 @@ def extract_multi_chat(messages): --- 【规则判断与补全流程】 - - 第一步:用户最新问题是否以“公司”为主语?→ 原样返回,无需补全 - - 若用户最新问题主语是“公司”,直接返回原句,无需补全。 - - 主语为“公司”的典型句式: - - 以“公司”开头; - - 以“今天”“昨天”“本周”“下周”等时间词开头,紧跟“公司”作为主语; - - 示例: - - 用户的最新问题:“今天公司有多少四级风险作业计划?” - - 用户的最新问题:“今天公司有多少作业计划” - - 用户的最新问题:“公司今天有多少4级风险的作业面?” - - 最终提问均为: 原句不变。 - - 第二步:用户最新问题是否是完整的问题?→ 原样返回,无需补全 - - 若用户最新问题中包含下列之一:具体的项目部名、工程名、分公司名、班组名、地区名等信息,且同时出现作业计划、作业面、班组等查询对象,视为完整问题,直接返回原句,无需补全。 - - 示例: - - 用户最新问题:“今天张三班组有多少作业计划?” - - 用户最新问题:“今天绿雪莲塘工程有多少作业计划” - - 最终提问均为: 原句不变。 - - 第三步:用户最新问题是否存在指代词?→ 结合用户最新问题和最新对话历史进行补全 - - 若用户最新问题问题中出现模糊表达,如“具体是哪些项”、“是哪两个”、“作业计划分别是什么”、“合肥中心变工程呢”、“具体是哪20项”等,请只使用紧邻最新问题之前的用户问题和AI回复补全问题信息。 + + 第一步:用户最新问题是否存在指代词?→ 结合用户最新问题和最新对话历史进行补全 + - 若用户最新问题问题中出现模糊表达,如“具体是哪些项”、“是哪两个”、“作业计划分别是什么”、“合肥中心变工程呢”、“具体是哪20项”、“考勤人数呢”等,请根据需要结合最新对话历史或最老对话历史补全问题信息。 - 示例1: - 用户最新问题:“具体的作业计划分别是什么” - - 紧邻最新问题的对话历史的用户问题:“今天公司有多少项作业计划” - - 紧邻最新问题的对话历史的AI回答:“2025-04-25公司一共有421项作业计划,分别如下:风险等级为2级的有15项,3级的有144项,4级的有262项,5级的有0项” + - 最新对话历史的用户问题:“今天公司有多少项作业计划” + - 最新对话历史的AI回答:“2025-04-25公司一共有421项作业计划,分别如下:风险等级为2级的有15项,3级的有144项,4级的有262项,5级的有0项” - 则最终提问应为: “今天公司的421项作业计划分别是什么” - 示例2: - 用户最新问题:“具体的作业内容是什么” - - 最新对话历史的用户问题:今天送一分公司第一项目部有多少项作业计划 + - 最新对话历史的用户问题:送一分公司第一项目部今天有多少项作业计划 - 最新对话历史的AI回答:今天送电一分公司第一项目管理部有21项作业计划 - 则最终提问应为: “今天送电一分公司第一项目管理部的21项作业计划分别是什么” - 第四步:用户最新问题是否为序号指代(第一个/第2个)?→ 用完整工程/项目/公司名替换补全 + 第二步:用户最新问题是否为序号指代(第一个/第2个)?→ 用完整工程/项目/公司名替换补全 - 精确提取用户所指的序号(如“第3个”指第3个工程名、公司名或项目部名); - 将该工程、公司或项目部的完整名称(包括括号中的编号)提取出来; - - 用完整名称替换掉最新对话历史的用户问题中出现的简称或模糊表达; - - 必须保留最新对话历史的用户问题中的所有其他关键信息(包括但不限于:项目部名称、时间、计划数、内容如"进度情况""作业计划""作业内容"等); + - 用完整工程、公司或项目部的名称替换掉最新对话历史的用户问题中出现的简称或模糊表达,但保留其他信息不变; + - 必须保留最新对话历史的用户问题中的所有其他关键信息(如具体的动作和操作的内容包括但不限于:项目部名称、时间、计划数、内容如"进度情况""作业计划""作业内容"“摄像头”“视频”等); - 示例1: - 用户最新问题:"第二个" 或"第2个" - 最新对话历史的用户问题:"2025年南苑调相机检修(PROJ-2023-0179)今天有多少作业计划"" @@ -349,14 +333,20 @@ def extract_multi_chat(messages): - 最新对话历史的AI回答:你说的工程名可能是,第1个:芦集-古沟π入潘集变电站220kV线路工程(PROJ-2024-0189),第二个:淮南芦集220千伏变电站220千伏配电装置改造工程(PROJ-2024-0265),请确认您要选择哪一个? - 则最终提问应为: "请帮我查一下今天淮南芦集220千伏变电站220千伏配电装置改造工程(PROJ-2024-0265)的进度情况" - - 示例3(新增关键保留示例): + - 示例3: - 用户最新问题:"第2个" - 最新对话历史的用户问题:"宏源电力公司第三项目部今天有多少项作业计划" - 最新对话历史的AI回答:您说的实施组织名可能是,第1个:安徽宏源电力建设有限公司(线路),第2个:安徽宏源电力建设有限公司(变电),请选择哪一个 - 则最终提问应为: "安徽宏源电力建设有限公司(变电)第三项目部今天有多少项作业计划" + - 示例4: + - 用户最新问题:"第2个" + - 最新对话历史的用户问题:"打开中心变摄像头" + - 最新对话历史的AI回答:您说的工程名可能是,第1个:锦绣-常青π入中心变电站220kV架空线路工程(PROJ-2024-1206),第2个:合肥中心变B包(PROJ-2024-0176),请选择哪一个 + - 则最终提问应为: + "打开合肥中心变B包(PROJ-2024-0176)摄像头" - 第五步:输出最终问题 + 第三步:输出最终问题 - 直接输出最终问题(无解释、无多余前缀或后缀) - 保持句式自然清晰 diff --git a/api/utils.py b/api/utils.py index fe07e5e..04875dc 100644 --- a/api/utils.py +++ b/api/utils.py @@ -315,10 +315,11 @@ def standardize_projectDepartment(standard_company, input_project, company_proje temp_input_project = replace_arabic_with_chinese(input_project) temp_input_project = clean_useless_program_departement_name(temp_input_project) - + # logger.info(f"temp_input_project: {temp_input_project}") program_list = company_project_department_map.get(standard_company, []) cleaned_map = {clean_useless_program_departement_name(p): p for p in program_list} + # logger.info(f"cleaned_map: {cleaned_map}") project_match = process.extractOne(temp_input_project, list(cleaned_map.keys()), scorer=cast(Callable, WRatio)) @@ -525,7 +526,7 @@ def clean_useless_program_departement_name(name: str) -> str: def check_lost(int_res, slot): #labels: ["天气查询","通用对话","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容","施工人数","作业考勤人数","知识问答"] mapping = { - 2: [['page'], ['app'], ['module']], + # 2: [['page'], ['app'], ['module']], 3: [['date']], 4: [['date']], 5: [['date']], @@ -538,8 +539,8 @@ def check_lost(int_res, slot): 14: [['date']], 15: [['date']], } - - intention_mapping = {2: "页面切换", 3: "日计划数量查询", 4: "周计划数量查询", 5: "日计划作业内容", + # 2: "页面切换", + intention_mapping = {3: "日计划数量查询", 4: "周计划数量查询", 5: "日计划作业内容", 6: "周计划作业内容", 7: "施工人数", 8: "作业考勤人数", 11: "作业面查询", 12: "班组人数查询", 13: "班组数查询", 14: "作业面内容", 15: "班组详情"} if not mapping.__contains__(int_res): @@ -568,14 +569,14 @@ def check_lost(int_res, slot): left = [x for x in mapping[int_res][idx] if x not in cur_k] logger.info(f"符合当前意图的的必须槽位,但是不在提取的槽位信息里, {left}") apologize_str = "非常抱歉," - if int_res == 2: - return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询哪个页面?" - elif int_res in [3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15,16]: + # if int_res == 2: + # return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询哪个页面?" + if int_res in [3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15,16]: return CheckResult.NEEDS_MORE_ROUNDS, f"{apologize_str}请问你想查询什么时间的{intention_mapping[int_res]}?" def check_standard_name_slot_probability(int_res, slot) -> tuple: - intention_list = {3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26} + intention_list = {2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26} if int_res not in intention_list: return CheckResult.NO_MATCH, "" @@ -633,7 +634,7 @@ def check_standard_name_slot_probability(int_res, slot) -> tuple: if key == PROJECT_DEPARTMENT: logger.info(f"check_standard_name_slot 原始项目部名 : {slot[PROJECT_DEPARTMENT]}") match_results = standardize_projectDepartment(slot[IMPLEMENTATION_ORG], value, GlobalData.standard_company_program, - high_score=90) + high_score=95) logger.info(f"check_standard_name_slot 匹配后项目部名: result:{match_results}") if match_results and len(match_results) == 1: slot[key] = match_results[0] diff --git a/generated_data/generated.py b/generated_data/generated.py index c6b068e..e2fe397 100644 --- a/generated_data/generated.py +++ b/generated_data/generated.py @@ -67,7 +67,7 @@ BASE_DATA = { "国网安徽省电力有限公司建设分公司","中铁四局建设公司","中铁四局建设公司","银联黄山园区开发有限公司"], # 分包单位 "subcontractors": ["劦力建筑责任公司","安徽劦力建筑装饰有限责任公司", "安徽苏亚建设集团有限公司","大信电力建设有限公司","优越电力公司", - "安徽国腾电力工程有限公司","安徽京硚建设有限公司","中国能源建设集团安徽省电力设计院有限公司"], + "安徽国腾电力工程有限公司","安徽京硚建设有限公司","中国能源建设集团安徽省电力设计院有限公司","作业班组管理中心"], # 班组名称 "team_names": ["张朵班组", "刘梁玉班组", "魏玉龙班组","周可富班组"], # 班组长 @@ -85,7 +85,13 @@ BASE_DATA = { "person_query_types": ["班组", "工程", "分公司", "实时组织", "项目部", "项目管理部"], # 工程状态 - "project_status_s": ["在建", "在作业", "在施工",""] + "project_status_s": ["在建", "在作业", "在施工",""], + + #皖送天网 + "sky_nets": ["摄像头", "视频"], + #项目巡航 + "program_navigations": ["数字化项目部", "数字化项目部管理平台", "施工生产管理平台"], + } @@ -140,8 +146,8 @@ TEMPLATE_CONFIG = { ["project_department", "date", "risk_level"]), ("安徽送变电{project_department}{date}有多少项{risk_level}风险作业计划?", ["project_department", "date", "risk_level"]), - ("{project_department}{date}有多少项{risk_level}风险作业计划?", - ["project_department", "date", "risk_level"]), + # ("{project_department}{date}有多少项{risk_level}风险作业计划?", + # ["project_department", "date", "risk_level"]), ("{project_department}{date}有多少{risk_level}风险作业计划?", ["project_department", "date", "risk_level"]), # 请帮我查一下 ("请帮我查一下{date}{project_manager}作业计划是多少?", ["date", "project_manager"]), @@ -161,10 +167,10 @@ TEMPLATE_CONFIG = { ("请帮我查一下{implementation_organization}{date}存在{risk_level}风险的有多少", ["implementation_organization","date", "risk_level"]), - ("{date}{project_type}类{implementation_organization}组织实施的作业计划有多少?", - ["date", "project_type", "implementation_organization"]), - ("{date}{project_department}管理的{project_type}类作业计划有多少?", - ["date", "project_department", "project_type"]), + # ("{date}{project_type}类{implementation_organization}组织实施的作业计划有多少?", + # ["date", "project_type", "implementation_organization"]), + ("{date}{implementation_organization}{project_department}{project_type}类作业计划有多少?", + ["date", "implementation_organization", "project_department", "project_type"]), ("{date}分包单位{subcontractor}承包的{project_type}类作业计划有多少?", ["date", "subcontractor", "project_type"]), ("{date}分包单位为{project_manager}负责的{project_type}类作业计划有多少?", ["date", "project_manager", "project_type"]), @@ -199,7 +205,6 @@ TEMPLATE_CONFIG = { ("{date}公司{project_department}有多少作业?", ["date", "project_department"]), ("{date}送变电公司{project_department}有多少作业?", ["date", "project_department"]), ("{date}安徽送变电{project_department}有多少作业?", ["date", "project_department"]), - ("{date}{project_department}有多少项作业?", ["date", "project_department"]), #有多少 ("{date}{implementation_organization}{project_department}有多少?", ["date", "implementation_organization", "project_department"]), @@ -760,6 +765,22 @@ TEMPLATE_CONFIG = { ("加载{page}模块", ["page"]), ("切换{page}", ["page"]), ("加载{page}页面", ["page"]), + #施工生产管理平台 + #项目巡航:分公司 + ("打开{implementation_organization}{program_navigation}", ["implementation_organization", "program_navigation"]), + #项目巡航:分公司、项目部 + ("打开{implementation_organization}{project_department}{program_navigation}", + ["implementation_organization", "project_department", "program_navigation"]), + #项目巡航:分公司 + ("切换到{implementation_organization}{program_navigation}", + ["implementation_organization", "program_navigation"]), + #项目巡航,工程 + ("打开{project_name}{program_navigation}", ["project_name", "program_navigation"]), + #皖智天网,工程名摄像头 + ("打开{project_name}{sky_net}", ["project_name", "sky_net"]), + #皖智天网,班组名摄像头 + ("切换到{team_leader}{sky_net}", ["team_leader", "sky_net"]), + #施工生产管理平台 ] }, "作业面查询": { @@ -1213,75 +1234,6 @@ TEMPLATE_CONFIG = { #询问工程数量时有工程性质和风险等级吗 ] }, - - # "工程数量查询": { - # "date": ["今天","今日", ""], - # "templates": [ - # #公司 - # ("{date}公司有多少工程", ["date"]), - # - # ("安徽送变电公司有多少工程{project_status}", ["project_status"]), - # #分公司和项目部 - # ("{date}{implementation_organization}有多少工程{project_status}", - # ["date", "implementation_organization", "project_status"]), - # ("{implementation_organization}{date}{project_department}有多少工程{project_status}", - # ["implementation_organization", "date", "project_department", "project_status"]), - # #建管区域和单位 - # ("{date}{construction_area}地区风险等级为{risk_level}有多少工程?", ["""construction_area", "risk_level"]), - # - # ("{construction_area}地区有多少工程{project_status}?", ["construction_area", "project_status"]), - # - # ("{construction_unit}有多少工程{project_status}?", ["construction_unit","project_status"]), - # - # #分包商 - # ("{subcontractor}有多少工程{project_status}", ["subcontractor", "project_status"]), - # ("安徽送变电公司{project_department}有多少工程?", ["project_department"]), - # #项目经理 - # ("{project_manager}有多少工程{project_status}", ["project_manager","project_status"]), - # #班组名称 - # ("{team_leader}有多少{project_status}工程", ["team_leader", "project_status"]), - # #工程性质 - # ("公司{project_type}类的工程有多少?", ["project_type"]), - # #风险等级 - # ("公司{risk_level}风险的{project_status}工程有多少?", ["risk_level", "project_status"]), - # #询问工程数量时有工程性质和风险等级吗 - # ] - # }, - # - # "工程详情查询": { - # "date": ["今天","今日",""], - # "templates": [ - # #公司 - # ("{date}公司有哪些工程", ["date"]), - # ("截止目前公司有哪些{project_status}工程", ["project_status"]), - # ("安徽送变电公司{date}有哪些工程{project_status}", ["date", "project_status"]), - # #分公司和项目部 - # ("{implementation_organization}{date}工程详情{project_status}", - # ["implementation_organization", "date", "project_status"]), - # ("{date}{implementation_organization}{project_department}有哪些工程{project_status}", - # ["date", "implementation_organization", "project_department", "project_status"]), - # #建管区域和单位 - # ("{date}{construction_area}地区风险等级为{risk_level}工程具体情况?", ["date", "construction_area", "risk_level"]), - # - # ("{construction_area}{date}地区有哪些{project_status}工程?", ["construction_area", "date", "project_status"]), - # - # ("{construction_unit}有哪些工程{project_status}?", ["construction_unit","project_status"]), - # - # #分包商 - # ("{subcontractor}有多少{project_status}工程", ["subcontractor", "project_status"]), - # ("送变电公司{project_department}工程详情?", ["project_department"]), - # #项目经理 - # ("{project_manager}有多少工程{project_status}", ["project_manager","project_status"]), - # #班组名称 - # ("{team_leader}工程具体情况", ["team_leader"]), - # #工程性质 - # ("公司{project_type}类的工程有哪些?", ["project_type"]), - # #风险等级 - # ("公司{risk_level}风险的{project_status}工程有那些?", ["risk_level", "project_status"]), - # #询问工程数量时有工程性质和风险等级吗 - # ] - # }, - "项目部数量查询": { "date": ["今天","最近"], "templates": [ @@ -1386,6 +1338,8 @@ def generate_natural_samples(config, label): "person_name": BASE_DATA["person_names"], "person_query_type": BASE_DATA["person_query_types"], "project_status": BASE_DATA["project_status_s"], + "sky_net": BASE_DATA["sky_nets"], + "program_navigation": BASE_DATA["program_navigations"], } for template, variables in config["templates"]: diff --git a/uie/train.py b/uie/train.py index 51bb37a..23c6f2d 100644 --- a/uie/train.py +++ b/uie/train.py @@ -18,7 +18,8 @@ def preprocess_function(example, tokenizer): 'date', 'project_name', 'project_type', 'construction_unit', 'implementation_organization', 'project_department', 'project_manager', 'subcontractor', 'team_leader', 'risk_level', 'page', 'operating', 'team_name', - 'construction_area', 'person_name', 'person_query_type', 'project_status' + 'construction_area', 'person_name', 'person_query_type', 'project_status', + "sky_net", "program_navigation" ] # 文本 Tokenization @@ -60,7 +61,7 @@ def preprocess_function(example, tokenizer): # === 3. 加载 UIE 预训练模型 === -model = ErnieForTokenClassification.from_pretrained("uie-base", num_classes=35) # 3 类 (O, B, I) +model = ErnieForTokenClassification.from_pretrained("uie-base", num_classes=39) # 3 类 (O, B, I) tokenizer = ErnieTokenizer.from_pretrained("uie-base") # === 4. 加载数据集 ===