新增10中意图数据

This commit is contained in:
weiweiw 2025-05-04 15:29:03 +08:00
parent 90e9a919e3
commit fe34128d83
13 changed files with 597 additions and 164 deletions

View File

@ -15,34 +15,41 @@ from config import *
from globalData import GlobalData
from apscheduler.schedulers.background import BackgroundScheduler
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-25910"
MODEL_UIE_PATH = R"../uie/output/checkpoint-32750"
# MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-25910"
# MODEL_UIE_PATH = R"../uie/output/checkpoint-32750"
MODEL_ERNIE_PATH = R"../ernie/output_temp/checkpoint-33510"
MODEL_UIE_PATH = R"../uie/output_temp/checkpoint-33220"
# 类别名称列表
labels = [
"天气查询", "互联网查询", "页面切换", "日计划数量查询", "周计划数量查询",
"日计划作业内容", "周计划作业内容", "施工人数", "作业考勤人数", "知识问答",
"通用对话", "作业面查询", "班组人数查询", "班组数查询", "作业面内容", "班组详情",
"工程进度查询"
"工程进度查询", "人员查询", "分公司查询","工程数量查询","工程详情查询","项目部数量查询",
"建管单位数量查询","建管单位详情","分包单位数量查询","分包单位详情"
]
# 标签映射
label_map = {
0: 'O', # 非实体
1: 'B-date', 15: 'I-date',
2: 'B-projectName', 16: 'I-projectName',
3: 'B-projectType', 17: 'I-projectType',
4: 'B-constructionUnit', 18: 'I-constructionUnit',
5: 'B-implementationOrganization', 19: 'I-implementationOrganization',
6: 'B-projectDepartment', 20: 'I-projectDepartment',
7: 'B-projectManager', 21: 'I-projectManager',
8: 'B-subcontractor', 22: 'I-subcontractor',
9: 'B-teamLeader', 23: 'I-teamLeader',
10: 'B-riskLevel', 24: 'I-riskLevel',
11: 'B-page', 25: 'I-page',
12: 'B-operating', 26: 'I-operating',
13: 'B-teamName', 27: 'I-teamName',
14: 'B-constructionArea', 28: 'I-constructionArea',
1: 'B-date', 18: 'I-date',
2: 'B-projectName', 19: 'I-projectName',
3: 'B-projectType', 20: 'I-projectType',
4: 'B-constructionUnit', 21: 'I-constructionUnit',
5: 'B-implementationOrganization', 22: 'I-implementationOrganization',
6: 'B-projectDepartment', 23: 'I-projectDepartment',
7: 'B-projectManager', 24: 'I-projectManager',
8: 'B-subcontractor', 25: 'I-subcontractor',
9: 'B-teamLeader', 26: 'I-teamLeader',
10: 'B-riskLevel', 27: 'I-riskLevel',
11: 'B-page', 28: 'I-page',
12: 'B-operating', 29: 'I-operating',
13: 'B-teamName', 30: 'I-teamName',
14: 'B-constructionArea', 31: 'I-constructionArea',
15: 'B-personName', 32: 'I-personName',
16: 'B-personQueryType', 33: 'I-personQueryType',
17: 'B-projectStatus', 34: 'I-projectStatus',
}
logger = setup_logger("main", level=logging.DEBUG)

View File

@ -15,34 +15,38 @@ from config import *
from globalData import GlobalData
from apscheduler.schedulers.background import BackgroundScheduler
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-25910"
MODEL_UIE_PATH = R"../uie/output/checkpoint-32750"
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-33510"
MODEL_UIE_PATH = R"../uie/output/checkpoint-33220"
# 类别名称列表
labels = [
"天气查询", "互联网查询", "页面切换", "日计划数量查询", "周计划数量查询",
"日计划作业内容", "周计划作业内容", "施工人数", "作业考勤人数", "知识问答",
"通用对话", "作业面查询", "班组人数查询", "班组数查询", "作业面内容", "班组详情",
"工程进度查询"
"工程进度查询", "人员查询", "分公司查询","工程数量查询","工程详情查询","项目部数量查询",
"建管单位数量查询","建管单位详情","分包单位数量查询","分包单位详情"
]
# 标签映射
label_map = {
0: 'O', # 非实体
1: 'B-date', 15: 'I-date',
2: 'B-projectName', 16: 'I-projectName',
3: 'B-projectType', 17: 'I-projectType',
4: 'B-constructionUnit', 18: 'I-constructionUnit',
5: 'B-implementationOrganization', 19: 'I-implementationOrganization',
6: 'B-projectDepartment', 20: 'I-projectDepartment',
7: 'B-projectManager', 21: 'I-projectManager',
8: 'B-subcontractor', 22: 'I-subcontractor',
9: 'B-teamLeader', 23: 'I-teamLeader',
10: 'B-riskLevel', 24: 'I-riskLevel',
11: 'B-page', 25: 'I-page',
12: 'B-operating', 26: 'I-operating',
13: 'B-teamName', 27: 'I-teamName',
14: 'B-constructionArea', 28: 'I-constructionArea',
1: 'B-date', 18: 'I-date',
2: 'B-projectName', 19: 'I-projectName',
3: 'B-projectType', 20: 'I-projectType',
4: 'B-constructionUnit', 21: 'I-constructionUnit',
5: 'B-implementationOrganization', 22: 'I-implementationOrganization',
6: 'B-projectDepartment', 23: 'I-projectDepartment',
7: 'B-projectManager', 24: 'I-projectManager',
8: 'B-subcontractor', 25: 'I-subcontractor',
9: 'B-teamLeader', 26: 'I-teamLeader',
10: 'B-riskLevel', 27: 'I-riskLevel',
11: 'B-page', 28: 'I-page',
12: 'B-operating', 29: 'I-operating',
13: 'B-teamName', 30: 'I-teamName',
14: 'B-constructionArea', 31: 'I-constructionArea',
15: 'B-personName', 32: 'I-personName',
16: 'B-personQueryType', 33: 'I-personQueryType',
17: 'B-projectStatus', 34: 'I-projectStatus',
}
logger = setup_logger("main", level=logging.DEBUG)

View File

@ -1,53 +0,0 @@
第八项目管理部(淮北宿州)
第七项目管理部(阜阳)
第十一项目管理部(马鞍山)
第四项目管理部(安庆)
第九项目管理部(合肥轨道线)
第五项目管理部(合肥)
第一项目管理部(池州黄山)
第十项目管理部(特高压)
第二项目管理部(宣城)
第六项目管理部(滁州)
第三项目管理部(芜湖)
第四项目管理部(六安变电)
第七项目管理部(淮南线路)
第十项目管理部(亳州变电)
第九项目管理部(亳州线路)
第五项目管理部(蚌埠线路)
第十一项目管理部(萧砀线路)
第三项目管理部(张店线路)
第三项目管理部(岳西线路)
第八项目管理部(淮南变电)
第六项目管理部(蚌埠变电)
第十一项目管理部(宿州线路)
第二项目管理部(合肥变电)
第三项目管理部(谯城变、亳州楼)
第五项目管理部(金牛变)
第二项目管理部(合州站、阜四变)
第一项目管理部(萧砀变、锁库变)
第七项目管理部(合肥中心变)
第二项目管理部(修试)
第三项目管理部(香鹭东段)
第二项目管理部(香鹭西段)
第五项目管理部(阜阳)
第十三项目管理部(黄山)
第八项目管理部(芜湖)
第九项目管理部(马鞍山)
第四项目管理部(甘浙)
第十一项目管理部(宣城)
第九项目管理部(淮北)
第十二项目管理部(陕皖)
第一项目管理部(肥东)
第四项目管理部(池州)
第二项目管理部(紫蓬)
第六项目管理部(安庆)
第八项目管理部(宿州分部)
第七项目管理部(安庆四)
第三项目管理部(庐江)
第三项目管理部(六安线路)
第六项目管理部(阜阳综合楼、省营销楼)
第四项目管理部(安庆四、明生楼)
第一项目管理部(金上)
第一项目管理部(修试)
第五项目管理部(铜陵)
第八项目管理部(宿州)

View File

@ -1,7 +0,0 @@
检修试验分公司
送电一分公司
送电二分公司
变电分公司
建筑分公司
安徽宏源电力建设有限公司
安徽顺安电网建设有限公司

View File

@ -325,7 +325,7 @@
500kV当涂变220kV当马安控系统升级改造项目((PROJ-2021-0114))
淄博高青500kV返厂大修(PROJ-2024-0850)
合肥云谷-江汽新港变110kV线路改造工程(PROJ-2020-0524)
淮北庙台220kV变电站新建工程 (电气部分(PROJ-2025-0115)
淮北220千伏庙台变电站新建工程 (电气(PROJ-2025-0115)
香涧-梨花π入固镇南牵引站220kV线路工程(PROJ-2024-0441)
香涧500kV变电站间隔扩建工程(PROJ-2024-0890)
安徽铜陵枞阳县破罡35KV变电站增容改造工程(PROJ-2021-0064)
@ -388,6 +388,7 @@
武昌-古港π入江调变电站110kV线路工程(PROJ-2021-0088)
宏源大厦维修改造项目(PROJ-2021-0089)
安徽合肥大学城220kV变电站110kV站前路主所间隔扩建工程(PROJ-2021-0091)
亳州伯阳500kV变电站220kV道仁间隔扩建工程(PROJ-2024-0781)
尧天湖220kV变电站35kV田营园区出线间隔扩建工程(PROJ-2020-0525)
蓼城-白莲牵引站Ⅱ回线改接俞林变电站110kV架空线路工程(PROJ-2020-0526)
±500KV延庆换流站备用换流变安装(PROJ-2025-0162)
@ -1767,6 +1768,7 @@ G205九华南路快速化改造工程二期火龙岗北至南陵渡桥段5
国网安徽合肥供电公司2024年220kV学苑变电站一键顺控系统完善提升(PROJ-2024-0244)
蕴山-沙埂110kV线路工程(PROJ-2024-0245)
国国网安徽合肥供电公司2024年220kV秋浦变电站一键顺控系统完善提升(PROJ-2024-0249)
国网安徽合肥供电公司2024年500kV龙门变电站一键顺控系统完善提升(PROJ-2024-0250)
500kV潘清5710线#89-#100段改造工程(PROJ-2024-0252)
500kV墨孔5339线#234-#245段改造工程(PROJ-2024-0253)
国网安徽电力超高压公司1000kV特高压芜湖站1000kVI母线GM104气室B相、GM113气室C相特高频局放异常处理项目(PROJ-2024-0254)
@ -1795,7 +1797,6 @@ G205九华南路快速化改造工程二期火龙岗北至南陵渡桥段5
安徽亳州华都110kV变电站110kV真源间隔改造工程调试部分(PROJ-2024-1180)
国网北京检修公司2024年±500kV延庆换流站阀冷系统设备驻站(PROJ-2024-0849)
安徽亳州鲲鹏110kV变电站新建工程(PROJ-2025-0095)
围屏220kV变电站110kV万济间隔保护改造工程(PROJ-2025-0182)
1000kV芜湖站1000kV T042开关拆除工作(PROJ-2025-0113)
庆阳±800kV换流站工程大件运输工程(PROJ-2024-1264)
同乐-明都开断接入林楼变220kV线路工程(含光缆)(PROJ-2021-0018)
@ -1944,7 +1945,7 @@ S334峨山路东延伸沿江高速至芜宣高速新建工程二期500kV
谢桥电厂-原鹿220kV线路工程(PROJ-2024-0785)
围屏220kV变电站220kV万济间隔扩建工程调试部分(PROJ-2025-0135)
马鞍山华阳110kV变电站110kV含山间隔改造工程调试部分(PROJ-2024-1169)
淮北庙台220kV变电站新建工程土建部分(PROJ-2025-0114)
淮北220千伏庙台变电站新建工程土建(PROJ-2025-0114)
碱河220千伏变电站110千伏凌云间隔扩建工程(调试部分)(PROJ-2025-0006)
国网淮南供电公司500kV孔店变500kV5043电流互感器更换项目(PROJ-2025-0124)
花园220kV变电站薛桥、郭王间隔改造工程(PROJ-2025-0116)
@ -2024,7 +2025,6 @@ S334峨山路东延伸沿江高速至芜宣高速新建工程二期500kV
国网安徽电力超高压分公司2024年超特高压变电站应急保障服务(PROJ-2024-0843)
太和-李腰π入城南变电站110kV架空线路工程(PROJ-2024-0709)
金牛500kV变电站新建工程调试部分(PROJ-2025-0032)
章塘220kV变电站110kV万济间隔保护改造工程(PROJ-2025-0181)
安徽送变电工程有限公司2025年度扩建及技改工程电气安装劳务分包框架(PROJ-2025-0099)
国网蚌埠供电公司500kV怀洪变500kV5379线路及5380线路间隔检查和试验项目(PROJ-2025-0097)
椿树220千伏变电站220千伏阜四间隔扩建工程(电气安装)(PROJ-2024-0942)
@ -2088,7 +2088,6 @@ S334峨山路东延伸沿江高速至芜宣高速新建工程二期500kV
谯城(亳三)-谯城220kV电缆线路工程(PROJ-2024-1203)
陕北-安徽直流工程合州±800千伏换流站土建A包(PROJ-2024-0312)
宿州蟠龙220kV变电站220kV大庄风电间隔扩建工程(电气安装)(PROJ-2024-0466)
S19淮桐高速合肥段涉500kV皋铭5357线皋传5358线改造工程(PROJ-2025-0183)
石岗-施桥110kV线路工程(PROJ-2024-0276)
国网安徽宣城供电公司500kV河沥变加装固定融冰装置调试部分(PROJ-2025-0087)
香涧-鹭岛500kV线路工程一般线路东段(PROJ-2024-0725)

View File

@ -36,13 +36,13 @@
李章贵班组
吴义新班组
郑云龙班组
徐钦文班组
张江福班组
盛平班组
储友健班组
熊宗书班组
李光明班组
朱明雷班组
刘亚锋班组
耿兴海班组
王威班组
张小斌班组
@ -104,6 +104,7 @@
蔡道平班组
汪吉祥班组
徐朝军班组
汪自存班组
徐本家班组
任培培班组
陈兰荣班组
@ -157,6 +158,7 @@
邓君班组
李赛北班组
何东洋班组
古伟班组
叶从进班组
席胜利班组
孔永高班组
@ -169,7 +171,6 @@
王亮班组
方勇勇班组
徐兴才班组
龙金华班组
蔡大羊班组
乔帅兵班组
刘荣君班组
@ -202,7 +203,6 @@
杨廷泽班组
董可祥班组
王新班组
鲁从伟班组
陈双双班组
郭翔翔班组
刘鹏班组
@ -300,7 +300,6 @@
洪金涛班组
单海峰班组
张海涛班组
岳林班组
唐孝明班组
叶朝磊班组
左华彪班组
@ -323,7 +322,6 @@
段祥宇班组
王坤班组
刘士平班组
殷书尔班组
赵振强班组
方年春班组
黄勇班组
@ -381,6 +379,7 @@
江军班组
杨长和班组
朱纪倍班组
夏万年班组
赵华健班组
杨松伟班组
王务红班组
@ -422,6 +421,7 @@
蒲民班组
黄本初班组
高磊班组
张志班组
姚海强班组
吴庆欢班组
徐南班组

View File

@ -4,7 +4,8 @@ from logger_util import setup_logger
from globalData import GlobalData
from utils import standardize_name, clean_useless_team_leader_name, standardize_sub_company, standardize_project_name, \
standardize_projectDepartment, standardize_team_name, check_standard_name_slot_probability, \
clean_useless_project_name
clean_useless_project_name, save_standard_name_list_to_file, load_standard_name_list, save_dict_to_file, \
load_standard_json_data
import time
from apscheduler.schedulers.blocking import BlockingScheduler
@ -16,7 +17,7 @@ from globalData import GlobalData
#
#
logger = setup_logger("utils", level=logging.DEBUG)
GlobalData.update_from_redis()
# GlobalData.update_from_redis()
def check_standard_name_slot_probability_test():
slot_list = [{"constructionUnit": "合肥供电公司"},
@ -90,34 +91,34 @@ def standardize_team_leader_test():
def standardize_company_test():
test_cases = [
("送一分公司"),
("送二分公司"),
("变电分公司"),
("建筑分公司"),
("检修试验分公司"),
("宏源电力公司"),
("宏源电力限公司"),
("宏源电力限公司线路"),
("宏源电力限公司变电"),
("送一分"),
("送二分"),
("变电分"),
("建筑分"),
("检修试验分"),
("宏源电力"),
("红源电力"),
("宏源电力有限"),
("宏源电力限线路"),
("宏源电力限变电"),
("宋轶分公司"),
# ("送二分公司"),
# ("变电分公司"),
# ("建筑分公司"),
# ("检修试验分公司"),
# ("宏源电力公司"),
# ("宏源电力限公司"),
# ("宏源电力限公司线路"),
# ("宏源电力限公司变电"),
# ("送一分"),
# ("送二分"),
# ("变电分"),
# ("建筑分"),
# ("检修试验分"),
# ("宏源电力"),
# ("红源电力"),
# ("宏源电力有限"),
# ("宏源电力限线路"),
# ("宏源电力限变电"),
]
logger.info(f"加权混合策略 分公司名匹配**********************")
print(f"加权混合策略 分公司名匹配**********************")
start = time.perf_counter()
for item in test_cases:
match_results = standardize_sub_company(item,GlobalData.simply_to_standard_company_name_map, GlobalData.pinyin_simply_to_standard_company_name_map,70,90)
logger.info(f"加权混合策略 分公司名匹配 输入: {item}-> 输出: {match_results}")
print(f"加权混合策略 分公司名匹配 输入: {item}-> 输出: {match_results}")
end = time.perf_counter()
logger.info(f"加权混合策略 耗时: {end - start:.4f}")
print(f"加权混合策略 耗时: {end - start:.4f}")
def standardize_construction_test():
@ -134,8 +135,8 @@ def standardize_construction_test():
def standardize_project_test():
test_cases = [
("众兴500kV变电站220kV杜岗Ⅱ间隔改造工程(PROJ-2023-0435)"),
("众兴500kv变电站220kv杜岗ⅱ间隔改造工程(proj-2023-0435)")
("陶楼夏塘工程"),
("宗阳黄桥工程")
# ("合肥卫田变电站工程").
# ("金牛变电站新建建筑"),
# ("金牛变电站建筑工程"),
@ -178,14 +179,14 @@ def standardize_project_test():
# ("卫田-陶楼T接首业变电站110kV电缆线路工程(PROJ-2024-1236)"),
# ("谯城(亳三)-希夷220kV线路工程(PROJ-2024-1205)"),
]
logger.info(f"去不重要词汇 工程名匹配******************************************")
print(f"去不重要词汇 工程名匹配******************************************")
start = time.perf_counter()
for item in test_cases:
match_results = standardize_project_name(item, GlobalData.simply_to_standard_project_name_map,
GlobalData.pinyin_simply_to_standard_project_name_map, 70, 90)
logger.info(f"***************工程名匹配 输入: {item}-> 输出: {match_results}")
print(f"***************工程名匹配 输入: {item}-> 输出: {match_results}")
end = time.perf_counter()
logger.info(f"词集匹配 耗时: {end - start:.4f}")
print(f"词集匹配 耗时: {end - start:.4f}")
def standardize_program_test():
logger.info(f"项目名匹配******************************************")
@ -358,14 +359,226 @@ def get_size():
standardize_project_test()
unuselessStr = clean_useless_project_name("众兴杜岗ⅱ间隔改造")
print(f"众兴杜岗ⅱ间隔改造:{unuselessStr}")
unuselessStr = clean_useless_project_name("众兴杜岗Ⅱ间隔改造")
print(f"众兴杜岗Ⅱ间隔改造:{unuselessStr}")
print("今天的长度:",len("今天"))
def exact_hot_words():
save_standard_name_list_to_file(list(GlobalData.simply_to_standard_company_name_map.keys()),"./hot_word/company.txt")
#save_standard_name_list_to_file(list(GlobalData.simply_to_standard_project_name_map.keys()),"./hot_word/project.txt")
save_standard_name_list_to_file(list(GlobalData.simply_to_standard_constractor_name_map.keys()),"./hot_word/constractor.txt")
save_standard_name_list_to_file(list(GlobalData.simply_to_standard_construct_name_map.keys()),"./hot_word/construct.txt")
# def exact_project_hot_words():
# import re
#
# # 示例数据,换成 GlobalData.simply_to_standard_project_name_map.keys()
# project_names = [
# "安庆四500kV变电站新建工程(PROJ-2024-0862)",
# "淮南芦集 220 千伏变电站 220 千伏配电装置改造工程(调试部分)(PROJ-2025-0022)",
# "屏显220kV变电站220kV杜岗Ⅱ间隔改造工程调试部分(PROJ-2025-0169)",
# "漆园220kV变电站220kV杨柳间隔改造工程调试部分(PROJ-2025-0042)",
# "宝桥220kV变电站220kV红桥间隔保护改造工程(调试部分)(PROJ-2025-0088)",
# "蟠龙220kV变电站220kV灵泗间隔改造工程调试部分(PROJ-2025-0018)",
# "锦绣-常青π入中心变电站220kV架空线路工程(PROJ-2024-1206)",
# "安庆和平220kV变电站新建工程调试部分(PROJ-2024-1238)",
# "渝北±800千伏换流站电气安装A包(调试部分)(PROJ-2024-1192)",
# "安徽亳州华佗220kV变电站220kV新华风电间隔扩建工程调试部分(PROJ-2024-1171)",
# "先锋-泉河π入安庆四变电站220kV线路工程(PROJ-2024-0834)",
# "安徽滁州护桥220kV变电站2号主变扩建工程(PROJ-2024-0821)",
# "双岭500kV变电站间隔改造工程(PROJ-2024-0863)",
# "合州±800千伏换流站电气安装A包(PROJ-2025-0056)",
# "金牛500kV变电站新建工程(PROJ-2024-0866)",
# "况楼220kV变电站间隔扩建工程调试部分(PROJ-2025-0144)",
# "国网安徽合肥供电公司2023年GIS带电显示器维护(PROJ-2024-1260)",
# "亳州木兰220kV变电站220kV改造工程安徽亳州木兰200kV变电站GIS设备检修及调试技术服务(PROJ-2024-1256)",
# "香涧-鹭岛500kV线路工程淮河大跨越段(PROJ-2024-0722)",
# "安徽蚌埠濠州220kV变电站220千伏大唐凤阳红心镇光伏间隔扩建工程调试部分)(PROJ-2025-0164)",
# "陶楼-广银T接智迪改接首业变电站110kV电缆线路工程(PROJ-2024-1233)",
# "国网北京检修公司2024年±500kV延庆换流站直流主设备及辅助设备不停电检修维护(PROJ-2024-0841)"
# ]
#
#
# # 用来存关键词
# keywords = []
# #
# for name in GlobalData.simply_to_standard_project_name_map.keys():
# # 去掉括号和括号里的内容
# cleaned_name = re.sub(r"\(.*?\)", "", name)
#
# # 提取“-”连接的词
# if "-" in cleaned_name:
# parts = cleaned_name.split("-")
# first = parts[0].strip()
# second = re.split(r"[^\u4e00-\u9fa5]", parts[1])[0].strip() # 只取中文部分
# if first:
# keywords.append(first)
# if second:
# keywords.append(second)
# if first and second:
# keywords.append(first + second)
# else:
# # 正常提取,取第一个连续的中文词组
# match = re.match(r"([\u4e00-\u9fa5]+)", cleaned_name)
# if match:
# keywords.append(match.group(1))
#
# # 去重,且保持原顺序
# seen = set()
# unique_keywords = []
# for kw in keywords:
# if kw not in seen:
# seen.add(kw)
# unique_keywords.append(kw)
#
# # 写入到文件
# with open("new_project.txt", "w", encoding="utf-8") as f:
# for kw in unique_keywords:
# f.write(kw + "\n")
#
# print("提取完成,已写入 new_project.txt")
# # 去重且保持顺序
# seen = set()
# unique_keywords = []
# for kw in keywords:
# if kw not in seen:
# seen.add(kw)
# unique_keywords.append(kw)
#
# # 写入到文件
# with open("new_project.txt", "w", encoding="utf-8") as f:
# for kw in unique_keywords:
# f.write(kw + "\n")
#
# print("提取完成,已写入 new_project.txt")
def exact_project_hot_words():
import re
# 示例数据,换成 GlobalData.simply_to_standard_project_name_map.keys()
# project_names = [
# "安庆四500kV变电站新建工程(PROJ-2024-0862)",
# "淮南芦集 220 千伏变电站 220 千伏配电装置改造工程(调试部分)(PROJ-2025-0022)",
# "锦绣-常青π入中心变电站220kV架空线路工程(PROJ-2024-1206)",
# "渝北±800千伏换流站电气安装A包(调试部分)(PROJ-2024-1192)",
# "屏显220kV变电站220kV杜岗Ⅱ间隔改造工程调试部分(PROJ-2025-0169)",
# "漆园220kV变电站220kV杨柳间隔改造工程调试部分(PROJ-2025-0042)",
# "宝桥220kV变电站220kV红桥间隔保护改造工程(调试部分)(PROJ-2025-0088)",
# "蟠龙220kV变电站220kV灵泗间隔改造工程调试部分(PROJ-2025-0018)",
# "安庆和平220kV变电站新建工程调试部分(PROJ-2024-1238)",
# "安徽亳州华佗220kV变电站220kV新华风电间隔扩建工程调试部分(PROJ-2024-1171)",
# "先锋-泉河π入安庆四变电站220kV线路工程(PROJ-2024-0834)",
# "安徽滁州护桥220kV变电站2号主变扩建工程(PROJ-2024-0821)",
# "双岭500kV变电站间隔改造工程(PROJ-2024-0863)",
# "合州±800千伏换流站电气安装A包(PROJ-2025-0056)",
# "金牛500kV变电站新建工程(PROJ-2024-0866)",
# "况楼220kV变电站间隔扩建工程调试部分(PROJ-2025-0144)",
# "国网安徽合肥供电公司2023年GIS带电显示器维护(PROJ-2024-1260)",
# "亳州木兰220kV变电站220kV改造工程安徽亳州木兰200kV变电站GIS设备检修及调试技术服务(PROJ-2024-1256)",
# "香涧-鹭岛500kV线路工程淮河大跨越段(PROJ-2024-0722)",
# "安徽蚌埠濠州220kV变电站220千伏大唐凤阳红心镇光伏间隔扩建工程调试部分)(PROJ-2025-0164)",
# "陶楼-广银T接智迪改接首业变电站110kV电缆线路工程(PROJ-2024-1233)",
# "国网北京检修公司2024年±500kV延庆换流站直流主设备及辅助设备不停电检修维护(PROJ-2024-0841)"
# ]
# 排除含有这些词的项目名
exclude_keywords = ["国网", "公司"]
keywords = []
for name in list(GlobalData.simply_to_standard_project_name_map.values()): # 正式换成
print(f"name:{name}")
# 去掉括号及里面内容
cleaned_name = re.sub(r".*?|\(.*?\)", "", name)
# 如果包含排除关键词,跳过
if any(ek in cleaned_name for ek in exclude_keywords):
continue
# 处理有"-"连接的情况
if "-" in cleaned_name:
parts = cleaned_name.split("-")
# 切出来的每一段,也要去括号内容
part0 = re.sub(r".*?|\(.*?\)", "", parts[0])
part1 = re.sub(r".*?|\(.*?\)", "", parts[1])
first = re.match(r"[\u4e00-\u9fa5]+", part0)
second = re.match(r"[\u4e00-\u9fa5]+", part1)
first_word = first.group(0) if first else ""
second_word = second.group(0) if second else ""
if first_word:
keywords.append(first_word)
if second_word:
keywords.append(second_word)
if first_word and second_word:
keywords.append(first_word + second_word)
else:
# 没有"-",提取第一个连续中文
match = re.match(r"([\u4e00-\u9fa5]+)", cleaned_name)
if match:
word = match.group(1)
if word:
keywords.append(word)
# 去重且保持顺序
seen = set()
unique_keywords = []
for kw in keywords:
if kw not in seen:
seen.add(kw)
unique_keywords.append(kw)
# 写入文件
with open("new_project.txt", "w", encoding="utf-8") as f:
for kw in unique_keywords:
f.write(kw + "\n")
print("提取完成,已写入 new_project.txt")
def removte_reduant_list():
temp_list = load_standard_name_list("./new_project.txt")
# 去重且保持顺序
seen = set()
unique_keywords = []
for kw in temp_list:
if kw not in seen:
seen.add(kw)
unique_keywords.append(kw)
with open("hot_word/final_new_project.txt", "w", encoding="utf-8") as f:
for kw in unique_keywords:
f.write(kw + "\n")
save_standard_name_list_to_file(list(GlobalData.simply_to_standard_project_name_map.keys()),"./hot_word/project.txt")
def list_to_json():
import json
my_list = load_standard_name_list("./hot_word/final_new_project.txt")
data = {
"hotwordList": my_list
}
# 转换为 JSON 格式字符串ensure_ascii=False 确保中文正常显示)
json_str = json.dumps(data, ensure_ascii=False, indent=4)
save_dict_to_file(json_str,'./hot_word/final_new_project.json')
list_to_json()
# exact_hot_words()
# exact_project_hot_words()
# unuselessStr = clean_useless_project_name("众兴杜岗ⅱ间隔改造")
# print(f"众兴杜岗ⅱ间隔改造:{unuselessStr}")
# unuselessStr = clean_useless_project_name("众兴杜岗Ⅱ间隔改造")
# print(f"众兴杜岗Ⅱ间隔改造:{unuselessStr}")
# print("今天的长度:",len("今天"))
# standardize_program()
# history_message()
# standardize_project_test()
# standardize_company_test()
# standardize_team_leader_test()
#
# standardize_sub_constractor_test()

View File

@ -561,7 +561,7 @@ def check_lost(int_res, slot):
def check_standard_name_slot_probability(int_res, slot) -> tuple:
intention_list = {3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15,16}
intention_list = {3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26}
if int_res not in intention_list:
return CheckResult.NO_MATCH, ""

View File

@ -7,4 +7,5 @@ test: ./data/test.json # (可选) 测试集路径
nc: 13 # 目标类别数
labels: ["天气查询","互联网查询","页面切换","日计划数量查询","周计划数量查询","日计划作业内容","周计划作业内容",
"施工人数","作业考勤人数","知识问答","通用对话","作业面查询","班组人数查询","班组数查询","作业面内容",
"班组详情","工程进度查询"] # 类别名称
"班组详情","工程进度查询", "人员查询", "分公司查询","工程数量查询","工程详情查询","项目部数量查询",
"建管单位数量查询","建管单位详情","分包单位数量查询","分包单位详情"] # 类别名称

View File

@ -7,7 +7,7 @@ import numpy as np
import functools
from paddle.nn import CrossEntropyLoss
from paddlenlp.data import DataCollatorWithPadding
from paddlenlp.trainer import Trainer, TrainingArguments
from paddlenlp.trainer import Trainer, TrainingArguments, EarlyStoppingCallback
import os
from sklearn.metrics import precision_score, recall_score, f1_score
@ -116,10 +116,10 @@ def main():
output_dir="./output_temp",
evaluation_strategy="epoch",
save_strategy="epoch",
eval_steps=2000, # 每100步评估一次
save_steps=2000,
eval_steps=2000, # 每2000步评估一次evaluation_strategy="steps"时生效
save_steps=2000, # 每2000步保存一次save_strategy="steps"时生效
logging_dir="./logs",
logging_steps=100, # 每50步输出一次日志
logging_steps=100, # 每100步输出一次日志
num_train_epochs=10, # 训练轮数
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
@ -140,6 +140,7 @@ def main():
eval_dataset=test_ds,
data_collator=data_collator,
compute_metrics=compute_metrics, # 使用自定义的评估指标
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)
# 训练模型

View File

@ -78,7 +78,14 @@ BASE_DATA = {
"operatings": ["8+2工况", "8加2工况"],
# 页面切换
"pages": ["风险管控", "日计划", "周风险", "日计划统计报表", "日计划推送", "生产管控中心", "考勤统计详情",
"今日作业计划", "周风险统计报表", "周风险推送"]
"今日作业计划", "周风险统计报表", "周风险推送"],
# 具体人名
"person_names": ["何东洋", "李东","王孙强林"],
# 人名查询目标
"person_query_types": ["班组", "工程", "分公司", "实时组织", "项目部", "项目管理部"],
# 工程状态
"project_status_s": ["在建", "在作业", "在施工"]
}
@ -465,11 +472,11 @@ TEMPLATE_CONFIG = {
("{date}{construction_unit}{operating}的作业内容是什么?", ["date", "construction_unit","operating"]),
#分包单位
("{subcontractor}{date}作业内容是什么", ["subcontractor", "date"]),
("{date}{subcontractor}具体作业内容", ["date","subcontractor"]),
("{subcontractor}{date}作业内容", ["subcontractor", "date"]),
("{date}{subcontractor}具体作业内容", ["date","subcontractor"]),
("{date}分包单位为{subcontractor}有哪些{risk_level}风险作业计划?", ["date", "subcontractor", "risk_level"]),
("{date}{subcontractor}风险等级为{risk_level}的作业计划是什么", ["date", "subcontractor", "risk_level"]),
("{date}{subcontractor}{operating}的作业内容是什么?", ["date", "subcontractor","operating"]),
("{date}{subcontractor}风险等级为{risk_level}的作业计划", ["date", "subcontractor", "risk_level"]),
("{date}{subcontractor}{operating}具体作业计划", ["date", "subcontractor","operating"]),
]
},
@ -933,7 +940,7 @@ TEMPLATE_CONFIG = {
# 1. 查询特定日期和项目的作业安排
("{date}{project_name}作业面是什么?", ["date", "project_name"]),
("{date}属于{operating}作业面内容是什么?", ["date", "operating"]),
("{date}存在{operating}作业面是什么?", ["date", "operating"]),
("{date}存在{operating}作业面", ["date", "operating"]),
# 3. 查询特定日期和项目类型的工程计划
("{date}{project_type}类具体作业面有哪些?", ["date", "project_type"]),
@ -994,7 +1001,7 @@ TEMPLATE_CONFIG = {
("{date}{construction_area}地区{operating}具体作业面是什么", ["date", "construction_area","operating"]),
#建管单位
("{construction_unit}{date}具体作业面内容是什么", ["construction_unit", "date"]),
("{construction_unit}{date}具体作业面内容", ["construction_unit", "date"]),
("{date}{construction_unit}具体作业面有哪些", ["date", "construction_unit",]),
("{date}建管单位为{construction_unit}作业面是什么?", ["date", "construction_unit"]),
("{date}{construction_unit}风险等级为{risk_level}两项作业面分别是什么?", ["date", "construction_unit", "risk_level"]),
@ -1003,9 +1010,9 @@ TEMPLATE_CONFIG = {
#分包单位
("{subcontractor}{date}具体作业面内容是什么", ["subcontractor", "date"]),
("{date}{subcontractor}具体作业面有哪些", ["date","subcontractor"]),
("{date}{subcontractor}具体作业面", ["date","subcontractor"]),
("{date}分包单位为{subcontractor}{risk_level}作业面是什么", ["date", "subcontractor", "risk_level"]),
("{date}{subcontractor}风险等级为{risk_level}两项作业面分别是什么?", ["date", "subcontractor", "risk_level"]),
("{date}{subcontractor}风险等级为{risk_level}两项作业面", ["date", "subcontractor", "risk_level"]),
("{date}{subcontractor}{operating}作业面内容有哪些?", ["date", "subcontractor","operating"]),
]
},
@ -1111,7 +1118,244 @@ TEMPLATE_CONFIG = {
("{project_name}今日工程进度情况", ["project_name"]),
("{project_name}今天工程进展怎么样?", ["project_name"]),
]
}
},
"人员查询": {
"date": ["今天","最近"],
"templates": [
("{person_name}在哪个{person_query_type}", ["person_name","person_query_type"]),
("请帮我查一下{person_name}在哪个{person_query_type}", ["person_name","person_query_type"]),
]
},
"分公司查询": {
"date": ["今天","最近"],
"templates": [
("安徽送变电公司有多少分公司", []),
("公司有多少分公司", []),
("请查一下分公司数量", []),
("分公司数量", []),
("实施组织数量", []),
("分公司详情", []),
("实施组织详情", []),
("公司有哪些分公司", []),
("公司有哪些实施组织", []),
("{implementation_organization}详情", ["implementation_organization"]),
("{implementation_organization}情况", ["implementation_organization"]),
("请帮我查一下具体分公司详情", []),
("请帮我查一下具体实施组织详情", []),
("请帮我查一下具体{implementation_organization}详情", ["implementation_organization"]),
]
},
# "工程数量查询": {
# "date": ["今日", "昨日", "2024年5月24日", "5月24日", "今天", "昨天"],
# "templates": [
# #公司
# ("{date}公司有多少工程", ["date"]),
# ("{date}安徽送变电公司有多少工程{project_status}", ["date", "project_status"]),
# #分公司和项目部
# ("{implementation_organization}{date}有多少工程{project_status}",
# ["implementation_organization", "date", "project_status"]),
# ("{implementation_organization}{project_department}{date}有多少工程{project_status}",
# ["implementation_organization", "project_department", "date", "project_status"]),
# #建管区域和单位
# ("{construction_area}地区{date}风险等级为{risk_level}有多少工程?", ["construction_area", "date", "risk_level"]),
#
# ("{construction_area}地区{date}有多少工程{project_status}", ["construction_area", "date", "project_status"]),
#
# ("{construction_unit}{date}有多少工程{project_status}", ["construction_unit", "date","project_status"]),
#
# #分包商
# ("{date}{subcontractor}有多少工程{project_status}", ["date", "subcontractor", "project_status"]),
# ("{date}送变电公司{project_department}有多少工程?", ["date", "project_department"]),
# #项目经理
# ("{date}{project_manager}有多少工程{project_status}", ["date", "project_manager","project_status"]),
# #班组名称
# ("{team_leader}{date}有多少工程", ["team_leader", "date"]),
# #工程性质
# ("公司{date}{project_type}的工程有多少?", ["date", "project_type"]),
# #风险等级
# ("公司{date}{risk_level}风险的{project_status}工程有多少?", ["date", "risk_level", "project_status"]),
# #询问工程数量时有工程性质和风险等级吗
# ]
# },
#
# "工程详情查询": {
# "date": ["今日", "昨日", "2024年5月24日", "5月24日", "今天", "昨天"],
# "templates": [
# #公司
# ("{date}公司有哪些工程", ["date"]),
# ("{date}安徽送变电公司有哪些工程{project_status}", ["date", "project_status"]),
# #分公司和项目部
# ("{implementation_organization}{date}工程详情{project_status}",
# ["implementation_organization", "date", "project_status"]),
# ("{implementation_organization}{project_department}{date}有哪些工程{project_status}",
# ["implementation_organization", "project_department", "date", "project_status"]),
# #建管区域和单位
# ("{construction_area}地区{date}风险等级为{risk_level}工程具体情况?", ["construction_area", "date", "risk_level"]),
#
# ("{construction_area}地区{date}有哪些工程{project_status}", ["construction_area", "date", "project_status"]),
#
# ("{construction_unit}{date}有哪些工程{project_status}", ["construction_unit", "date","project_status"]),
#
# #分包商
# ("{date}{subcontractor}有多少工程{project_status}", ["date", "subcontractor", "project_status"]),
# ("{date}送变电公司{project_department}工程详情?", ["date", "project_department"]),
# #项目经理
# ("{date}{project_manager}有多少工程{project_status}", ["date", "project_manager","project_status"]),
# #班组名称
# ("{team_leader}{date}工程具体情况", ["team_leader", "date"]),
# #工程性质
# ("公司{date}{project_type}的工程有哪些?", ["date", "project_type"]),
# #风险等级
# ("公司{date}{risk_level}风险的{project_status}工程有那些?", ["date", "risk_level", "project_status"]),
# #询问工程数量时有工程性质和风险等级吗
# ]
# },
"工程数量查询": {
"date": ["今天","最近"],
"templates": [
#公司
("公司有多少工程", []),
("安徽送变电公司有多少工程{project_status}", ["project_status"]),
#分公司和项目部
("{implementation_organization}有多少工程{project_status}",
["implementation_organization", "project_status"]),
("{implementation_organization}{project_department}有多少工程{project_status}",
["implementation_organization", "project_department", "project_status"]),
#建管区域和单位
("{construction_area}地区风险等级为{risk_level}有多少工程?", ["construction_area", "risk_level"]),
("{construction_area}地区有多少工程{project_status}", ["construction_area", "project_status"]),
("{construction_unit}有多少工程{project_status}", ["construction_unit","project_status"]),
#分包商
("{subcontractor}有多少工程{project_status}", ["subcontractor", "project_status"]),
("安徽送变电公司{project_department}有多少工程?", ["project_department"]),
#项目经理
("{project_manager}有多少工程{project_status}", ["project_manager","project_status"]),
#班组名称
("{team_leader}有多少{project_status}工程", ["team_leader", "project_status"]),
#工程性质
("公司{project_type}类的工程有多少?", ["project_type"]),
#风险等级
("公司{risk_level}风险的{project_status}工程有多少?", ["risk_level", "project_status"]),
#询问工程数量时有工程性质和风险等级吗
]
},
"工程详情查询": {
"date": ["今天","最近"],
"templates": [
#公司
("公司有哪些工程", []),
("截止目前公司有哪些{project_status}工程", ["project_status"]),
("安徽送变电公司有哪些工程{project_status}", ["project_status"]),
#分公司和项目部
("{implementation_organization}工程详情{project_status}",
["implementation_organization", "project_status"]),
("{implementation_organization}{project_department}有哪些工程{project_status}",
["implementation_organization", "project_department", "project_status"]),
#建管区域和单位
("{construction_area}地区风险等级为{risk_level}工程具体情况?", ["construction_area", "risk_level"]),
("{construction_area}地区有哪些{project_status}工程?", ["construction_area", "project_status"]),
("{construction_unit}有哪些工程{project_status}", ["construction_unit","project_status"]),
#分包商
("{subcontractor}有多少{project_status}工程", ["subcontractor", "project_status"]),
("送变电公司{project_department}工程详情?", ["project_department"]),
#项目经理
("{project_manager}有多少工程{project_status}", ["project_manager","project_status"]),
#班组名称
("{team_leader}工程具体情况", ["team_leader"]),
#工程性质
("公司{project_type}类的工程有哪些?", ["project_type"]),
#风险等级
("公司{risk_level}风险的{project_status}工程有那些?", ["risk_level", "project_status"]),
#询问工程数量时有工程性质和风险等级吗
]
},
"项目部数量查询": {
"date": ["今天","最近"],
"templates": [
#公司
("公司有多少项目部", []),
("安徽送变电公司有多少项目管理部", []),
#分公司
("{implementation_organization}有多少项目部", ["implementation_organization"]),
("{implementation_organization}有多少项目管理部", ["implementation_organization"]),
("{implementation_organization}项目管理部的数量", ["implementation_organization"]),
#请帮我查一下
("请帮我查一下公司项目部数量", []),
("请帮我查一下{implementation_organization}有多少项目管理部",
["implementation_organization"]),
]
},
"项目部详情": {
"date": ["今天","最近"],
"templates": [
#公司
("公司有哪些项目部", []),
("安徽送变电公司项目管理部详情", []),
#分公司
("{implementation_organization}项目部详情", ["implementation_organization"]),
("{implementation_organization}有哪些项目管理部", ["implementation_organization"]),
#请帮我查一下
("请帮我查一下公司项目部详情", []),
("请帮我查一下{implementation_organization}有哪些项目管理部",
["implementation_organization"]),
]
},
"建管单位数量查询": {
"date": ["今天","最近"],
"templates": [
#公司
("公司有多少建管单位", []),
("安徽送变电公司有多少建管单位", []),
]
},
"建管单位详情": {
"date": ["今天","最近"],
"templates": [
#公司
("{project_name}建管单位情况", ["project_name"]),
("{project_name}建管单位详情", ["project_name"]),
("请介绍下{construction_unit}详情", ["construction_unit"]),
("请介绍下{construction_unit}情况", ["construction_unit"]),
]
},
"分包单位数量查询": {
"date": ["今天","最近"],
"templates": [
#公司
("公司有多少分包单位", []),
("公司有多少分包商", []),
("安徽送变电公司有多少分包单位", []),
("安徽送变电公司有多少分包商", []),
]
},
"分包单位详情": {
"date": ["今天","最近"],
"templates": [
#公司
("{project_name}分包单位详情", ["project_name"]),
("{project_name}分包商情况", ["project_name"]),
("请介绍下{subcontractor}详情", ["subcontractor"]),
("请介绍下{subcontractor}情况", ["subcontractor"]),
]
},
}
@ -1133,6 +1377,9 @@ def generate_natural_samples(config, label):
"operating": BASE_DATA["operatings"],
"team_name": BASE_DATA["team_names"],
"construction_area": BASE_DATA["construction_areas"],
"person_name": BASE_DATA["person_names"],
"person_query_type": BASE_DATA["person_query_types"],
"project_status": BASE_DATA["project_status_s"],
}
for template, variables in config["templates"]:

View File

@ -25,7 +25,9 @@ def merge_json_files(file_list, output_file):
files = ['互联网查询.json','天气查询.json','知识问答.json','作业考勤人数.json', '周计划作业内容.json',
'周计划数量查询.json','施工人数.json','日计划作业内容.json','日计划数量查询.json',
'页面切换.json','通用对话.json','作业面查询.json','班组人数查询.json','班组数查询.json','作业面内容.json',
'班组详情.json','工程进度查询.json']
'班组详情.json','工程进度查询.json','人员查询.json','分公司查询.json','工程数量查询.json',
'工程详情查询.json','项目部数量查询.json','建管单位数量查询.json','建管单位详情.json','分包单位数量查询.json',
'分包单位详情.json']
output_file = 'output/merged_data.json'
# 执行合并

View File

@ -2,22 +2,25 @@ import json
import paddle
from paddlenlp.datasets import MapDataset
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer
from paddlenlp.trainer import Trainer, TrainingArguments
from paddlenlp.trainer import Trainer, TrainingArguments, EarlyStoppingCallback
from paddlenlp.data import DataCollatorForTokenClassification
# === 1. 加载数据 ===
def load_dataset(data_path):
with open(data_path, "r", encoding="utf-8") as f:
data = json.load(f)
return MapDataset(data)
# === 2. 预处理数据 ===
def preprocess_function(example, tokenizer):
# 预定义实体类型列表
entity_types = [
'date', 'project_name', 'project_type', 'construction_unit',
'implementation_organization', 'project_department', 'project_manager',
'subcontractor', 'team_leader', 'risk_level', 'page', 'operating', 'team_name', 'construction_area'
'subcontractor', 'team_leader', 'risk_level', 'page', 'operating', 'team_name',
'construction_area', 'person_name', 'person_query_type', 'project_status'
]
# 文本 Tokenization
@ -59,7 +62,7 @@ def preprocess_function(example, tokenizer):
# === 3. 加载 UIE 预训练模型 ===
model = ErnieForTokenClassification.from_pretrained("uie-base", num_classes=29) # 3 类 (O, B, I)
lsmodel = ErnieForTokenClassification.from_pretrained("uie-base", num_classes=35) # 3 类 (O, B, I)
tokenizer = ErnieTokenizer.from_pretrained("uie-base")
# === 4. 加载数据集 ===
@ -70,7 +73,6 @@ dev_dataset = load_dataset("data/val.json") # 验证数据集
train_dataset = train_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False)
dev_dataset = dev_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False)
# === 6. 数据整理 ===
data_collator = DataCollatorForTokenClassification(tokenizer, padding=True)
@ -87,13 +89,30 @@ training_args = TrainingArguments(
save_total_limit=1, # 只保留最新 2 个模型
logging_dir="./logs",
logging_steps=100,
eval_steps=5000,
save_steps=5000,
eval_steps=5000, #evaluation_strategy="steps"时生效
save_steps=5000, #save_strategy="steps"时生效
seed=1000,
load_best_model_at_end=True,
)
# === 8. 训练 ===
# === 8. 创建 EarlyStoppingCallback 实例 ===
early_stopping_callback = EarlyStoppingCallback(
early_stopping_patience=2, # 连续多少次评估不提升就停
early_stopping_threshold=0.01 # 最小提升幅度例如设为0.01表示至少提升1%
)
def compute_metrics(eval_preds):
predictions, labels = eval_preds
preds = predictions.argmax(axis=-1)
correct = (preds == labels).astype(int)
accuracy = correct.sum() / correct.size
return {"accuracy": accuracy}
training_args.metric_for_best_model = "accuracy"
# === 9. 训练 ===
trainer = Trainer(
model=model,
args=training_args,
@ -101,11 +120,12 @@ trainer = Trainer(
eval_dataset=dev_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
callbacks=[early_stopping_callback], # 添加 EarlyStopping 回调
)
trainer.train()
# 为模型定义输入规格
input_spec = [
paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="input_ids"),
@ -113,4 +133,3 @@ input_spec = [
paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="position_ids"),
paddle.static.InputSpec(shape=[None, 512], dtype="float32", name="attention_mask")
]