重构模型训练

This commit is contained in:
jiang 2025-03-16 14:40:56 +08:00
parent cd4dcd5429
commit ee0380545f
15 changed files with 4770 additions and 144 deletions

View File

@ -13,7 +13,7 @@ from constants import PROJECT_NAME, PROJECT_DEPARTMENT, SIMILARITY_VALUE
from config import *
# 常量
MODEL_ERNIE_PATH = R"E:\workingSpace\PycharmProjects\Intention_dev\ernie\output\checkpoint-4160"
MODEL_UIE_PATH = R"E:\workingSpace\PycharmProjects\Intention_dev\uie\output\checkpoint-1740"
MODEL_UIE_PATH = R"E:\workingSpace\PycharmProjects\Intention_dev\uie\output\checkpoint-2430"
# 类别名称列表
labels = [
@ -24,17 +24,19 @@ labels = [
# 标签映射
label_map = {
0: 'O', # 非实体
1: 'B-date', 12: 'I-date',
2: 'B-project_name', 13: 'I-project_name',
3: 'B-project_type', 14: 'I-project_type',
4: 'B-construction_unit', 15: 'I-construction_unit',
5: 'B-implementation_organization', 16: 'I-implementation_organization',
6: 'B-project_department', 17: 'I-project_department',
7: 'B-project_manager', 18: 'I-project_manager',
8: 'B-subcontractor', 19: 'I-subcontractor',
9: 'B-team_leader', 20: 'I-team_leader',
10: 'B-risk_level', 21: 'I-risk_level',
11: 'B-page', 22: 'I-page',
1: 'B-date', 14: 'I-date',
2: 'B-projectName', 15: 'I-projectName',
3: 'B-projectType', 16: 'I-projectType',
4: 'B-constructionUnit', 17: 'I-constructionUnit',
5: 'B-implementationOrganization', 18: 'I-implementationOrganization',
6: 'B-projectDepartment', 19: 'I-projectDepartment',
7: 'B-projectManager', 20: 'I-projectManager',
8: 'B-subcontractor', 21: 'I-subcontractor',
9: 'B-teamLeader', 22: 'I-teamLeader',
10: 'B-riskLevel', 23: 'I-riskLevel',
11: 'B-page', 24: 'I-page',
12: 'B-operating', 25: 'I-operating',
13: 'B-teamName', 26: 'I-teamName',
}
# 初始化工具类

16
ernie/1.py Normal file
View File

@ -0,0 +1,16 @@
import pandas as pd
import json
# 读取 Excel 文件
excel_file = r"D:\bonus\Desktop\问题.xlsx"
df = pd.read_excel(excel_file)
# 只保留 text 和 prompt并转换格式
json_data = [{"text": row["问题"], "label": "知识问答"} for _, row in df.iterrows()]
# 保存为 JSON 文件
json_file = "知识问答.json"
with open(json_file, "w", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False, indent=4)
print(f"Excel 数据已转换为 JSON 并保存到 {json_file}")

View File

@ -14,7 +14,7 @@ model = ErnieForSequenceClassification.from_pretrained(R"E:\workingSpace\Pycharm
tokenizer = ErnieTokenizer.from_pretrained(R"E:\workingSpace\PycharmProjects\Intention_dev\ernie\output\checkpoint-4160")
# 创建输入示例
text = "胡彬项目经理上一周作业内容是什么"
text = "宇宙中发现的第一个脉冲星是由谁发现的"
inputs = tokenizer(text, max_length=256, truncation=True, padding='max_length', return_tensors="pd")
# 将输入数据转化为 Paddle tensor 格式
@ -34,3 +34,4 @@ max_prob_value = np.max(probabilities.numpy(), axis=-1) # 获取最大概率值
# 根据预测的标签索引映射到类别名称
predicted_label = labels[max_prob_idx[0]] # 根据索引获取对应的标签
predicted_probability = max_prob_value[0] # 获取最大概率值
print(predicted_label, predicted_probability)

View File

@ -1,7 +1,7 @@
import json
# 读取 text 文件
with open("data/test.txt", "r", encoding="utf-8") as f:
with open("data/train.txt", "r", encoding="utf-8") as f:
data = f.readlines() # 按行读取
# 解析数据
@ -17,8 +17,7 @@ for line in data:
json_output = json.dumps(json_list, ensure_ascii=False, indent=4)
# 保存到 JSON 文件
with open("data/test.json", "w", encoding="utf-8") as f:
with open("data1/train.json", "w", encoding="utf-8") as f:
f.write(json_output)
# 打印 JSON 结果
print(json_output)

View File

@ -0,0 +1,282 @@
[
{
"text": "当前俄罗斯与乌克兰的局势如何?",
"label": "互联网查询"
},
{
"text": "最近中国经济增长的最新数据是什么?",
"label": "互联网查询"
},
{
"text": "中国和欧盟的贸易关系目前怎么样?",
"label": "互联网查询"
},
{
"text": "美国总统大选的最新动态是什么?",
"label": "互联网查询"
},
{
"text": "最近中东局势有没有新的变化?",
"label": "互联网查询"
},
{
"text": "当前全球通胀情况如何?",
"label": "互联网查询"
},
{
"text": "最近联合国有什么重要决议?",
"label": "互联网查询"
},
{
"text": "中印边境局势最近有什么新进展?",
"label": "互联网查询"
},
{
"text": "全球气候变化会议有哪些新的决策?",
"label": "互联网查询"
},
{
"text": "当前中国房产市场的最新趋势是什么?",
"label": "互联网查询"
},
{
"text": "最近中国对台政策有何新动向?",
"label": "互联网查询"
},
{
"text": "北约最近有哪些重要的军事动态?",
"label": "互联网查询"
},
{
"text": "中国最近在科技领域有哪些突破?",
"label": "互联网查询"
},
{
"text": "全球供应链危机目前的状况如何?",
"label": "互联网查询"
},
{
"text": "中国最近有哪些新的外交举措?",
"label": "互联网查询"
},
{
"text": "最近美日韩三国的关系如何?",
"label": "互联网查询"
},
{
"text": "当前中国对新能源汽车的政策是什么?",
"label": "互联网查询"
},
{
"text": "全球能源价格目前的走势如何?",
"label": "互联网查询"
},
{
"text": "中国在人工智能领域的最新进展是什么?",
"label": "互联网查询"
},
{
"text": "东南亚局势最近有没有新的变化?",
"label": "互联网查询"
},
{
"text": "当前全球移民政策的最新动向是什么?",
"label": "互联网查询"
},
{
"text": "日本近期的经济政策有哪些变化?",
"label": "互联网查询"
},
{
"text": "中国最新的军事实力发展情况如何?",
"label": "互联网查询"
},
{
"text": "目前全球粮食安全形势如何?",
"label": "互联网查询"
},
{
"text": "中国政府近期对房地产行业有哪些新规?",
"label": "互联网查询"
},
{
"text": "当前国际原油价格走势如何?",
"label": "互联网查询"
},
{
"text": "最近中美科技竞争的最新动态是什么?",
"label": "互联网查询"
},
{
"text": "中国最近有哪些重大基建项目开工?",
"label": "互联网查询"
},
{
"text": "当前非洲国家的经济发展趋势如何?",
"label": "互联网查询"
},
{
"text": "最近中国在太空探索方面有哪些新进展?",
"label": "互联网查询"
},
{
"text": "关于百度的最新新闻是什么?",
"label": "互联网查询"
},
{
"text": "关于美国加州大火的最新消息是什么?",
"label": "互联网查询"
},
{
"text": "关于人工智能的最新发展是什么?",
"label": "互联网查询"
},
{
"text": "最新的北京教育政策变化是什么",
"label": "互联网查询"
},
{
"text": "最新的北京旅游景点推荐。",
"label": "互联网查询"
},
{
"text": "最新的NBA比赛结果在哪里可以查看",
"label": "互联网查询"
},
{
"text": "2025年最新的科技新闻。",
"label": "互联网查询"
},
{
"text": "2025年最新的股票市场行情是什么",
"label": "互联网查询"
},
{
"text": "量子计算取得突破性进展是什么",
"label": "互联网查询"
},
{
"text": "人工智能国内的最新进展是什么",
"label": "互联网查询"
},
{
"text": "中国机器人最好的厂家有哪些",
"label": "互联网查询"
},
{
"text": "截止目前世界最富有的国家是哪一个",
"label": "互联网查询"
},
{
"text": "现在中非关系怎么样",
"label": "互联网查询"
},
{
"text": "现在世界格局是怎样的",
"label": "互联网查询"
},
{
"text": "现在中美关系的怎么样",
"label": "互联网查询"
},
{
"text": "现在欧盟有哪些国家",
"label": "互联网查询"
},
{
"text": "现在中国的世界关系是怎样的",
"label": "互联网查询"
},
{
"text": "哪吒闹海这个电影的放映时间是什么",
"label": "互联网查询"
},
{
"text": "李晨这个明星的最新动态是什么",
"label": "互联网查询"
},
{
"text": "当前俄乌冲突的最新进展如何?",
"label": "互联网查询"
},
{
"text": "巴以冲突的现状如何?",
"label": "互联网查询"
},
{
"text": "全球经济形势目前如何?",
"label": "互联网查询"
},
{
"text": "美国对华最新政策有哪些变化?",
"label": "互联网查询"
},
{
"text": "中国经济增长率最新数据是多少?",
"label": "互联网查询"
},
{
"text": "近期有哪些重要的国际峰会?",
"label": "互联网查询"
},
{
"text": "中国和欧盟的关系如何?",
"label": "互联网查询"
},
{
"text": "当前中东局势如何发展?",
"label": "互联网查询"
},
{
"text": "美国大选的最新动态是什么?",
"label": "互联网查询"
},
{
"text": "中国在人工智能领域有哪些最新进展?",
"label": "互联网查询"
},
{
"text": "中国和东盟国家的合作现状如何?",
"label": "互联网查询"
},
{
"text": "全球气候变化的最新趋势是什么?",
"label": "互联网查询"
},
{
"text": "当前全球通胀水平如何?",
"label": "互联网查询"
},
{
"text": "北约近期有哪些新动态?",
"label": "互联网查询"
},
{
"text": "中国与非洲国家的最新合作进展如何?",
"label": "互联网查询"
},
{
"text": "世界能源市场当前状况如何?",
"label": "互联网查询"
},
{
"text": "中日韩关系近期有哪些变化?",
"label": "互联网查询"
},
{
"text": "当前全球供应链的恢复情况如何?",
"label": "互联网查询"
},
{
"text": "最近的国际热点事件有哪些?",
"label": "互联网查询"
},
{
"text": "中国在国际事务中的影响力如何变化?",
"label": "互联网查询"
},
{
"text": "最近的新闻有那些说一下?",
"label": "互联网查询"
}
]

View File

@ -0,0 +1,266 @@
[
{
"text": "今天下午亳州的降水量大概有多少?",
"label": "天气查询"
},
{
"text": "明天上午池州的天气预报是否有雨?",
"label": "天气查询"
},
{
"text": "今天晚上宣城的降雨量预计有多少?",
"label": "天气查询"
},
{
"text": "今天傍晚合肥的天气预报显示会有雷阵雨吗?",
"label": "天气查询"
},
{
"text": "明天铜陵的天气适合登山吗?",
"label": "天气查询"
},
{
"text": "宿州明天的气温区间是多少?",
"label": "天气查询"
},
{
"text": "今天南京的天气怎么样?",
"label": "天气查询"
},
{
"text": "明天蚌埠的天气是否有明显变化?",
"label": "天气查询"
},
{
"text": "未来三天淮南的降水频率是多少?",
"label": "天气查询"
},
{
"text": "合肥本周的最高气温是多少?",
"label": "天气查询"
},
{
"text": "明天芜湖会刮风吗?",
"label": "天气查询"
},
{
"text": "未来一周内阜阳的空气质量会改善吗?",
"label": "天气查询"
},
{
"text": "本周五亳州的天气会有强风吗?",
"label": "天气查询"
},
{
"text": "未来三天巢湖的天气预报是什么?",
"label": "天气查询"
},
{
"text": "六安本周的降水量预计是多少?",
"label": "天气查询"
},
{
"text": "合肥今天体感温度是多少?",
"label": "天气查询"
},
{
"text": "今天合肥的天气怎么样?",
"label": "天气查询"
},
{
"text": "本周安庆的湿度会达到多少?",
"label": "天气查询"
},
{
"text": "明天合肥会有雨吗?",
"label": "天气查询"
},
{
"text": "明天巢湖的日照强度预计如何?",
"label": "天气查询"
},
{
"text": "明天宿州的天气是否适合进行户外露营?",
"label": "天气查询"
},
{
"text": "本周滁州的天气适合出行吗?",
"label": "天气查询"
},
{
"text": "明天宿州的天气会转晴吗?",
"label": "天气查询"
},
{
"text": "本周蚌埠的降雨概率有多大?",
"label": "天气查询"
},
{
"text": "合肥明天早上的风速大约是多少?",
"label": "天气查询"
},
{
"text": "本周芜湖的阳光时长有多少?",
"label": "天气查询"
},
{
"text": "明天安庆的天气怎样",
"label": "天气查询"
},
{
"text": "本周六黄山的天气如何?",
"label": "天气查询"
},
{
"text": "明天阜阳的能见度如何?",
"label": "天气查询"
},
{
"text": "宿州明天的平均气温是多少?",
"label": "天气查询"
},
{
"text": "今天池州的体感温度会很低吗?",
"label": "天气查询"
},
{
"text": "未来两天芜湖的最高气温是多少?",
"label": "天气查询"
},
{
"text": "明天宣城的气温相比今天会升高吗?",
"label": "天气查询"
},
{
"text": "今天淮南的最高气温是多少?",
"label": "天气查询"
},
{
"text": "本周合肥会有降雨吗?",
"label": "天气查询"
},
{
"text": "今晚铜陵会出现大风天气吗?",
"label": "天气查询"
},
{
"text": "今天晚上阜阳有可能出现冰雹吗?",
"label": "天气查询"
},
{
"text": "2025年清明节巢湖的天气如何",
"label": "天气查询"
},
{
"text": "今天下午安庆有可能出现强降雨吗?",
"label": "天气查询"
},
{
"text": "今晚滁州的最低气温是多少?",
"label": "天气查询"
},
{
"text": "明天马鞍山的空气质量指数是多少?",
"label": "天气查询"
},
{
"text": "今天黄山的山顶温度会低于零度吗?",
"label": "天气查询"
},
{
"text": "本周五晚上阜阳的气温预计最低会降到多少度?",
"label": "天气查询"
},
{
"text": "未来一周内蚌埠的天气如何?",
"label": "天气查询"
},
{
"text": "今天下午马鞍山的空气质量如何?",
"label": "天气查询"
},
{
"text": "今天晚上滁州有雷雨吗?",
"label": "天气查询"
},
{
"text": "明天滁州的气温相比今天是否更高?",
"label": "天气查询"
},
{
"text": "今天晚上宿州的温度会低于10°C吗",
"label": "天气查询"
},
{
"text": "淮北明天的紫外线指数高吗?",
"label": "天气查询"
},
{
"text": "本周蚌埠的气温最低会达到多少?",
"label": "天气查询"
},
{
"text": "本周合肥的天气怎么样?",
"label": "天气查询"
},
{
"text": "2025年春节前合肥最低气温多少",
"label": "天气查询"
},
{
"text": "今天下午巢湖的空气湿度是多少?",
"label": "天气查询"
},
{
"text": "接下来一个月亳州的天气是什么样的",
"label": "天气查询"
},
{
"text": "明天六安会出现霜冻现象吗?",
"label": "天气查询"
},
{
"text": "六安今天中午会下小雨吗?",
"label": "天气查询"
},
{
"text": "本周合肥的风向主要是什么?",
"label": "天气查询"
},
{
"text": "本周安庆的最低气温预计出现在哪一天?",
"label": "天气查询"
},
{
"text": "本周黄山的气温波动大吗?",
"label": "天气查询"
},
{
"text": "今天晚上合肥的风力会达到几级?",
"label": "天气查询"
},
{
"text": "本周五安庆的天气适合户外活动吗?",
"label": "天气查询"
},
{
"text": "本周合肥的最低气温预计出现在哪一天?",
"label": "天气查询"
},
{
"text": "今天晚上铜陵会出现大雾天气吗?",
"label": "天气查询"
},
{
"text": "淮北明天的体感温度是否较高?",
"label": "天气查询"
},
{
"text": "本周亳州的气温是否一直偏低?",
"label": "天气查询"
},
{
"text": "明天安徽省宿州市的日出和日落时间分别是什么时候?",
"label": "天气查询"
}
]

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,7 @@
import json
import os
from itertools import product
# 目录路径
directory = "data"
@ -9,69 +10,101 @@ if not os.path.exists(directory):
os.makedirs(directory)
# 基础数据定义
BASE_DATA = {
"implementation_organizations": ["送电一分公司", "送电二分公司", "变电分公司", "建筑分公司", "消防分公司"],
# 实施组织
"implementation_organizations": ["送电一分公司", "送电二分公司", "变电分公司", "消防分公司"],
# 工程性质
"project_types": ["基建", "技改大修", "用户工程", "小型基建"],
# 工程名
"project_names": [
"国网北京检修公司2024年±500kV延庆换流站直流主设备年度检修维护",
"合肥二电厂-彭郢π入长临河变电站220kV线路工程"
"1号工程",
"淮南芦集改造工程",
"第十号工程",
"合肥二电厂220kV线路工程",
"九号工程",
],
"construction_units": ["国网安徽省电力有限公司建设分公司", "国网安徽省电力有限公司马鞍山供电公司", "中铁二局集团电务工程有限公司"],
"project_departments": ["第九项目管理部", "第十一项目管理部", "第八项目管理部"],
# 建管单位
"construction_units": ["国网安徽省电力有限公司建设分公司", "国网安徽省电力有限公司马鞍山供电公司",
"中铁二局集团电务工程有限公司"],
# 项目部名称
"project_departments": ["第9项目管理部", "第十一项目部", "第八项目管理部", "9号项目部"],
# 项目经理
"project_managers": ["陈少平项目经理", "范文立项目经理", "何东洋项目经理"],
# 分包单位
"subcontractors": ["安徽劦力建筑装饰有限责任公司", "安徽苏亚建设集团有限公司"],
# 班组名称
"team_names": ["张朵班组", "刘梁玉班组", "魏玉龙班组"],
# 班组长
"team_leaders": ["李元帅班组长", "刘雨豪班组长"],
"risk_levels": ["1级", "2级", "3级", "4级", "5级"],
"pages": ["风险管控", "日计划", "周风险" ,"日计划统计报表","日计划推送"],
"operatings": ["8+2工况","8加2工况"]
# 风险等级
"risk_levels": ["1级", "一级", "二级", "5级", "四级"],
# 8+2工况
"operatings": ["8+2工况", "8加2工况"],
# 页面切换
"pages": ["风险管控", "日计划", "周风险", "日计划统计报表", "日计划推送"]
}
# 自然语言模板配置
TEMPLATE_CONFIG = {
"日计划数量查询": {
"date": ["今日", "昨日", "2024年5月24日", "5月24日","今天","昨天"],
"date": ["今日", "昨日", "2024年5月24日", "5月24日", "今天", "昨天"],
"templates": [
("{date}{project_name}有多少作业计划?", ["date", "project_name"]),
("{project_name}{date}有多少作业计划?", ["project_name","date"]),
("{date}{project_type}类的作业计划有多少?", ["date", "project_type"]),
("{project_name}{date}有多少作业计划?", ["project_name", "date"]),
("工程性质是{project_type}{date}有多少作业计划?", ["project_type", "date"]),
("{date}风险等级为{risk_level}的作业计划有多少?", ["date", "risk_level"]),
("{date}工程性质为{project_type}的有多少作业计划?", ["date", "project_type"]),
("工程性质为{project_type}{date}有多少作业计划?", ["project_type", "date"]),
("查询{project_name}{date}的作业计划有多少?", ["project_name", "date"]),
("工程性质为{project_type}{date}有多少作业计划?", ["project_type", "date"]),
("查询{project_name}{date}的作业计划数量", ["project_name", "date"]),
("{date}{project_type}类作业计划有多少?", ["date", "project_type"]),
("{project_type}{date}作业计划有多少?", ["date", "project_type"]),
("{project_type}{date}作业计划有多少?", ["project_type", "date"]),
("{construction_unit}{date}有多少作业计划?", ["construction_unit", "date"]),
("{date}{construction_unit}有多少作业计划?", ["date", "construction_unit"]),
("{date}有多少作业计划?", ["date"]),
("公司{date}有多少作业计划?", ["date"]),
("{date}属于{operating}有多少作业计划?", ["date","operating"]),
("{date}属于{operating}有多少作业计划?", ["date", "operating"]),
("{date}{implementation_organization}有多少作业计划?", ["date", "implementation_organization"]),
("{date}{project_department}有多少作业计划?", ["date", "project_department"]),
("{project_department}{date}有多少{risk_level}风险作业计划?", ["project_department", "date", "risk_level"]),
("{date}{project_manager}有多少作业计划?", ["date", "project_manager"]),
("{date}{subcontractor}有多少作业计划?", ["date", "subcontractor"]),
("{date}{team_leader}有多少作业计划?", ["date", "team_leader"]),
("{date}风险等级为{risk_level}的作业计划有多少?", ["date", "risk_level"]),
("{date}{project_department}有多少{risk_level}风险作业计划?", ["date", "project_department", "risk_level"]),
("{date}{project_type}中,风险等级为{risk_level}的作业计划有多少?",["date", "project_type", "risk_level"]),
("{date}{construction_unit}中,风险等级为{risk_level}的计划有多少?",["date", "construction_unit", "risk_level"]),
("{date}{project_type}风险等级为{risk_level}的作业计划有多少?", ["date", "project_type", "risk_level"]),
("{date}{construction_unit}有多少{risk_level}风险作业计划?", ["date", "construction_unit", "risk_level"]),
("{date}{project_type}类中,由{construction_unit}负责的作业计划有多少?",["date", "project_type", "construction_unit"]),
("{date}{project_type}类中,由{implementation_organization}组织实施的作业计划有多少?",["date", "project_type", "implementation_organization"]),
("{date}{project_department}管理的{project_type}类作业计划有多少?",["date", "project_department", "project_type"]),
("{date}{subcontractor}承包的{project_type}类作业计划有多少?",["date", "subcontractor", "project_type"]),
("{date}{project_manager}负责的{project_type}类作业计划有多少?",["date", "project_manager", "project_type"]),
("{date}{project_type}{construction_unit}负责的作业计划有多少?",
["date", "project_type", "construction_unit"]),
("{date}{project_type}{implementation_organization}组织实施的作业计划有多少?",
["date", "project_type", "implementation_organization"]),
("{date}{project_department}管理的{project_type}类作业计划有多少?",
["date", "project_department", "project_type"]),
("{date}{subcontractor}承包的{project_type}类作业计划有多少?", ["date", "subcontractor", "project_type"]),
("{date}{project_manager}负责的{project_type}类作业计划有多少?",
["date", "project_manager", "project_type"]),
("{date}{team_leader}带领的{project_type}类作业计划有多少?", ["date", "team_leader", "project_type"]),
("{date}{project_name}{project_manager}作业计划有多少?",["date", "project_name", "project_manager"]),
("{date}{project_name}{project_manager}作业计划有多少?", ["date", "project_name", "project_manager"]),
("{date}{project_name}中,风险等级为{risk_level}的作业计划有多少?", ["date", "project_name", "risk_level"]),
("{date}{project_manager}作业计划有多少?", ["date","project_manager"]),
("{date}{project_manager}作业计划有多少?", ["date", "project_manager"]),
("{project_manager}{date}作业计划有多少?", ["project_manager", "date"]),
("{date}{project_manager}的作业计划数量", ["date", "project_manager"]),
("{project_manager}{date}的作业计划数量", ["project_manager", "date"]),
# 班组
("{date}{team_name}有多少项作业计划?", ["date", "team_name"]),
("{team_name}{date}有多少作业计划?", ["team_name", "date"]),
("{team_name}{date}作业计划数量", ["team_name", "date"]),
("{date}{team_name}作业计划数量", ["date", "team_name"]),
]
},
"周计划数量查询": {
"date": ["本周", "上周","上一周", "下周", "下一周", "最近一周", "本周内", "这一周"],
"date": ["本周", "上周", "上一周", "下周", "下一周", "最近一周", "本周内", "这一周"],
"templates": [
("{date}{project_name}作业计划有多少?", ["date", "project_name"]),
("{project_name}{date}作业计划有多少?", ["project_name", "date"]),
@ -80,7 +113,7 @@ TEMPLATE_CONFIG = {
("{date}作业计划有多少?", ["date"]),
# 🎯 date + 其他单个维度
("{date}{project_name}有多少作业计划?", ["date", "project_name"]),
("{date}{project_name}有多少作业计划?", ["date", "project_name"]),
("{date}{construction_unit}作业计划有多少?", ["date", "construction_unit"]),
("{date}{implementation_organization}作业计划有多少?", ["date", "implementation_organization"]),
@ -89,33 +122,45 @@ TEMPLATE_CONFIG = {
("{date}{subcontractor}作业计划有多少?", ["date", "subcontractor"]),
("{date}{team_leader}作业计划有多少?", ["date", "team_leader"]),
("{date}{project_department}作业计划数量", ["date", "project_department"]),
("{date}{subcontractor}作业计划数量?", ["date", "subcontractor"]),
# 🎯 date + 风险维度
("{date}风险等级为{risk_level}的作业计划有多少?", ["date", "risk_level"]),
("{date}有多少{risk_level}风险作业计划", ["date", "risk_level"]),
# 🎯 date + construction_unit + risk_level
("{date}{construction_unit}风险等级为{risk_level}的作业计划有多少?", ["date", "construction_unit", "risk_level"]),
("{construction_unit}{date}有多少项{risk_level}风险作业计划", ["construction_unit", "date", "risk_level"]),
# 🎯 date + implementation_organization + risk_level
("{date}{implementation_organization}风险等级为{risk_level}的作业计划有多少?",["date", "implementation_organization", "risk_level"]),
("{date}{implementation_organization}风险等级为{risk_level}的作业计划有多少?",
["date", "implementation_organization", "risk_level"]),
# 🎯 date + project_name + project_manager
("{date}{project_name}{project_manager}负责的作业计划有多少?", ["date", "project_name", "project_manager"]),
# 🎯 date + project_name + risk_level
("{date}{project_name}中,风险等级为{risk_level}的作业计划有多少", ["date", "project_name", "risk_level"]),
("{date}{project_name}有多少项{risk_level}风险作业计划", ["date", "project_name", "risk_level"]),
# 🎯 project_manager 维度
("{project_manager}{date}作业计划数量?", ["project_manager", "date"]),
("{project_manager}{date}作业计划有多少?", ["project_manager", "date"]),
("{project_manager}{date}负责的风险等级为{risk_level}的作业计划有多少?", ["project_manager", "date", "risk_level"]),
("{project_manager}{date}负责的风险等级为{risk_level}的作业计划有多少?",
["project_manager", "date", "risk_level"]),
("{date}{team_name}有多少项作业计划?", ["date", "team_name"]),
("{team_name}{date}有多少作业计划?", ["team_name", "date"]),
("{team_name}{date}作业计划数量", ["team_name", "date"]),
("{date}{team_name}的作业计划数量", ["date", "team_name"]),
]
},
"日计划作业内容": {
"date": ["今日", "昨日", "2024年5月24日", "5月24日","今天","昨天"],
"date": ["今日", "昨日", "2024年5月24日", "5月24日", "今天", "昨天"],
"templates": [
("{date}{project_name}作业内容是什么?", ["date", "project_name"]),
("{project_name}{date}作业内容是什么", ["project_name", "date"]),
("{date}{project_type}类作业内容是什么?", ["date", "project_type"]),
("{project_type}{date}作业内容是什么?", ["project_type", "date"]),
("{date}工程性质为{project_type}的作业内容是什么?", ["date", "project_type"]),
("工程性质为{project_type}{date}作业内容是什么?", ["project_type", "date"]),
("{construction_unit}{date}作业内容是什么?", ["construction_unit", "date"]),
@ -125,6 +170,10 @@ TEMPLATE_CONFIG = {
# 3. 查询特定日期和项目类型的工程计划
("{date}{project_type}类计划作业内容是什么?", ["date", "project_type"]),
("{date}{construction_unit}{risk_level}风险的作业内容是什么?", ["date", "construction_unit", "risk_level"]),
("{date}{implementation_organization}{risk_level}风险的作业内容是什么?",
["date", "implementation_organization", "risk_level"]),
# 5. 查询特定日期和项目经理的任务安排
("{project_manager}{date}作业内容是什么?", ["project_manager", "date"]),
@ -139,11 +188,13 @@ TEMPLATE_CONFIG = {
("{team_leader}{date}作业内容是什么?", ["team_leader", "date"]),
# 9. 查询特定日期和项目类型下的高风险任务
("{date}{project_type}中,风险等级为{risk_level}的作业内容是什么?", ["date", "project_type", "risk_level"]),
("{date}{project_type}风险等级为{risk_level}的作业内容是什么?", ["date", "project_type", "risk_level"]),
# 10. 查询特定日期和风险等级的任务安排
("{date}风险等级为{risk_level}的作业内容是什么?", ["date", "risk_level"]),
("{date}有多少项{risk_level}风险作业计划?", ["date", "risk_level"]),
# 11. 查询特定日期和施工单位的任务进展
("{construction_unit}{date}作业内容是什么?", ["construction_unit", "date"]),
@ -151,17 +202,21 @@ TEMPLATE_CONFIG = {
("{project_manager}{date}作业内容是什么?", ["project_manager", "date"]),
# 13. 查询特定日期和项目经理的高风险任务
("{project_manager}{date}的风险等级为{risk_level}的作业内容是什么?", ["project_manager", "date", "risk_level"]),
("{project_manager}{date}的风险等级为{risk_level}的作业内容是什么?",
["project_manager", "date", "risk_level"]),
# 15. 查询特定日期和所有任务安排
("{date}作业内容是什么?", ["date"]),
# 16. 查询特定日期和项目进度
("{date}{project_name}作业内容是什么?", ["date", "project_name"]),
]
# 班组
("{date}{team_name}作业内容是什么?", ["date", "team_name"]),
("{team_name}{date}作业内容", ["team_name", "date"]),
]
},
"周计划作业内容": {
"date": ["本周", "上周","上一周", "下周", "下一周", "最近一周", "本周内", "这一周"],
"date": ["本周", "上周", "上一周", "下周", "下一周", "最近一周", "本周内", "这一周"],
"templates": [
("工程性质为{project_type}{date}作业内容是什么?", ["project_type", "date"]),
("{date}工程性质为{project_type}作业内容是什么?", ["date", "project_type"]),
@ -172,25 +227,27 @@ TEMPLATE_CONFIG = {
# 4. 查询某项目在指定周的所有作业计划
("{project_name}{date}作业内容是什么?", ["project_name", "date"]),
# 5. 查询指定周的所有项目类型作业内容
("{date}{project_type}类作业内容是什么?", ["date", "project_type"]),
# 6. 查询某施工单位在指定周的作业任务
("{construction_unit}{date}作业内容是什么?", ["construction_unit", "date"]),
# 7. 查询某项目经理在指定周负责的作业内容
("{project_manager}{date}作业内容是什么?", ["project_manager", "date"]),
# 8. 查询某团队负责人在指定周的作业安排
("{team_leader}{date}作业内容是什么?", ["team_leader", "date"]),
# 9. 查询某项目类型在指定周的高风险作业内容
("{date}{project_type}类中,风险等级为{risk_level}的作业内容是什么?", ["date", "project_type", "risk_level"]),
("{date}{project_type}类并且风险等级为{risk_level}的作业内容是什么?",
["date", "project_type", "risk_level"]),
# 10. 查询某风险等级在指定周的作业内容
("{date}风险等级为{risk_level}的作业内容是什么?", ["date", "risk_level"]),
("{date}{risk_level}风险的作业内容是什么?", ["date", "risk_level"]),
# 11. 查询某施工单位在指定周的作业进展
("{construction_unit}{date}作业内容是什么?", ["construction_unit", "date"]),
@ -199,10 +256,13 @@ TEMPLATE_CONFIG = {
# 15. 查询某项目部门在指定周的作业安排
("{project_department}{date}作业内容是什么?", ["project_department", "date"]),
]
("{date}{team_name}作业内容是什么", ["date", "team_name"]),
("{team_name}{date}作业内容", ["team_name", "date"]),
]
},
"施工人数": {
"date": ["今日", "昨日", "2024年5月24日", "5月24日","今天","昨天"],
"date": ["今日", "昨日", "2024年5月24日", "5月24日", "今天", "昨天"],
"templates": [
("{date}{project_name}施工人员有多少?", ["date", "project_name"]),
("{date}{project_name}施工人数是多少?", ["date", "project_name"]),
@ -235,14 +295,12 @@ TEMPLATE_CONFIG = {
("{date}{team_leader}的施工人员有多少?", ["date", "team_leader"]),
("{date}{team_leader}的施工人数是多少?", ["date", "team_leader"]),
# 11. 查询某实施单位在指定日期的施工人员总数
("{implementation_organization}{date}的施工人数是多少?", ["implementation_organization", "date"]),
("{implementation_organization}{date}的施工人员有多少?", ["implementation_organization", "date"]),
("{date}{team_leader}的施工人员有多少?", ["date", "team_leader"]),
("{date}{team_leader}的施工人数是多少?", ["date", "team_leader"]),
# 16. 统计某项目部门在指定日期的施工人员数量
("{project_department}{date}的施工人员有多少?", ["project_department", "date"]),
("{project_department}{date}的施工人数是多少?", ["project_department", "date"]),
@ -254,7 +312,13 @@ TEMPLATE_CONFIG = {
("{subcontractor}{date}的施工人数是多少?", ["subcontractor", "date"]),
# 22. 统计某施工单位在指定周的高风险作业人员数量
("{construction_unit}{date}风险等级为{risk_level}的施工人数是多少?", ["construction_unit", "date", "risk_level"]),
("{construction_unit}{date}风险等级为{risk_level}的施工人数是多少?",
["construction_unit", "date", "risk_level"]),
("{date}{team_name}施工人数是多少", ["date", "team_name"]),
("{date}{team_name}施工人数", ["date", "team_name"]),
("{team_name}{date}施工人数是多少", ["team_name", "date"]),
("{team_name}{date}施工人数", ["team_name", "date"]),
]
},
@ -284,11 +348,13 @@ TEMPLATE_CONFIG = {
# 11. 查询某分包商在指定周的出勤情况
("{subcontractor}{date}的出勤情况如何?", ["subcontractor", "date"]),
("{date}{team_name}考勤人数是多少", ["date", "team_name"]),
("{team_name}{date}考勤人数", ["team_name", "date"]),
]
},
"页面切换": {
"date": ["本周", "上周", "过去一周", "最近一周", "本周内", "这一周", "上个星期", "这个星期", "今日", "昨日",
"2024年5月24日", "5月24日", "24日", "周一"],
"date": ["今日", "昨日", "2024年5月24日", "5月24日", "今天", "昨天"],
"templates": [
("打开{page}页面", ["page"]),
("打开{page}", ["page"]),
@ -296,24 +362,18 @@ TEMPLATE_CONFIG = {
("进入{page}", ["page"]),
("进入{page}模块", ["page"]),
("进入{page}页面", ["page"]),
("查看{page}", ["page"]),
("查看{page}模块", ["page"]),
("查看{page}页面", ["page"]),
("跳转到{page}", ["page"]),
("跳转到{page}模块", ["page"]),
("跳转到{page}页面", ["page"]),
("访问{page}页面", ["page"]),
("访问{page}模块", ["page"]),
("访问{page}", ["page"]),
("显示{page}模块", ["page"]),
("显示{page}", ["page"]),
("请打开{page}模块", ["page"]),
("请打开{page}", ["page"]),
("显示{page}页面", ["page"]),
("加载{page}模块", ["page"]),
("加载{page}", ["page"]),
("加载{page}页面", ["page"]),
("查询{page}模块", ["page"]),
("查询{page}", ["page"]),
("查询{page}页面", ["page"]),
]
}
}
@ -334,7 +394,8 @@ def generate_natural_samples(config, label):
"project_department": BASE_DATA["project_departments"],
"project_manager": BASE_DATA["project_managers"],
"page": BASE_DATA["pages"],
"operating": BASE_DATA["operatings"]
"operating": BASE_DATA["operatings"],
"team_name": BASE_DATA["team_names"]
}
for template, variables in config["templates"]:

View File

@ -1,5 +1,6 @@
import json
import os
import random
# 目录路径
directory = "output/ernie"
@ -25,16 +26,16 @@ def load_json(file_path):
def convert_data_format(data):
converted_list = []
for item in data:
if "text" in item and "prompt" in item:
if "text" in item and "label" in item:
converted_list.append({
"text": item["text"],
"label": item["prompt"] # prompt → label
"label": item["label"] # prompt → label
})
return converted_list
# 按 7:3 比例分割 JSON 数据
def split_json(input_file, output_file1, output_file2):
# 随机按 7:3 比例分割 JSON 数据
def split_json_random(input_file, output_file1, output_file2):
# 读取数据
data = load_json(input_file)
@ -46,12 +47,15 @@ def split_json(input_file, output_file1, output_file2):
# 转换数据格式
converted_data = convert_data_format(data)
# 打乱数据顺序
random.shuffle(converted_data)
# 计算数据的分割点7:3
split_point = int(len(converted_data) * 0.8)
split_point = int(len(converted_data) * 0.7)
# 按比例分割数据
data_part1 = converted_data[:split_point] # 70% 数据
data_part2 = converted_data[split_point:] # 30% 数据
data_part1 = converted_data[:split_point] # 70% 训练数据
data_part2 = converted_data[split_point:] # 30% 验证数据
# 保存数据到两个文件
with open(output_file1, 'w', encoding='utf-8') as f1:
@ -60,8 +64,7 @@ def split_json(input_file, output_file1, output_file2):
with open(output_file2, 'w', encoding='utf-8') as f2:
json.dump(data_part2, f2, ensure_ascii=False, indent=4)
print(
f"数据已转换并按 7:3 比例分割,保存至:\n - {output_file1}{len(data_part1)} 条)\n - {output_file2}{len(data_part2)} 条)")
print(f"数据已随机打乱并按 7:3 分割,保存至:\n - {output_file1}{len(data_part1)} 条)\n - {output_file2}{len(data_part2)} 条)")
# 输入的 JSON 文件路径
@ -70,5 +73,5 @@ input_file = 'output/merged_data.json'
output_file1 = 'output/ernie/train.json'
output_file2 = 'output/ernie/val.json'
# 执行数据转换和分割
split_json(input_file, output_file1, output_file2)
# 执行数据转换和随机分割
split_json_random(input_file, output_file1, output_file2)

View File

@ -1,39 +0,0 @@
import json
# 读取 JSON 文件
def load_json(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
# 按7:3比例将一个JSON文件分成两个
def split_json(input_file, output_file1, output_file2):
# 读取数据
data = load_json(input_file)
# 计算数据的分割点
split_point = int(len(data) * 0.7)
# 按比例分割数据
data_part1 = data[:split_point] # 前70%数据
data_part2 = data[split_point:] # 后30%数据
# 保存数据到两个文件
with open(output_file1, 'w', encoding='utf-8') as f1:
json.dump(data_part1, f1, ensure_ascii=False, indent=4)
with open(output_file2, 'w', encoding='utf-8') as f2:
json.dump(data_part2, f2, ensure_ascii=False, indent=4)
print(f"数据已按 7:3 比例分割并保存到 {output_file1}{output_file2}")
# 输入的 JSON 文件路径
input_file = 'merged_data.json'
# 输出的两个文件路径
output_file1 = 'data_part1.json'
output_file2 = 'data_part2.json'
# 按 7:3 比例分割并保存
split_json(input_file, output_file1, output_file2)

View File

@ -1,4 +1,4 @@
labels: ['date', 'project_name', 'project_type', 'construction_unit','implementation_organization', 'project_department', 'project_manager','subcontractor', 'team_leader', 'risk_level','page'] # 类别名称
labels: ['date', 'project_name', 'project_type', 'construction_unit','implementation_organization', 'project_department', 'project_manager','subcontractor', 'team_leader', 'risk_level','page','operating'] # 类别名称
# Model configuration for selecting the pretrained model and other model-related settings

File diff suppressed because one or more lines are too long

View File

@ -2,12 +2,12 @@ from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer
import paddle
# 1. 加载模型和 tokenizer
model_path = R"E:\workingSpace\PycharmProjects\Intention_dev\uie\output\checkpoint-16920" # 你的模型路径
model_path = R"E:\workingSpace\PycharmProjects\Intention_dev\uie\output\checkpoint-2440" # 你的模型路径
model = ErnieForTokenClassification.from_pretrained(model_path)
tokenizer = ErnieTokenizer.from_pretrained(model_path)
# 2. 处理输入文本
text = "今天杨柳220kV变电站220kV南坪间隔扩建工程有多少作业计划"
text = "李四班组今天有多少作业计划"
inputs = tokenizer(text, max_len=512, return_tensors="pd")
# 3. 进行预测
@ -19,17 +19,19 @@ with paddle.no_grad():
# 4. 标签映射
label_map = {
0: 'O', # 非实体
1: 'B-date', 12: 'I-date',
2: 'B-project_name', 13: 'I-project_name',
3: 'B-project_type', 14: 'I-project_type',
4: 'B-construction_unit', 15: 'I-construction_unit',
5: 'B-implementation_organization', 16: 'I-implementation_organization',
6: 'B-project_department', 17: 'I-project_department',
7: 'B-project_manager', 18: 'I-project_manager',
8: 'B-subcontractor', 19: 'I-subcontractor',
9: 'B-team_leader', 20: 'I-team_leader',
10: 'B-risk_level', 21: 'I-risk_level',
11: 'B-page', 22: 'I-page',
1: 'B-date', 14: 'I-date',
2: 'B-projectName', 15: 'I-projectName',
3: 'B-projectType', 16: 'I-projectType',
4: 'B-constructionUnit', 17: 'I-constructionUnit',
5: 'B-implementationOrganization', 18: 'I-implementationOrganization',
6: 'B-projectDepartment', 19: 'I-projectDepartment',
7: 'B-projectManager', 20: 'I-projectManager',
8: 'B-subcontractor', 21: 'I-subcontractor',
9: 'B-teamLeader', 22: 'I-teamLeader',
10: 'B-riskLevel', 23: 'I-riskLevel',
11: 'B-page', 24: 'I-page',
12: 'B-operating', 25: 'I-operating',
13: 'B-teamName', 26: 'I-teamName',
}
# 5. 解析预测结果
@ -41,6 +43,7 @@ current_entity = None
current_label = None
for token, label_id in zip(tokens, predicted_labels):
print(label_id)
label = label_map.get(label_id, "O")
if label.startswith("B-"): # 开始新实体

View File

@ -18,7 +18,7 @@ def preprocess_function(example, tokenizer):
entity_types = [
'date', 'project_name', 'project_type', 'construction_unit',
'implementation_organization', 'project_department', 'project_manager',
'subcontractor', 'team_leader', 'risk_level','page'
'subcontractor', 'team_leader', 'risk_level','page','operating','team_name'
]
# 文本 Tokenization
@ -60,7 +60,7 @@ def preprocess_function(example, tokenizer):
# === 3. 加载 UIE 预训练模型 ===
model = ErnieForTokenClassification.from_pretrained("uie-base", num_classes=25) # 3 类 (O, B, I)
model = ErnieForTokenClassification.from_pretrained("uie-base", num_classes=27) # 3 类 (O, B, I)
tokenizer = ErnieTokenizer.from_pretrained("uie-base")
# === 4. 加载数据集 ===

247
uie/train1.py Normal file
View File

@ -0,0 +1,247 @@
import json
import os
import yaml
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Optional
import paddle
from paddlenlp.metrics import SpanEvaluator
from sklearn.metrics import classification_report
from paddlenlp.data import DataCollatorWithPadding
from paddlenlp.datasets import MapDataset
from paddlenlp.trainer import Trainer, TrainingArguments, get_last_checkpoint
from paddlenlp.transformers import ErnieForTokenClassification, AutoTokenizer, UIEM, UIE, export_model
from paddlenlp.utils.log import logger
def load_config(config_path):
"""加载YAML配置文件"""
with open(config_path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
# === 1. 加载数据 ===
def load_dataset(data_path):
with open(data_path, "r", encoding="utf-8") as f:
data = json.load(f)
return MapDataset(data)
@dataclass
class DataArguments:
train_path: str
dev_path: str
max_seq_length: Optional[int] = 512
dynamic_max_length: Optional[List[int]] = None
@dataclass
class ModelArguments:
model_name_or_path: str = "uie-base"
export_model_dir: Optional[str] = None
multilingual: bool = False
def preprocess_function(example, tokenizer):
# 文本 Tokenization
inputs = tokenizer(example["text"], max_length=512, truncation=True, return_offsets_mapping=True)
offset_mapping = inputs["offset_mapping"]
# 初始化 label_ids0 表示 O 标签)
label_ids = [0] * len(offset_mapping) # 0: O, 1: B-XXX, 2: I-XXX
# 处理实体
if "annotations" in example:
for entity in example["annotations"]:
start, end, entity_label = entity["start"], entity["end"], entity["label"]
# 确保 entity_label 在我们的标签范围内
if entity_label not in config["labels"]:
continue # 如果实体标签不在范围内,则跳过
# 将实体类型映射到索引编号
entity_class = config["labels"].index(entity_label) + 1 # 1: B-XXX, 2: B-XXX, ...
# 处理实体的起始位置
entity_started = False # 标记实体是否已开始
for idx, (char_start, char_end) in enumerate(offset_mapping):
token = inputs['input_ids'][idx]
# 排除特殊 token
if token == tokenizer.cls_token_id or token == tokenizer.sep_token_id:
continue # 跳过 [CLS] 和 [SEP] token
if char_start >= start and char_end <= end:
if not entity_started:
label_ids[idx] = entity_class # B-实体
entity_started = True
else:
label_ids[idx] = entity_class + len(config["labels"]) # I-实体
# 将标注结果加到输入
inputs["labels"] = label_ids
del inputs["offset_mapping"] # 删除 offset_mapping
return inputs
# 加载配置文件
config = load_config("data.yaml")
# 从配置文件中提取参数
model_args = ModelArguments(**config["model_args"])
data_args = DataArguments(**config["data_args"])
# 确保学习率是浮动数值
if isinstance(config["training_args"]["learning_rate"], str):
config["training_args"]["learning_rate"] = float(config["training_args"]["learning_rate"])
training_args = CompressionArguments(**config["training_args"])
# 打印模型和数据配置
print(f"Model config: {model_args}")
print(f"Data config: {data_args}")
print(f"Training config: {training_args}")
paddle.set_device(training_args.device)
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# 检查是否存在上次训练的检查点
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
if model_args.multilingual:
model = UIEM.from_pretrained(model_args.model_name_or_path)
else:
model = UIE.from_pretrained(model_args.model_name_or_path)
# === 4. 加载数据集 ===
train_dataset = load_dataset(R"data/train.json") # 训练数据集
dev_dataset = load_dataset(R"data/val.json") # 验证数据集
# === 5. 处理数据 ===
train_ds = train_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False)
dev_ds = dev_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False)
if training_args.device == "npu":
data_collator = DataCollatorWithPadding(tokenizer, padding="longest")
else:
data_collator = DataCollatorWithPadding(tokenizer)
criterion = paddle.nn.BCELoss()
def uie_loss_func(outputs, labels):
start_ids, end_ids = labels
start_prob, end_prob = outputs
start_ids = paddle.cast(start_ids, "float32")
end_ids = paddle.cast(end_ids, "float32")
loss_start = criterion(start_prob, start_ids)
loss_end = criterion(end_prob, end_ids)
loss = (loss_start + loss_end) / 2.0
return loss
def compute_metrics(p):
metric = SpanEvaluator()
start_prob, end_prob = p.predictions
start_ids, end_ids = p.label_ids
metric.reset()
num_correct, num_infer, num_label = metric.compute(start_prob, end_prob, start_ids, end_ids)
metric.update(num_correct, num_infer, num_label)
precision, recall, f1 = metric.accumulate()
metric.reset()
return {"precision": precision, "recall": recall, "f1": f1}
trainer = Trainer(
model=model,
criterion=uie_loss_func,
args=training_args,
data_collator=data_collator,
train_dataset=train_ds if training_args.do_train or training_args.do_compress else None,
eval_dataset=dev_ds if training_args.do_eval or training_args.do_compress else None,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
trainer.optimizer = paddle.optimizer.AdamW(
learning_rate=training_args.learning_rate, parameters=model.parameters()
)
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
# 训练过程
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
trainer.save_model()
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
# 评估模型
if training_args.do_eval:
eval_metrics = trainer.evaluate()
trainer.log_metrics("eval", eval_metrics)
# 导出推理模型
if training_args.do_export:
if training_args.device == "npu":
input_spec_dtype = "int32"
else:
input_spec_dtype = "int64"
input_spec = [
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="input_ids"),
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="position_ids"),
]
if model_args.export_model_dir is None:
model_args.export_model_dir = os.path.join(training_args.output_dir, "export")
export_model(model=trainer.model, input_spec=input_spec, path=model_args.export_model_dir)
trainer.tokenizer.save_pretrained(model_args.export_model_dir)
# 如果需要压缩模型
if training_args.do_compress:
@paddle.no_grad()
def custom_evaluate(self, model, data_loader):
metric = SpanEvaluator()
model.eval()
metric.reset()
for batch in data_loader:
if model_args.multilingual:
logits = model(input_ids=batch["input_ids"], position_ids=batch["position_ids"])
else:
logits = model(
input_ids=batch["input_ids"],
token_type_ids=batch["token_type_ids"],
position_ids=batch["position_ids"],
attention_mask=batch["attention_mask"],
)
start_prob, end_prob = logits
start_ids, end_ids = batch["start_positions"], batch["end_positions"]
num_correct, num_infer, num_label = metric.compute(start_prob, end_prob, start_ids, end_ids)
metric.update(num_correct, num_infer, num_label)
precision, recall, f1 = metric.accumulate()
logger.info("f1: %s, precision: %s, recall: %s" % (f1, precision, recall))
model.train()
return f1
trainer.compress(custom_evaluate=custom_evaluate)