diff --git a/api/main.py b/api/main.py index a75ecba..6dc3dd3 100644 --- a/api/main.py +++ b/api/main.py @@ -15,34 +15,41 @@ from config import * from globalData import GlobalData from apscheduler.schedulers.background import BackgroundScheduler -MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-25910" -MODEL_UIE_PATH = R"../uie/output/checkpoint-32750" +# MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-25910" +# MODEL_UIE_PATH = R"../uie/output/checkpoint-32750" + +MODEL_ERNIE_PATH = R"../ernie/output_temp/checkpoint-33510" +MODEL_UIE_PATH = R"../uie/output_temp/checkpoint-33220" # 类别名称列表 labels = [ "天气查询", "互联网查询", "页面切换", "日计划数量查询", "周计划数量查询", "日计划作业内容", "周计划作业内容", "施工人数", "作业考勤人数", "知识问答", "通用对话", "作业面查询", "班组人数查询", "班组数查询", "作业面内容", "班组详情", - "工程进度查询" + "工程进度查询", "人员查询", "分公司查询","工程数量查询","工程详情查询","项目部数量查询", + "建管单位数量查询","建管单位详情","分包单位数量查询","分包单位详情" ] # 标签映射 label_map = { 0: 'O', # 非实体 - 1: 'B-date', 15: 'I-date', - 2: 'B-projectName', 16: 'I-projectName', - 3: 'B-projectType', 17: 'I-projectType', - 4: 'B-constructionUnit', 18: 'I-constructionUnit', - 5: 'B-implementationOrganization', 19: 'I-implementationOrganization', - 6: 'B-projectDepartment', 20: 'I-projectDepartment', - 7: 'B-projectManager', 21: 'I-projectManager', - 8: 'B-subcontractor', 22: 'I-subcontractor', - 9: 'B-teamLeader', 23: 'I-teamLeader', - 10: 'B-riskLevel', 24: 'I-riskLevel', - 11: 'B-page', 25: 'I-page', - 12: 'B-operating', 26: 'I-operating', - 13: 'B-teamName', 27: 'I-teamName', - 14: 'B-constructionArea', 28: 'I-constructionArea', + 1: 'B-date', 18: 'I-date', + 2: 'B-projectName', 19: 'I-projectName', + 3: 'B-projectType', 20: 'I-projectType', + 4: 'B-constructionUnit', 21: 'I-constructionUnit', + 5: 'B-implementationOrganization', 22: 'I-implementationOrganization', + 6: 'B-projectDepartment', 23: 'I-projectDepartment', + 7: 'B-projectManager', 24: 'I-projectManager', + 8: 'B-subcontractor', 25: 'I-subcontractor', + 9: 'B-teamLeader', 26: 'I-teamLeader', + 10: 'B-riskLevel', 27: 'I-riskLevel', + 11: 'B-page', 28: 'I-page', + 12: 'B-operating', 29: 'I-operating', + 13: 'B-teamName', 30: 'I-teamName', + 14: 'B-constructionArea', 31: 'I-constructionArea', + 15: 'B-personName', 32: 'I-personName', + 16: 'B-personQueryType', 33: 'I-personQueryType', + 17: 'B-projectStatus', 34: 'I-projectStatus', } logger = setup_logger("main", level=logging.DEBUG) diff --git a/api/main_temp.py b/api/main_temp.py index 440c35b..ba76890 100644 --- a/api/main_temp.py +++ b/api/main_temp.py @@ -15,34 +15,38 @@ from config import * from globalData import GlobalData from apscheduler.schedulers.background import BackgroundScheduler -MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-25910" -MODEL_UIE_PATH = R"../uie/output/checkpoint-32750" +MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-33510" +MODEL_UIE_PATH = R"../uie/output/checkpoint-33220" # 类别名称列表 labels = [ "天气查询", "互联网查询", "页面切换", "日计划数量查询", "周计划数量查询", "日计划作业内容", "周计划作业内容", "施工人数", "作业考勤人数", "知识问答", "通用对话", "作业面查询", "班组人数查询", "班组数查询", "作业面内容", "班组详情", - "工程进度查询" + "工程进度查询", "人员查询", "分公司查询","工程数量查询","工程详情查询","项目部数量查询", + "建管单位数量查询","建管单位详情","分包单位数量查询","分包单位详情" ] # 标签映射 label_map = { 0: 'O', # 非实体 - 1: 'B-date', 15: 'I-date', - 2: 'B-projectName', 16: 'I-projectName', - 3: 'B-projectType', 17: 'I-projectType', - 4: 'B-constructionUnit', 18: 'I-constructionUnit', - 5: 'B-implementationOrganization', 19: 'I-implementationOrganization', - 6: 'B-projectDepartment', 20: 'I-projectDepartment', - 7: 'B-projectManager', 21: 'I-projectManager', - 8: 'B-subcontractor', 22: 'I-subcontractor', - 9: 'B-teamLeader', 23: 'I-teamLeader', - 10: 'B-riskLevel', 24: 'I-riskLevel', - 11: 'B-page', 25: 'I-page', - 12: 'B-operating', 26: 'I-operating', - 13: 'B-teamName', 27: 'I-teamName', - 14: 'B-constructionArea', 28: 'I-constructionArea', + 1: 'B-date', 18: 'I-date', + 2: 'B-projectName', 19: 'I-projectName', + 3: 'B-projectType', 20: 'I-projectType', + 4: 'B-constructionUnit', 21: 'I-constructionUnit', + 5: 'B-implementationOrganization', 22: 'I-implementationOrganization', + 6: 'B-projectDepartment', 23: 'I-projectDepartment', + 7: 'B-projectManager', 24: 'I-projectManager', + 8: 'B-subcontractor', 25: 'I-subcontractor', + 9: 'B-teamLeader', 26: 'I-teamLeader', + 10: 'B-riskLevel', 27: 'I-riskLevel', + 11: 'B-page', 28: 'I-page', + 12: 'B-operating', 29: 'I-operating', + 13: 'B-teamName', 30: 'I-teamName', + 14: 'B-constructionArea', 31: 'I-constructionArea', + 15: 'B-personName', 32: 'I-personName', + 16: 'B-personQueryType', 33: 'I-personQueryType', + 17: 'B-projectStatus', 34: 'I-projectStatus', } logger = setup_logger("main", level=logging.DEBUG) diff --git a/api/new_algroth/standard_data/standard_program.txt b/api/new_algroth/standard_data/standard_program.txt deleted file mode 100644 index 736a26a..0000000 --- a/api/new_algroth/standard_data/standard_program.txt +++ /dev/null @@ -1,53 +0,0 @@ -第八项目管理部(淮北宿州) -第七项目管理部(阜阳) -第十一项目管理部(马鞍山) -第四项目管理部(安庆) -第九项目管理部(合肥轨道线) -第五项目管理部(合肥) -第一项目管理部(池州黄山) -第十项目管理部(特高压) -第二项目管理部(宣城) -第六项目管理部(滁州) -第三项目管理部(芜湖) -第四项目管理部(六安变电) -第七项目管理部(淮南线路) -第十项目管理部(亳州变电) -第九项目管理部(亳州线路) -第五项目管理部(蚌埠线路) -第十一项目管理部(萧砀线路) -第三项目管理部(张店线路) -第三项目管理部(岳西线路) -第八项目管理部(淮南变电) -第六项目管理部(蚌埠变电) -第十一项目管理部(宿州线路) -第二项目管理部(合肥变电) -第三项目管理部(谯城变、亳州楼) -第五项目管理部(金牛变) -第二项目管理部(合州站、阜四变) -第一项目管理部(萧砀变、锁库变) -第七项目管理部(合肥中心变) -第二项目管理部(修试) -第三项目管理部(香鹭东段) -第二项目管理部(香鹭西段) -第五项目管理部(阜阳) -第十三项目管理部(黄山) -第八项目管理部(芜湖) -第九项目管理部(马鞍山) -第四项目管理部(甘浙) -第十一项目管理部(宣城) -第九项目管理部(淮北) -第十二项目管理部(陕皖) -第一项目管理部(肥东) -第四项目管理部(池州) -第二项目管理部(紫蓬) -第六项目管理部(安庆) -第八项目管理部(宿州分部) -第七项目管理部(安庆四) -第三项目管理部(庐江) -第三项目管理部(六安线路) -第六项目管理部(阜阳综合楼、省营销楼) -第四项目管理部(安庆四、明生楼) -第一项目管理部(金上) -第一项目管理部(修试) -第五项目管理部(铜陵) -第八项目管理部(宿州) \ No newline at end of file diff --git a/api/standard_data/standard_company.txt b/api/standard_data/standard_company.txt deleted file mode 100644 index 0aba818..0000000 --- a/api/standard_data/standard_company.txt +++ /dev/null @@ -1,7 +0,0 @@ -检修试验分公司 -送电一分公司 -送电二分公司 -变电分公司 -建筑分公司 -安徽宏源电力建设有限公司 -安徽顺安电网建设有限公司 \ No newline at end of file diff --git a/api/standard_data/standard_project.txt b/api/standard_data/standard_project.txt index d48257e..78471d4 100644 --- a/api/standard_data/standard_project.txt +++ b/api/standard_data/standard_project.txt @@ -325,7 +325,7 @@ 500kV当涂变220kV当马安控系统升级改造项目((PROJ-2021-0114)) 淄博高青500kV返厂大修(PROJ-2024-0850) 合肥云谷-江汽新港变110kV线路改造工程(PROJ-2020-0524) -淮北庙台220kV变电站新建工程 (电气部分)(PROJ-2025-0115) +淮北220千伏庙台变电站新建工程 (电气)(PROJ-2025-0115) 香涧-梨花π入固镇南牵引站220kV线路工程(PROJ-2024-0441) 香涧500kV变电站间隔扩建工程(PROJ-2024-0890) 安徽铜陵枞阳县破罡35KV变电站增容改造工程(PROJ-2021-0064) @@ -388,6 +388,7 @@ 武昌-古港π入江调变电站110kV线路工程(PROJ-2021-0088) 宏源大厦维修改造项目(PROJ-2021-0089) 安徽合肥大学城220kV变电站110kV站前路主所间隔扩建工程(PROJ-2021-0091) +亳州伯阳500kV变电站220kV道仁间隔扩建工程(PROJ-2024-0781) 尧天湖220kV变电站35kV田营园区出线间隔扩建工程(PROJ-2020-0525) 蓼城-白莲牵引站Ⅱ回线改接俞林变电站110kV架空线路工程(PROJ-2020-0526) ±500KV延庆换流站备用换流变安装(PROJ-2025-0162) @@ -1767,6 +1768,7 @@ G205九华南路快速化改造工程(二期)火龙岗北至南陵渡桥段5 国网安徽合肥供电公司2024年220kV学苑变电站一键顺控系统完善提升(PROJ-2024-0244) 蕴山-沙埂110kV线路工程(PROJ-2024-0245) 国国网安徽合肥供电公司2024年220kV秋浦变电站一键顺控系统完善提升(PROJ-2024-0249) +国网安徽合肥供电公司2024年500kV龙门变电站一键顺控系统完善提升(PROJ-2024-0250) 500kV潘清5710线#89-#100段改造工程(PROJ-2024-0252) 500kV墨孔5339线#234-#245段改造工程(PROJ-2024-0253) 国网安徽电力超高压公司1000kV特高压芜湖站1000kVI母线GM104气室B相、GM113气室C相特高频局放异常处理项目(PROJ-2024-0254) @@ -1795,7 +1797,6 @@ G205九华南路快速化改造工程(二期)火龙岗北至南陵渡桥段5 安徽亳州华都110kV变电站110kV真源间隔改造工程(调试部分)(PROJ-2024-1180) 国网北京检修公司2024年±500kV延庆换流站阀冷系统设备驻站(PROJ-2024-0849) 安徽亳州鲲鹏110kV变电站新建工程(PROJ-2025-0095) -围屏220kV变电站110kV万济间隔保护改造工程(PROJ-2025-0182) 1000kV芜湖站1000kV T042开关拆除工作(PROJ-2025-0113) 庆阳±800kV换流站工程大件运输工程(PROJ-2024-1264) 同乐-明都开断接入林楼变220kV线路工程(含光缆)(PROJ-2021-0018) @@ -1944,7 +1945,7 @@ S334峨山路东延伸(沿江高速至芜宣高速)新建工程二期500kV 谢桥电厂-原鹿220kV线路工程(PROJ-2024-0785) 围屏220kV变电站220kV万济间隔扩建工程(调试部分)(PROJ-2025-0135) 马鞍山华阳110kV变电站110kV含山间隔改造工程(调试部分)(PROJ-2024-1169) -淮北庙台220kV变电站新建工程(土建部分)(PROJ-2025-0114) +淮北220千伏庙台变电站新建工程(土建)(PROJ-2025-0114) 碱河220千伏变电站110千伏凌云间隔扩建工程(调试部分)(PROJ-2025-0006) 国网淮南供电公司500kV孔店变500kV5043电流互感器更换项目(PROJ-2025-0124) 花园220kV变电站薛桥、郭王间隔改造工程(PROJ-2025-0116) @@ -2024,7 +2025,6 @@ S334峨山路东延伸(沿江高速至芜宣高速)新建工程二期500kV 国网安徽电力超高压分公司2024年超特高压变电站应急保障服务(PROJ-2024-0843) 太和-李腰π入城南变电站110kV架空线路工程(PROJ-2024-0709) 金牛500kV变电站新建工程(调试部分)(PROJ-2025-0032) -章塘220kV变电站110kV万济间隔保护改造工程(PROJ-2025-0181) 安徽送变电工程有限公司2025年度扩建及技改工程电气安装劳务分包框架(PROJ-2025-0099) 国网蚌埠供电公司500kV怀洪变500kV5379线路及5380线路间隔检查和试验项目(PROJ-2025-0097) 椿树220千伏变电站220千伏阜四间隔扩建工程(电气安装)(PROJ-2024-0942) @@ -2088,7 +2088,6 @@ S334峨山路东延伸(沿江高速至芜宣高速)新建工程二期500kV 谯城(亳三)-谯城220kV电缆线路工程(PROJ-2024-1203) 陕北-安徽直流工程合州±800千伏换流站土建A包(PROJ-2024-0312) 宿州蟠龙220kV变电站220kV大庄风电间隔扩建工程(电气安装)(PROJ-2024-0466) -S19淮桐高速合肥段涉500kV皋铭5357线皋传5358线改造工程(PROJ-2025-0183) 石岗-施桥110kV线路工程(PROJ-2024-0276) 国网安徽宣城供电公司500kV河沥变加装固定融冰装置(调试部分)(PROJ-2025-0087) 香涧-鹭岛500kV线路工程(一般线路东段)(PROJ-2024-0725) diff --git a/api/standard_data/team_leader.txt b/api/standard_data/team_leader.txt index 4535dcc..1d147e3 100644 --- a/api/standard_data/team_leader.txt +++ b/api/standard_data/team_leader.txt @@ -36,13 +36,13 @@ 李章贵班组 吴义新班组 郑云龙班组 +徐钦文班组 张江福班组 盛平班组 储友健班组 熊宗书班组 李光明班组 朱明雷班组 -刘亚锋班组 耿兴海班组 王威班组 张小斌班组 @@ -104,6 +104,7 @@ 蔡道平班组 汪吉祥班组 徐朝军班组 +汪自存班组 徐本家班组 任培培班组 陈兰荣班组 @@ -157,6 +158,7 @@ 邓君班组 李赛北班组 何东洋班组 +古伟班组 叶从进班组 席胜利班组 孔永高班组 @@ -169,7 +171,6 @@ 王亮班组 方勇勇班组 徐兴才班组 -龙金华班组 蔡大羊班组 乔帅兵班组 刘荣君班组 @@ -202,7 +203,6 @@ 杨廷泽班组 董可祥班组 王新班组 -鲁从伟班组 陈双双班组 郭翔翔班组 刘鹏班组 @@ -300,7 +300,6 @@ 洪金涛班组 单海峰班组 张海涛班组 -岳林班组 唐孝明班组 叶朝磊班组 左华彪班组 @@ -323,7 +322,6 @@ 段祥宇班组 王坤班组 刘士平班组 -殷书尔班组 赵振强班组 方年春班组 黄勇班组 @@ -381,6 +379,7 @@ 江军班组 杨长和班组 朱纪倍班组 +夏万年班组 赵华健班组 杨松伟班组 王务红班组 @@ -422,6 +421,7 @@ 蒲民班组 黄本初班组 高磊班组 +张志班组 姚海强班组 吴庆欢班组 徐南班组 diff --git a/api/standard_test.py b/api/standard_test.py index 6e28f70..975cd68 100644 --- a/api/standard_test.py +++ b/api/standard_test.py @@ -4,7 +4,8 @@ from logger_util import setup_logger from globalData import GlobalData from utils import standardize_name, clean_useless_team_leader_name, standardize_sub_company, standardize_project_name, \ standardize_projectDepartment, standardize_team_name, check_standard_name_slot_probability, \ - clean_useless_project_name + clean_useless_project_name, save_standard_name_list_to_file, load_standard_name_list, save_dict_to_file, \ + load_standard_json_data import time from apscheduler.schedulers.blocking import BlockingScheduler @@ -16,7 +17,7 @@ from globalData import GlobalData # # logger = setup_logger("utils", level=logging.DEBUG) -GlobalData.update_from_redis() +# GlobalData.update_from_redis() def check_standard_name_slot_probability_test(): slot_list = [{"constructionUnit": "合肥供电公司"}, @@ -90,34 +91,34 @@ def standardize_team_leader_test(): def standardize_company_test(): test_cases = [ - ("送一分公司"), - ("送二分公司"), - ("变电分公司"), - ("建筑分公司"), - ("检修试验分公司"), - ("宏源电力公司"), - ("宏源电力限公司"), - ("宏源电力限公司线路"), - ("宏源电力限公司变电"), - ("送一分"), - ("送二分"), - ("变电分"), - ("建筑分"), - ("检修试验分"), - ("宏源电力"), - ("红源电力"), - ("宏源电力有限"), - ("宏源电力限线路"), - ("宏源电力限变电"), + ("宋轶分公司"), + # ("送二分公司"), + # ("变电分公司"), + # ("建筑分公司"), + # ("检修试验分公司"), + # ("宏源电力公司"), + # ("宏源电力限公司"), + # ("宏源电力限公司线路"), + # ("宏源电力限公司变电"), + # ("送一分"), + # ("送二分"), + # ("变电分"), + # ("建筑分"), + # ("检修试验分"), + # ("宏源电力"), + # ("红源电力"), + # ("宏源电力有限"), + # ("宏源电力限线路"), + # ("宏源电力限变电"), ] - logger.info(f"加权混合策略 分公司名匹配**********************") + print(f"加权混合策略 分公司名匹配**********************") start = time.perf_counter() for item in test_cases: match_results = standardize_sub_company(item,GlobalData.simply_to_standard_company_name_map, GlobalData.pinyin_simply_to_standard_company_name_map,70,90) - logger.info(f"加权混合策略 分公司名匹配 输入: {item}-> 输出: {match_results}") + print(f"加权混合策略 分公司名匹配 输入: {item}-> 输出: {match_results}") end = time.perf_counter() - logger.info(f"加权混合策略 耗时: {end - start:.4f} 秒") + print(f"加权混合策略 耗时: {end - start:.4f} 秒") def standardize_construction_test(): @@ -134,8 +135,8 @@ def standardize_construction_test(): def standardize_project_test(): test_cases = [ - ("众兴500kV变电站220kV杜岗Ⅱ间隔改造工程(PROJ-2023-0435)"), - ("众兴500kv变电站220kv杜岗ⅱ间隔改造工程(proj-2023-0435)") + ("陶楼夏塘工程"), + ("宗阳黄桥工程") # ("合肥卫田变电站工程"). # ("金牛变电站新建建筑"), # ("金牛变电站建筑工程"), @@ -178,14 +179,14 @@ def standardize_project_test(): # ("卫田-陶楼T接首业变电站110kV电缆线路工程(PROJ-2024-1236)"), # ("谯城(亳三)-希夷220kV线路工程(PROJ-2024-1205)"), ] - logger.info(f"去不重要词汇 工程名匹配******************************************") + print(f"去不重要词汇 工程名匹配******************************************") start = time.perf_counter() for item in test_cases: match_results = standardize_project_name(item, GlobalData.simply_to_standard_project_name_map, GlobalData.pinyin_simply_to_standard_project_name_map, 70, 90) - logger.info(f"***************工程名匹配 输入: {item}-> 输出: {match_results}") + print(f"***************工程名匹配 输入: {item}-> 输出: {match_results}") end = time.perf_counter() - logger.info(f"词集匹配 耗时: {end - start:.4f} 秒") + print(f"词集匹配 耗时: {end - start:.4f} 秒") def standardize_program_test(): logger.info(f"项目名匹配******************************************") @@ -358,14 +359,226 @@ def get_size(): standardize_project_test() -unuselessStr = clean_useless_project_name("众兴杜岗ⅱ间隔改造") -print(f"众兴杜岗ⅱ间隔改造:{unuselessStr}") -unuselessStr = clean_useless_project_name("众兴杜岗Ⅱ间隔改造") -print(f"众兴杜岗Ⅱ间隔改造:{unuselessStr}") -print("今天的长度:",len("今天")) +def exact_hot_words(): + save_standard_name_list_to_file(list(GlobalData.simply_to_standard_company_name_map.keys()),"./hot_word/company.txt") + #save_standard_name_list_to_file(list(GlobalData.simply_to_standard_project_name_map.keys()),"./hot_word/project.txt") + save_standard_name_list_to_file(list(GlobalData.simply_to_standard_constractor_name_map.keys()),"./hot_word/constractor.txt") + save_standard_name_list_to_file(list(GlobalData.simply_to_standard_construct_name_map.keys()),"./hot_word/construct.txt") + + +# def exact_project_hot_words(): +# import re +# +# # 示例数据,换成 GlobalData.simply_to_standard_project_name_map.keys() +# project_names = [ +# "安庆四500kV变电站新建工程(PROJ-2024-0862)", +# "淮南芦集 220 千伏变电站 220 千伏配电装置改造工程(调试部分)(PROJ-2025-0022)", +# "屏显220kV变电站220kV杜岗Ⅱ间隔改造工程(调试部分)(PROJ-2025-0169)", +# "漆园220kV变电站220kV杨柳间隔改造工程(调试部分)(PROJ-2025-0042)", +# "宝桥220kV变电站220kV红桥间隔保护改造工程(调试部分)(PROJ-2025-0088)", +# "蟠龙220kV变电站220kV灵泗间隔改造工程(调试部分)(PROJ-2025-0018)", +# "锦绣-常青π入中心变电站220kV架空线路工程(PROJ-2024-1206)", +# "安庆和平220kV变电站新建工程(调试部分)(PROJ-2024-1238)", +# "渝北±800千伏换流站电气安装A包(调试部分)(PROJ-2024-1192)", +# "安徽亳州华佗220kV变电站220kV新华风电间隔扩建工程(调试部分)(PROJ-2024-1171)", +# "先锋-泉河π入安庆四变电站220kV线路工程(PROJ-2024-0834)", +# "安徽滁州护桥220kV变电站2号主变扩建工程(PROJ-2024-0821)", +# "双岭500kV变电站间隔改造工程(PROJ-2024-0863)", +# "合州±800千伏换流站电气安装A包(PROJ-2025-0056)", +# "金牛500kV变电站新建工程(PROJ-2024-0866)", +# "况楼220kV变电站间隔扩建工程(调试部分)(PROJ-2025-0144)", +# "国网安徽合肥供电公司2023年GIS带电显示器维护(PROJ-2024-1260)", +# "亳州木兰220kV变电站220kV改造工程(安徽亳州木兰200kV变电站GIS设备检修及调试技术服务)(PROJ-2024-1256)", +# "香涧-鹭岛500kV线路工程(淮河大跨越段)(PROJ-2024-0722)", +# "安徽蚌埠濠州220kV变电站220千伏大唐凤阳红心镇光伏间隔扩建工程(调试部分)(PROJ-2025-0164)", +# "陶楼-广银(T接智迪)改接首业变电站110kV电缆线路工程(PROJ-2024-1233)", +# "国网北京检修公司2024年±500kV延庆换流站直流主设备及辅助设备不停电检修维护(PROJ-2024-0841)" +# ] +# +# +# # 用来存关键词 +# keywords = [] +# # +# for name in GlobalData.simply_to_standard_project_name_map.keys(): +# # 去掉括号和括号里的内容 +# cleaned_name = re.sub(r"\(.*?\)", "", name) +# +# # 提取“-”连接的词 +# if "-" in cleaned_name: +# parts = cleaned_name.split("-") +# first = parts[0].strip() +# second = re.split(r"[^\u4e00-\u9fa5]", parts[1])[0].strip() # 只取中文部分 +# if first: +# keywords.append(first) +# if second: +# keywords.append(second) +# if first and second: +# keywords.append(first + second) +# else: +# # 正常提取,取第一个连续的中文词组 +# match = re.match(r"([\u4e00-\u9fa5]+)", cleaned_name) +# if match: +# keywords.append(match.group(1)) +# +# # 去重,且保持原顺序 +# seen = set() +# unique_keywords = [] +# for kw in keywords: +# if kw not in seen: +# seen.add(kw) +# unique_keywords.append(kw) +# +# # 写入到文件 +# with open("new_project.txt", "w", encoding="utf-8") as f: +# for kw in unique_keywords: +# f.write(kw + "\n") +# +# print("提取完成,已写入 new_project.txt") + + + # # 去重且保持顺序 + # seen = set() + # unique_keywords = [] + # for kw in keywords: + # if kw not in seen: + # seen.add(kw) + # unique_keywords.append(kw) + # + # # 写入到文件 + # with open("new_project.txt", "w", encoding="utf-8") as f: + # for kw in unique_keywords: + # f.write(kw + "\n") + # + # print("提取完成,已写入 new_project.txt") + + +def exact_project_hot_words(): + import re + + # 示例数据,换成 GlobalData.simply_to_standard_project_name_map.keys() + # project_names = [ + # "安庆四500kV变电站新建工程(PROJ-2024-0862)", + # "淮南芦集 220 千伏变电站 220 千伏配电装置改造工程(调试部分)(PROJ-2025-0022)", + # "锦绣-常青π入中心变电站220kV架空线路工程(PROJ-2024-1206)", + # "渝北±800千伏换流站电气安装A包(调试部分)(PROJ-2024-1192)", + # "屏显220kV变电站220kV杜岗Ⅱ间隔改造工程(调试部分)(PROJ-2025-0169)", + # "漆园220kV变电站220kV杨柳间隔改造工程(调试部分)(PROJ-2025-0042)", + # "宝桥220kV变电站220kV红桥间隔保护改造工程(调试部分)(PROJ-2025-0088)", + # "蟠龙220kV变电站220kV灵泗间隔改造工程(调试部分)(PROJ-2025-0018)", + # "安庆和平220kV变电站新建工程(调试部分)(PROJ-2024-1238)", + # "安徽亳州华佗220kV变电站220kV新华风电间隔扩建工程(调试部分)(PROJ-2024-1171)", + # "先锋-泉河π入安庆四变电站220kV线路工程(PROJ-2024-0834)", + # "安徽滁州护桥220kV变电站2号主变扩建工程(PROJ-2024-0821)", + # "双岭500kV变电站间隔改造工程(PROJ-2024-0863)", + # "合州±800千伏换流站电气安装A包(PROJ-2025-0056)", + # "金牛500kV变电站新建工程(PROJ-2024-0866)", + # "况楼220kV变电站间隔扩建工程(调试部分)(PROJ-2025-0144)", + # "国网安徽合肥供电公司2023年GIS带电显示器维护(PROJ-2024-1260)", + # "亳州木兰220kV变电站220kV改造工程(安徽亳州木兰200kV变电站GIS设备检修及调试技术服务)(PROJ-2024-1256)", + # "香涧-鹭岛500kV线路工程(淮河大跨越段)(PROJ-2024-0722)", + # "安徽蚌埠濠州220kV变电站220千伏大唐凤阳红心镇光伏间隔扩建工程(调试部分)(PROJ-2025-0164)", + # "陶楼-广银(T接智迪)改接首业变电站110kV电缆线路工程(PROJ-2024-1233)", + # "国网北京检修公司2024年±500kV延庆换流站直流主设备及辅助设备不停电检修维护(PROJ-2024-0841)" + # ] + # 排除含有这些词的项目名 + exclude_keywords = ["国网", "公司"] + + keywords = [] + + for name in list(GlobalData.simply_to_standard_project_name_map.values()): # 正式换成 + print(f"name:{name}") + # 去掉括号及里面内容 + cleaned_name = re.sub(r"(.*?)|\(.*?\)", "", name) + + # 如果包含排除关键词,跳过 + if any(ek in cleaned_name for ek in exclude_keywords): + continue + + # 处理有"-"连接的情况 + if "-" in cleaned_name: + parts = cleaned_name.split("-") + + # 切出来的每一段,也要去括号内容 + part0 = re.sub(r"(.*?)|\(.*?\)", "", parts[0]) + part1 = re.sub(r"(.*?)|\(.*?\)", "", parts[1]) + + first = re.match(r"[\u4e00-\u9fa5]+", part0) + second = re.match(r"[\u4e00-\u9fa5]+", part1) + + first_word = first.group(0) if first else "" + second_word = second.group(0) if second else "" + + if first_word: + keywords.append(first_word) + if second_word: + keywords.append(second_word) + if first_word and second_word: + keywords.append(first_word + second_word) + else: + # 没有"-",提取第一个连续中文 + match = re.match(r"([\u4e00-\u9fa5]+)", cleaned_name) + if match: + word = match.group(1) + if word: + keywords.append(word) + + # 去重且保持顺序 + seen = set() + unique_keywords = [] + for kw in keywords: + if kw not in seen: + seen.add(kw) + unique_keywords.append(kw) + + # 写入文件 + with open("new_project.txt", "w", encoding="utf-8") as f: + for kw in unique_keywords: + f.write(kw + "\n") + + print("提取完成,已写入 new_project.txt") + + +def removte_reduant_list(): + temp_list = load_standard_name_list("./new_project.txt") + # 去重且保持顺序 + seen = set() + unique_keywords = [] + for kw in temp_list: + if kw not in seen: + seen.add(kw) + unique_keywords.append(kw) + + with open("hot_word/final_new_project.txt", "w", encoding="utf-8") as f: + for kw in unique_keywords: + f.write(kw + "\n") + save_standard_name_list_to_file(list(GlobalData.simply_to_standard_project_name_map.keys()),"./hot_word/project.txt") + +def list_to_json(): + import json + my_list = load_standard_name_list("./hot_word/final_new_project.txt") + data = { + "hotwordList": my_list + } + + # 转换为 JSON 格式字符串(ensure_ascii=False 确保中文正常显示) + json_str = json.dumps(data, ensure_ascii=False, indent=4) + + save_dict_to_file(json_str,'./hot_word/final_new_project.json') + + +list_to_json() +# exact_hot_words() +# exact_project_hot_words() +# unuselessStr = clean_useless_project_name("众兴杜岗ⅱ间隔改造") +# print(f"众兴杜岗ⅱ间隔改造:{unuselessStr}") +# unuselessStr = clean_useless_project_name("众兴杜岗Ⅱ间隔改造") +# print(f"众兴杜岗Ⅱ间隔改造:{unuselessStr}") +# print("今天的长度:",len("今天")) # standardize_program() # history_message() +# standardize_project_test() +# standardize_company_test() # standardize_team_leader_test() # # standardize_sub_constractor_test() diff --git a/api/utils.py b/api/utils.py index ca5fb51..bba9cff 100644 --- a/api/utils.py +++ b/api/utils.py @@ -561,7 +561,7 @@ def check_lost(int_res, slot): def check_standard_name_slot_probability(int_res, slot) -> tuple: - intention_list = {3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15,16} + intention_list = {3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26} if int_res not in intention_list: return CheckResult.NO_MATCH, "" diff --git a/ernie/data.yaml b/ernie/data.yaml index 21c0b05..0580fac 100644 --- a/ernie/data.yaml +++ b/ernie/data.yaml @@ -7,4 +7,5 @@ test: ./data/test.json # (可选) 测试集路径 nc: 13 # 目标类别数 labels: ["天气查询","互联网查询","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容", "施工人数","作业考勤人数","知识问答","通用对话","作业面查询","班组人数查询","班组数查询","作业面内容", - "班组详情","工程进度查询"] # 类别名称 + "班组详情","工程进度查询", "人员查询", "分公司查询","工程数量查询","工程详情查询","项目部数量查询", + "建管单位数量查询","建管单位详情","分包单位数量查询","分包单位详情"] # 类别名称 diff --git a/ernie/train.py b/ernie/train.py index 5fb2199..2410cfd 100644 --- a/ernie/train.py +++ b/ernie/train.py @@ -7,7 +7,7 @@ import numpy as np import functools from paddle.nn import CrossEntropyLoss from paddlenlp.data import DataCollatorWithPadding -from paddlenlp.trainer import Trainer, TrainingArguments +from paddlenlp.trainer import Trainer, TrainingArguments, EarlyStoppingCallback import os from sklearn.metrics import precision_score, recall_score, f1_score @@ -116,10 +116,10 @@ def main(): output_dir="./output_temp", evaluation_strategy="epoch", save_strategy="epoch", - eval_steps=2000, # 每100步评估一次 - save_steps=2000, + eval_steps=2000, # 每2000步评估一次,evaluation_strategy="steps"时生效 + save_steps=2000, # 每2000步保存一次,save_strategy="steps"时生效 logging_dir="./logs", - logging_steps=100, # 每50步输出一次日志 + logging_steps=100, # 每100步输出一次日志 num_train_epochs=10, # 训练轮数 per_device_train_batch_size=32, per_device_eval_batch_size=32, @@ -140,6 +140,7 @@ def main(): eval_dataset=test_ds, data_collator=data_collator, compute_metrics=compute_metrics, # 使用自定义的评估指标 + callbacks=[EarlyStoppingCallback(early_stopping_patience=2)], ) # 训练模型 diff --git a/generated_data/generated.py b/generated_data/generated.py index 32fce7e..618a7dd 100644 --- a/generated_data/generated.py +++ b/generated_data/generated.py @@ -78,7 +78,14 @@ BASE_DATA = { "operatings": ["8+2工况", "8加2工况"], # 页面切换 "pages": ["风险管控", "日计划", "周风险", "日计划统计报表", "日计划推送", "生产管控中心", "考勤统计详情", - "今日作业计划", "周风险统计报表", "周风险推送"] + "今日作业计划", "周风险统计报表", "周风险推送"], + # 具体人名 + "person_names": ["何东洋", "李东","王孙强林"], + # 人名查询目标 + "person_query_types": ["班组", "工程", "分公司", "实时组织", "项目部", "项目管理部"], + + # 工程状态 + "project_status_s": ["在建", "在作业", "在施工"] } @@ -465,11 +472,11 @@ TEMPLATE_CONFIG = { ("{date}{construction_unit}{operating}的作业内容是什么?", ["date", "construction_unit","operating"]), #分包单位 - ("{subcontractor}{date}作业内容是什么", ["subcontractor", "date"]), - ("{date}{subcontractor}具体的作业内容", ["date","subcontractor"]), + ("{subcontractor}{date}作业内容", ["subcontractor", "date"]), + ("{date}{subcontractor}具体作业内容", ["date","subcontractor"]), ("{date}分包单位为{subcontractor}有哪些{risk_level}风险作业计划?", ["date", "subcontractor", "risk_level"]), - ("{date}{subcontractor}风险等级为{risk_level}的作业计划是什么?", ["date", "subcontractor", "risk_level"]), - ("{date}{subcontractor}{operating}的作业内容是什么?", ["date", "subcontractor","operating"]), + ("{date}{subcontractor}风险等级为{risk_level}的作业计划?", ["date", "subcontractor", "risk_level"]), + ("{date}{subcontractor}{operating}具体作业计划", ["date", "subcontractor","operating"]), ] }, @@ -933,7 +940,7 @@ TEMPLATE_CONFIG = { # 1. 查询特定日期和项目的作业安排 ("{date}{project_name}作业面是什么?", ["date", "project_name"]), ("{date}属于{operating}作业面内容是什么?", ["date", "operating"]), - ("{date}存在{operating}作业面是什么?", ["date", "operating"]), + ("{date}存在{operating}作业面", ["date", "operating"]), # 3. 查询特定日期和项目类型的工程计划 ("{date}{project_type}类具体作业面有哪些?", ["date", "project_type"]), @@ -994,7 +1001,7 @@ TEMPLATE_CONFIG = { ("{date}{construction_area}地区{operating}具体作业面是什么", ["date", "construction_area","operating"]), #建管单位 - ("{construction_unit}{date}具体作业面内容是什么", ["construction_unit", "date"]), + ("{construction_unit}{date}具体作业面内容", ["construction_unit", "date"]), ("{date}{construction_unit}具体作业面有哪些", ["date", "construction_unit",]), ("{date}建管单位为{construction_unit}作业面是什么?", ["date", "construction_unit"]), ("{date}{construction_unit}风险等级为{risk_level}两项作业面分别是什么?", ["date", "construction_unit", "risk_level"]), @@ -1003,9 +1010,9 @@ TEMPLATE_CONFIG = { #分包单位 ("{subcontractor}{date}具体作业面内容是什么", ["subcontractor", "date"]), - ("{date}{subcontractor}具体作业面有哪些", ["date","subcontractor"]), + ("{date}{subcontractor}具体作业面", ["date","subcontractor"]), ("{date}分包单位为{subcontractor}{risk_level}作业面是什么", ["date", "subcontractor", "risk_level"]), - ("{date}{subcontractor}风险等级为{risk_level}两项作业面分别是什么?", ["date", "subcontractor", "risk_level"]), + ("{date}{subcontractor}风险等级为{risk_level}两项作业面", ["date", "subcontractor", "risk_level"]), ("{date}{subcontractor}{operating}作业面内容有哪些?", ["date", "subcontractor","operating"]), ] }, @@ -1111,7 +1118,244 @@ TEMPLATE_CONFIG = { ("{project_name}今日工程进度情况", ["project_name"]), ("{project_name}今天工程进展怎么样?", ["project_name"]), ] - } + }, + "人员查询": { + "date": ["今天","最近"], + "templates": [ + ("{person_name}在哪个{person_query_type}", ["person_name","person_query_type"]), + ("请帮我查一下{person_name}在哪个{person_query_type}", ["person_name","person_query_type"]), + ] + }, + + "分公司查询": { + "date": ["今天","最近"], + "templates": [ + ("安徽送变电公司有多少分公司", []), + ("公司有多少分公司", []), + ("请查一下分公司数量", []), + ("分公司数量", []), + ("实施组织数量", []), + ("分公司详情", []), + ("实施组织详情", []), + ("公司有哪些分公司", []), + ("公司有哪些实施组织", []), + ("{implementation_organization}详情", ["implementation_organization"]), + ("{implementation_organization}情况", ["implementation_organization"]), + ("请帮我查一下具体分公司详情", []), + ("请帮我查一下具体实施组织详情", []), + ("请帮我查一下具体{implementation_organization}详情", ["implementation_organization"]), + ] + }, + + # "工程数量查询": { + # "date": ["今日", "昨日", "2024年5月24日", "5月24日", "今天", "昨天"], + # "templates": [ + # #公司 + # ("{date}公司有多少工程", ["date"]), + # ("{date}安徽送变电公司有多少工程{project_status}", ["date", "project_status"]), + # #分公司和项目部 + # ("{implementation_organization}{date}有多少工程{project_status}", + # ["implementation_organization", "date", "project_status"]), + # ("{implementation_organization}{project_department}{date}有多少工程{project_status}", + # ["implementation_organization", "project_department", "date", "project_status"]), + # #建管区域和单位 + # ("{construction_area}地区{date}风险等级为{risk_level}有多少工程?", ["construction_area", "date", "risk_level"]), + # + # ("{construction_area}地区{date}有多少工程{project_status}?", ["construction_area", "date", "project_status"]), + # + # ("{construction_unit}{date}有多少工程{project_status}?", ["construction_unit", "date","project_status"]), + # + # #分包商 + # ("{date}{subcontractor}有多少工程{project_status}", ["date", "subcontractor", "project_status"]), + # ("{date}送变电公司{project_department}有多少工程?", ["date", "project_department"]), + # #项目经理 + # ("{date}{project_manager}有多少工程{project_status}", ["date", "project_manager","project_status"]), + # #班组名称 + # ("{team_leader}{date}有多少工程", ["team_leader", "date"]), + # #工程性质 + # ("公司{date}{project_type}的工程有多少?", ["date", "project_type"]), + # #风险等级 + # ("公司{date}{risk_level}风险的{project_status}工程有多少?", ["date", "risk_level", "project_status"]), + # #询问工程数量时有工程性质和风险等级吗 + # ] + # }, + # + # "工程详情查询": { + # "date": ["今日", "昨日", "2024年5月24日", "5月24日", "今天", "昨天"], + # "templates": [ + # #公司 + # ("{date}公司有哪些工程", ["date"]), + # ("{date}安徽送变电公司有哪些工程{project_status}", ["date", "project_status"]), + # #分公司和项目部 + # ("{implementation_organization}{date}工程详情{project_status}", + # ["implementation_organization", "date", "project_status"]), + # ("{implementation_organization}{project_department}{date}有哪些工程{project_status}", + # ["implementation_organization", "project_department", "date", "project_status"]), + # #建管区域和单位 + # ("{construction_area}地区{date}风险等级为{risk_level}工程具体情况?", ["construction_area", "date", "risk_level"]), + # + # ("{construction_area}地区{date}有哪些工程{project_status}?", ["construction_area", "date", "project_status"]), + # + # ("{construction_unit}{date}有哪些工程{project_status}?", ["construction_unit", "date","project_status"]), + # + # #分包商 + # ("{date}{subcontractor}有多少工程{project_status}", ["date", "subcontractor", "project_status"]), + # ("{date}送变电公司{project_department}工程详情?", ["date", "project_department"]), + # #项目经理 + # ("{date}{project_manager}有多少工程{project_status}", ["date", "project_manager","project_status"]), + # #班组名称 + # ("{team_leader}{date}工程具体情况", ["team_leader", "date"]), + # #工程性质 + # ("公司{date}{project_type}的工程有哪些?", ["date", "project_type"]), + # #风险等级 + # ("公司{date}{risk_level}风险的{project_status}工程有那些?", ["date", "risk_level", "project_status"]), + # #询问工程数量时有工程性质和风险等级吗 + # ] + # }, + + "工程数量查询": { + "date": ["今天","最近"], + "templates": [ + #公司 + ("公司有多少工程", []), + ("安徽送变电公司有多少工程{project_status}", ["project_status"]), + #分公司和项目部 + ("{implementation_organization}有多少工程{project_status}", + ["implementation_organization", "project_status"]), + ("{implementation_organization}{project_department}有多少工程{project_status}", + ["implementation_organization", "project_department", "project_status"]), + #建管区域和单位 + ("{construction_area}地区风险等级为{risk_level}有多少工程?", ["construction_area", "risk_level"]), + + ("{construction_area}地区有多少工程{project_status}?", ["construction_area", "project_status"]), + + ("{construction_unit}有多少工程{project_status}?", ["construction_unit","project_status"]), + + #分包商 + ("{subcontractor}有多少工程{project_status}", ["subcontractor", "project_status"]), + ("安徽送变电公司{project_department}有多少工程?", ["project_department"]), + #项目经理 + ("{project_manager}有多少工程{project_status}", ["project_manager","project_status"]), + #班组名称 + ("{team_leader}有多少{project_status}工程", ["team_leader", "project_status"]), + #工程性质 + ("公司{project_type}类的工程有多少?", ["project_type"]), + #风险等级 + ("公司{risk_level}风险的{project_status}工程有多少?", ["risk_level", "project_status"]), + #询问工程数量时有工程性质和风险等级吗 + ] + }, + + "工程详情查询": { + "date": ["今天","最近"], + "templates": [ + #公司 + ("公司有哪些工程", []), + ("截止目前公司有哪些{project_status}工程", ["project_status"]), + ("安徽送变电公司有哪些工程{project_status}", ["project_status"]), + #分公司和项目部 + ("{implementation_organization}工程详情{project_status}", + ["implementation_organization", "project_status"]), + ("{implementation_organization}{project_department}有哪些工程{project_status}", + ["implementation_organization", "project_department", "project_status"]), + #建管区域和单位 + ("{construction_area}地区风险等级为{risk_level}工程具体情况?", ["construction_area", "risk_level"]), + + ("{construction_area}地区有哪些{project_status}工程?", ["construction_area", "project_status"]), + + ("{construction_unit}有哪些工程{project_status}?", ["construction_unit","project_status"]), + + #分包商 + ("{subcontractor}有多少{project_status}工程", ["subcontractor", "project_status"]), + ("送变电公司{project_department}工程详情?", ["project_department"]), + #项目经理 + ("{project_manager}有多少工程{project_status}", ["project_manager","project_status"]), + #班组名称 + ("{team_leader}工程具体情况", ["team_leader"]), + #工程性质 + ("公司{project_type}类的工程有哪些?", ["project_type"]), + #风险等级 + ("公司{risk_level}风险的{project_status}工程有那些?", ["risk_level", "project_status"]), + #询问工程数量时有工程性质和风险等级吗 + ] + }, + + "项目部数量查询": { + "date": ["今天","最近"], + "templates": [ + #公司 + ("公司有多少项目部", []), + ("安徽送变电公司有多少项目管理部", []), + #分公司 + ("{implementation_organization}有多少项目部", ["implementation_organization"]), + ("{implementation_organization}有多少项目管理部", ["implementation_organization"]), + ("{implementation_organization}项目管理部的数量", ["implementation_organization"]), + #请帮我查一下 + ("请帮我查一下公司项目部数量", []), + ("请帮我查一下{implementation_organization}有多少项目管理部", + ["implementation_organization"]), + ] + }, + + "项目部详情": { + "date": ["今天","最近"], + "templates": [ + #公司 + ("公司有哪些项目部", []), + ("安徽送变电公司项目管理部详情", []), + #分公司 + ("{implementation_organization}项目部详情", ["implementation_organization"]), + ("{implementation_organization}有哪些项目管理部", ["implementation_organization"]), + #请帮我查一下 + ("请帮我查一下公司项目部详情", []), + ("请帮我查一下{implementation_organization}有哪些项目管理部", + ["implementation_organization"]), + ] + }, + + "建管单位数量查询": { + "date": ["今天","最近"], + "templates": [ + #公司 + ("公司有多少建管单位", []), + ("安徽送变电公司有多少建管单位", []), + ] + }, + + "建管单位详情": { + "date": ["今天","最近"], + "templates": [ + #公司 + ("{project_name}建管单位情况", ["project_name"]), + ("{project_name}建管单位详情", ["project_name"]), + ("请介绍下{construction_unit}详情", ["construction_unit"]), + ("请介绍下{construction_unit}情况", ["construction_unit"]), + ] + }, + + "分包单位数量查询": { + "date": ["今天","最近"], + "templates": [ + #公司 + ("公司有多少分包单位", []), + ("公司有多少分包商", []), + ("安徽送变电公司有多少分包单位", []), + ("安徽送变电公司有多少分包商", []), + ] + }, + + "分包单位详情": { + "date": ["今天","最近"], + "templates": [ + #公司 + ("{project_name}分包单位详情", ["project_name"]), + ("{project_name}分包商情况", ["project_name"]), + ("请介绍下{subcontractor}详情", ["subcontractor"]), + ("请介绍下{subcontractor}情况", ["subcontractor"]), + ] + }, + + } @@ -1133,6 +1377,9 @@ def generate_natural_samples(config, label): "operating": BASE_DATA["operatings"], "team_name": BASE_DATA["team_names"], "construction_area": BASE_DATA["construction_areas"], + "person_name": BASE_DATA["person_names"], + "person_query_type": BASE_DATA["person_query_types"], + "project_status": BASE_DATA["project_status_s"], } for template, variables in config["templates"]: diff --git a/generated_data/合并数据.py b/generated_data/合并数据.py index 1b9e611..422ff35 100644 --- a/generated_data/合并数据.py +++ b/generated_data/合并数据.py @@ -25,7 +25,9 @@ def merge_json_files(file_list, output_file): files = ['互联网查询.json','天气查询.json','知识问答.json','作业考勤人数.json', '周计划作业内容.json', '周计划数量查询.json','施工人数.json','日计划作业内容.json','日计划数量查询.json', '页面切换.json','通用对话.json','作业面查询.json','班组人数查询.json','班组数查询.json','作业面内容.json', - '班组详情.json','工程进度查询.json'] + '班组详情.json','工程进度查询.json','人员查询.json','分公司查询.json','工程数量查询.json', + '工程详情查询.json','项目部数量查询.json','建管单位数量查询.json','建管单位详情.json','分包单位数量查询.json', + '分包单位详情.json'] output_file = 'output/merged_data.json' # 执行合并 diff --git a/uie/train.py b/uie/train.py index 35dc1c7..16fe27f 100644 --- a/uie/train.py +++ b/uie/train.py @@ -2,22 +2,25 @@ import json import paddle from paddlenlp.datasets import MapDataset from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer -from paddlenlp.trainer import Trainer, TrainingArguments +from paddlenlp.trainer import Trainer, TrainingArguments, EarlyStoppingCallback from paddlenlp.data import DataCollatorForTokenClassification + # === 1. 加载数据 === def load_dataset(data_path): with open(data_path, "r", encoding="utf-8") as f: data = json.load(f) return MapDataset(data) + # === 2. 预处理数据 === def preprocess_function(example, tokenizer): # 预定义实体类型列表 entity_types = [ 'date', 'project_name', 'project_type', 'construction_unit', 'implementation_organization', 'project_department', 'project_manager', - 'subcontractor', 'team_leader', 'risk_level', 'page', 'operating', 'team_name', 'construction_area' + 'subcontractor', 'team_leader', 'risk_level', 'page', 'operating', 'team_name', + 'construction_area', 'person_name', 'person_query_type', 'project_status' ] # 文本 Tokenization @@ -59,7 +62,7 @@ def preprocess_function(example, tokenizer): # === 3. 加载 UIE 预训练模型 === -model = ErnieForTokenClassification.from_pretrained("uie-base", num_classes=29) # 3 类 (O, B, I) +lsmodel = ErnieForTokenClassification.from_pretrained("uie-base", num_classes=35) # 3 类 (O, B, I) tokenizer = ErnieTokenizer.from_pretrained("uie-base") # === 4. 加载数据集 === @@ -70,7 +73,6 @@ dev_dataset = load_dataset("data/val.json") # 验证数据集 train_dataset = train_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False) dev_dataset = dev_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False) - # === 6. 数据整理 === data_collator = DataCollatorForTokenClassification(tokenizer, padding=True) @@ -87,13 +89,30 @@ training_args = TrainingArguments( save_total_limit=1, # 只保留最新 2 个模型 logging_dir="./logs", logging_steps=100, - eval_steps=5000, - save_steps=5000, + eval_steps=5000, #evaluation_strategy="steps"时生效 + save_steps=5000, #save_strategy="steps"时生效 seed=1000, load_best_model_at_end=True, ) -# === 8. 训练 === +# === 8. 创建 EarlyStoppingCallback 实例 === +early_stopping_callback = EarlyStoppingCallback( + early_stopping_patience=2, # 连续多少次评估不提升就停 + early_stopping_threshold=0.01 # 最小提升幅度(例如设为0.01表示至少提升1%) +) + + +def compute_metrics(eval_preds): + predictions, labels = eval_preds + preds = predictions.argmax(axis=-1) + correct = (preds == labels).astype(int) + accuracy = correct.sum() / correct.size + return {"accuracy": accuracy} + + +training_args.metric_for_best_model = "accuracy" + +# === 9. 训练 === trainer = Trainer( model=model, args=training_args, @@ -101,11 +120,12 @@ trainer = Trainer( eval_dataset=dev_dataset, tokenizer=tokenizer, data_collator=data_collator, + compute_metrics=compute_metrics, + callbacks=[early_stopping_callback], # 添加 EarlyStopping 回调 ) trainer.train() - # 为模型定义输入规格 input_spec = [ paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="input_ids"), @@ -113,4 +133,3 @@ input_spec = [ paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="position_ids"), paddle.static.InputSpec(shape=[None, 512], dtype="float32", name="attention_mask") ] -