From 2f4dff403df09ce33b3cdf4461b8ef008114d863 Mon Sep 17 00:00:00 2001 From: weiweiw <14335254+weiweiw22@user.noreply.gitee.com> Date: Mon, 5 May 2025 15:02:37 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E10=E4=B8=AD=E6=84=8F=E5=9B=BE?= =?UTF-8?q?=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ernie/train.py | 3 +- generated_data/generated.py | 214 ++++++++++++++++++------------------ uie/train.py | 34 ++---- 3 files changed, 119 insertions(+), 132 deletions(-) diff --git a/ernie/train.py b/ernie/train.py index 2410cfd..1722e63 100644 --- a/ernie/train.py +++ b/ernie/train.py @@ -7,7 +7,7 @@ import numpy as np import functools from paddle.nn import CrossEntropyLoss from paddlenlp.data import DataCollatorWithPadding -from paddlenlp.trainer import Trainer, TrainingArguments, EarlyStoppingCallback +from paddlenlp.trainer import Trainer, TrainingArguments import os from sklearn.metrics import precision_score, recall_score, f1_score @@ -140,7 +140,6 @@ def main(): eval_dataset=test_ds, data_collator=data_collator, compute_metrics=compute_metrics, # 使用自定义的评估指标 - callbacks=[EarlyStoppingCallback(early_stopping_patience=2)], ) # 训练模型 diff --git a/generated_data/generated.py b/generated_data/generated.py index 618a7dd..7f9d625 100644 --- a/generated_data/generated.py +++ b/generated_data/generated.py @@ -85,7 +85,7 @@ BASE_DATA = { "person_query_types": ["班组", "工程", "分公司", "实时组织", "项目部", "项目管理部"], # 工程状态 - "project_status_s": ["在建", "在作业", "在施工"] + "project_status_s": ["在建", "在作业", "在施工",""] } @@ -1137,8 +1137,8 @@ TEMPLATE_CONFIG = { ("实施组织数量", []), ("分公司详情", []), ("实施组织详情", []), - ("公司有哪些分公司", []), - ("公司有哪些实施组织", []), + ("公司具体有哪些分公司", []), + ("公司具体有哪些实施组织", []), ("{implementation_organization}详情", ["implementation_organization"]), ("{implementation_organization}情况", ["implementation_organization"]), ("请帮我查一下具体分公司详情", []), @@ -1147,139 +1147,141 @@ TEMPLATE_CONFIG = { ] }, - # "工程数量查询": { - # "date": ["今日", "昨日", "2024年5月24日", "5月24日", "今天", "昨天"], - # "templates": [ - # #公司 - # ("{date}公司有多少工程", ["date"]), - # ("{date}安徽送变电公司有多少工程{project_status}", ["date", "project_status"]), - # #分公司和项目部 - # ("{implementation_organization}{date}有多少工程{project_status}", - # ["implementation_organization", "date", "project_status"]), - # ("{implementation_organization}{project_department}{date}有多少工程{project_status}", - # ["implementation_organization", "project_department", "date", "project_status"]), - # #建管区域和单位 - # ("{construction_area}地区{date}风险等级为{risk_level}有多少工程?", ["construction_area", "date", "risk_level"]), - # - # ("{construction_area}地区{date}有多少工程{project_status}?", ["construction_area", "date", "project_status"]), - # - # ("{construction_unit}{date}有多少工程{project_status}?", ["construction_unit", "date","project_status"]), - # - # #分包商 - # ("{date}{subcontractor}有多少工程{project_status}", ["date", "subcontractor", "project_status"]), - # ("{date}送变电公司{project_department}有多少工程?", ["date", "project_department"]), - # #项目经理 - # ("{date}{project_manager}有多少工程{project_status}", ["date", "project_manager","project_status"]), - # #班组名称 - # ("{team_leader}{date}有多少工程", ["team_leader", "date"]), - # #工程性质 - # ("公司{date}{project_type}的工程有多少?", ["date", "project_type"]), - # #风险等级 - # ("公司{date}{risk_level}风险的{project_status}工程有多少?", ["date", "risk_level", "project_status"]), - # #询问工程数量时有工程性质和风险等级吗 - # ] - # }, - # - # "工程详情查询": { - # "date": ["今日", "昨日", "2024年5月24日", "5月24日", "今天", "昨天"], - # "templates": [ - # #公司 - # ("{date}公司有哪些工程", ["date"]), - # ("{date}安徽送变电公司有哪些工程{project_status}", ["date", "project_status"]), - # #分公司和项目部 - # ("{implementation_organization}{date}工程详情{project_status}", - # ["implementation_organization", "date", "project_status"]), - # ("{implementation_organization}{project_department}{date}有哪些工程{project_status}", - # ["implementation_organization", "project_department", "date", "project_status"]), - # #建管区域和单位 - # ("{construction_area}地区{date}风险等级为{risk_level}工程具体情况?", ["construction_area", "date", "risk_level"]), - # - # ("{construction_area}地区{date}有哪些工程{project_status}?", ["construction_area", "date", "project_status"]), - # - # ("{construction_unit}{date}有哪些工程{project_status}?", ["construction_unit", "date","project_status"]), - # - # #分包商 - # ("{date}{subcontractor}有多少工程{project_status}", ["date", "subcontractor", "project_status"]), - # ("{date}送变电公司{project_department}工程详情?", ["date", "project_department"]), - # #项目经理 - # ("{date}{project_manager}有多少工程{project_status}", ["date", "project_manager","project_status"]), - # #班组名称 - # ("{team_leader}{date}工程具体情况", ["team_leader", "date"]), - # #工程性质 - # ("公司{date}{project_type}的工程有哪些?", ["date", "project_type"]), - # #风险等级 - # ("公司{date}{risk_level}风险的{project_status}工程有那些?", ["date", "risk_level", "project_status"]), - # #询问工程数量时有工程性质和风险等级吗 - # ] - # }, - "工程数量查询": { - "date": ["今天","最近"], + "date": ["今日","今天", ""], "templates": [ #公司 - ("公司有多少工程", []), - ("安徽送变电公司有多少工程{project_status}", ["project_status"]), + ("{date}公司有多少工程", ["date"]), + ("{date}安徽送变电公司有多少个工程{project_status}", ["date", "project_status"]), #分公司和项目部 - ("{implementation_organization}有多少工程{project_status}", - ["implementation_organization", "project_status"]), - ("{implementation_organization}{project_department}有多少工程{project_status}", - ["implementation_organization", "project_department", "project_status"]), + ("{implementation_organization}{date}有多少工程{project_status}", + ["implementation_organization", "date", "project_status"]), + ("{implementation_organization}{project_department}{date}有多少个工程{project_status}", + ["implementation_organization", "project_department", "date", "project_status"]), #建管区域和单位 - ("{construction_area}地区风险等级为{risk_level}有多少工程?", ["construction_area", "risk_level"]), + ("{construction_area}地区{date}风险等级为{risk_level}有多少工程?", ["construction_area", "date", "risk_level"]), - ("{construction_area}地区有多少工程{project_status}?", ["construction_area", "project_status"]), + ("{construction_area}地区{date}有多少个工程{project_status}?", ["construction_area", "date", "project_status"]), - ("{construction_unit}有多少工程{project_status}?", ["construction_unit","project_status"]), + ("{construction_unit}{date}有多少工程{project_status}?", ["construction_unit", "date","project_status"]), #分包商 - ("{subcontractor}有多少工程{project_status}", ["subcontractor", "project_status"]), - ("安徽送变电公司{project_department}有多少工程?", ["project_department"]), + ("{date}{subcontractor}有多少工程{project_status}", ["date", "subcontractor", "project_status"]), + ("{date}送变电公司{project_department}有多少个工程?", ["date", "project_department"]), #项目经理 - ("{project_manager}有多少工程{project_status}", ["project_manager","project_status"]), + ("{date}{project_manager}有多少工程{project_status}", ["date", "project_manager","project_status"]), + ("{date}{project_manager}负责多少个工程{project_status}", ["date", "project_manager","project_status"]), #班组名称 - ("{team_leader}有多少{project_status}工程", ["team_leader", "project_status"]), + ("{team_leader}{date}有多少工程", ["team_leader", "date"]), #工程性质 - ("公司{project_type}类的工程有多少?", ["project_type"]), + ("公司{date}{project_type}的工程有多少个?", ["date", "project_type"]), #风险等级 - ("公司{risk_level}风险的{project_status}工程有多少?", ["risk_level", "project_status"]), + ("公司{date}{risk_level}风险的{project_status}工程有多少?", ["date", "risk_level", "project_status"]), #询问工程数量时有工程性质和风险等级吗 ] }, "工程详情查询": { - "date": ["今天","最近"], + "date": ["今日","今天", ""], "templates": [ #公司 - ("公司有哪些工程", []), - ("截止目前公司有哪些{project_status}工程", ["project_status"]), - ("安徽送变电公司有哪些工程{project_status}", ["project_status"]), + ("{date}公司有哪些工程", ["date"]), + ("{date}安徽送变电公司有哪些工程{project_status}", ["date", "project_status"]), #分公司和项目部 - ("{implementation_organization}工程详情{project_status}", - ["implementation_organization", "project_status"]), - ("{implementation_organization}{project_department}有哪些工程{project_status}", - ["implementation_organization", "project_department", "project_status"]), + ("{implementation_organization}{date}工程详情{project_status}", + ["implementation_organization", "date", "project_status"]), + ("{implementation_organization}{project_department}{date}有哪些工程{project_status}", + ["implementation_organization", "project_department", "date", "project_status"]), #建管区域和单位 - ("{construction_area}地区风险等级为{risk_level}工程具体情况?", ["construction_area", "risk_level"]), + ("{construction_area}地区{date}风险等级为{risk_level}工程具体情况?", ["construction_area", "date", "risk_level"]), - ("{construction_area}地区有哪些{project_status}工程?", ["construction_area", "project_status"]), + ("{construction_area}地区{date}有哪些工程{project_status}?", ["construction_area", "date", "project_status"]), - ("{construction_unit}有哪些工程{project_status}?", ["construction_unit","project_status"]), + ("{construction_unit}{date}有哪些工程{project_status}?", ["construction_unit", "date","project_status"]), #分包商 - ("{subcontractor}有多少{project_status}工程", ["subcontractor", "project_status"]), - ("送变电公司{project_department}工程详情?", ["project_department"]), + ("{date}{subcontractor}有哪些工程{project_status}", ["date", "subcontractor", "project_status"]), + ("{date}送变电公司{project_department}工程详情?", ["date", "project_department"]), #项目经理 - ("{project_manager}有多少工程{project_status}", ["project_manager","project_status"]), + ("{date}{project_manager}负责哪些工程{project_status}", ["date", "project_manager","project_status"]), #班组名称 - ("{team_leader}工程具体情况", ["team_leader"]), + ("{team_leader}{date}工程具体情况", ["team_leader", "date"]), #工程性质 - ("公司{project_type}类的工程有哪些?", ["project_type"]), + ("公司{date}{project_type}的工程有哪些?", ["date", "project_type"]), #风险等级 - ("公司{risk_level}风险的{project_status}工程有那些?", ["risk_level", "project_status"]), + ("公司{date}{risk_level}风险的{project_status}工程有哪些?", ["date", "risk_level", "project_status"]), #询问工程数量时有工程性质和风险等级吗 ] }, + # "工程数量查询": { + # "date": ["今天","今日", ""], + # "templates": [ + # #公司 + # ("{date}公司有多少工程", ["date"]), + # + # ("安徽送变电公司有多少工程{project_status}", ["project_status"]), + # #分公司和项目部 + # ("{date}{implementation_organization}有多少工程{project_status}", + # ["date", "implementation_organization", "project_status"]), + # ("{implementation_organization}{date}{project_department}有多少工程{project_status}", + # ["implementation_organization", "date", "project_department", "project_status"]), + # #建管区域和单位 + # ("{date}{construction_area}地区风险等级为{risk_level}有多少工程?", ["""construction_area", "risk_level"]), + # + # ("{construction_area}地区有多少工程{project_status}?", ["construction_area", "project_status"]), + # + # ("{construction_unit}有多少工程{project_status}?", ["construction_unit","project_status"]), + # + # #分包商 + # ("{subcontractor}有多少工程{project_status}", ["subcontractor", "project_status"]), + # ("安徽送变电公司{project_department}有多少工程?", ["project_department"]), + # #项目经理 + # ("{project_manager}有多少工程{project_status}", ["project_manager","project_status"]), + # #班组名称 + # ("{team_leader}有多少{project_status}工程", ["team_leader", "project_status"]), + # #工程性质 + # ("公司{project_type}类的工程有多少?", ["project_type"]), + # #风险等级 + # ("公司{risk_level}风险的{project_status}工程有多少?", ["risk_level", "project_status"]), + # #询问工程数量时有工程性质和风险等级吗 + # ] + # }, + # + # "工程详情查询": { + # "date": ["今天","今日",""], + # "templates": [ + # #公司 + # ("{date}公司有哪些工程", ["date"]), + # ("截止目前公司有哪些{project_status}工程", ["project_status"]), + # ("安徽送变电公司{date}有哪些工程{project_status}", ["date", "project_status"]), + # #分公司和项目部 + # ("{implementation_organization}{date}工程详情{project_status}", + # ["implementation_organization", "date", "project_status"]), + # ("{date}{implementation_organization}{project_department}有哪些工程{project_status}", + # ["date", "implementation_organization", "project_department", "project_status"]), + # #建管区域和单位 + # ("{date}{construction_area}地区风险等级为{risk_level}工程具体情况?", ["date", "construction_area", "risk_level"]), + # + # ("{construction_area}{date}地区有哪些{project_status}工程?", ["construction_area", "date", "project_status"]), + # + # ("{construction_unit}有哪些工程{project_status}?", ["construction_unit","project_status"]), + # + # #分包商 + # ("{subcontractor}有多少{project_status}工程", ["subcontractor", "project_status"]), + # ("送变电公司{project_department}工程详情?", ["project_department"]), + # #项目经理 + # ("{project_manager}有多少工程{project_status}", ["project_manager","project_status"]), + # #班组名称 + # ("{team_leader}工程具体情况", ["team_leader"]), + # #工程性质 + # ("公司{project_type}类的工程有哪些?", ["project_type"]), + # #风险等级 + # ("公司{risk_level}风险的{project_status}工程有那些?", ["risk_level", "project_status"]), + # #询问工程数量时有工程性质和风险等级吗 + # ] + # }, + "项目部数量查询": { "date": ["今天","最近"], "templates": [ @@ -1301,10 +1303,13 @@ TEMPLATE_CONFIG = { "date": ["今天","最近"], "templates": [ #公司 + ("公司具体项目部", []), ("公司有哪些项目部", []), ("安徽送变电公司项目管理部详情", []), + ("安徽送变电公司具体项目部", []), #分公司 ("{implementation_organization}项目部详情", ["implementation_organization"]), + ("{implementation_organization}具体项目部", ["implementation_organization"]), ("{implementation_organization}有哪些项目管理部", ["implementation_organization"]), #请帮我查一下 ("请帮我查一下公司项目部详情", []), @@ -1328,6 +1333,7 @@ TEMPLATE_CONFIG = { #公司 ("{project_name}建管单位情况", ["project_name"]), ("{project_name}建管单位详情", ["project_name"]), + ("{project_name}具体建管单位", ["project_name"]), ("请介绍下{construction_unit}详情", ["construction_unit"]), ("请介绍下{construction_unit}情况", ["construction_unit"]), ] @@ -1348,10 +1354,10 @@ TEMPLATE_CONFIG = { "date": ["今天","最近"], "templates": [ #公司 - ("{project_name}分包单位详情", ["project_name"]), + ("{project_name}具体分包单位详情", ["project_name"]), ("{project_name}分包商情况", ["project_name"]), ("请介绍下{subcontractor}详情", ["subcontractor"]), - ("请介绍下{subcontractor}情况", ["subcontractor"]), + ("请介绍下具体{subcontractor}情况", ["subcontractor"]), ] }, diff --git a/uie/train.py b/uie/train.py index 16fe27f..51bb37a 100644 --- a/uie/train.py +++ b/uie/train.py @@ -2,17 +2,15 @@ import json import paddle from paddlenlp.datasets import MapDataset from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer -from paddlenlp.trainer import Trainer, TrainingArguments, EarlyStoppingCallback +from paddlenlp.trainer import Trainer, TrainingArguments from paddlenlp.data import DataCollatorForTokenClassification - # === 1. 加载数据 === def load_dataset(data_path): with open(data_path, "r", encoding="utf-8") as f: data = json.load(f) return MapDataset(data) - # === 2. 预处理数据 === def preprocess_function(example, tokenizer): # 预定义实体类型列表 @@ -62,7 +60,7 @@ def preprocess_function(example, tokenizer): # === 3. 加载 UIE 预训练模型 === -lsmodel = ErnieForTokenClassification.from_pretrained("uie-base", num_classes=35) # 3 类 (O, B, I) +model = ErnieForTokenClassification.from_pretrained("uie-base", num_classes=35) # 3 类 (O, B, I) tokenizer = ErnieTokenizer.from_pretrained("uie-base") # === 4. 加载数据集 === @@ -73,6 +71,7 @@ dev_dataset = load_dataset("data/val.json") # 验证数据集 train_dataset = train_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False) dev_dataset = dev_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False) + # === 6. 数据整理 === data_collator = DataCollatorForTokenClassification(tokenizer, padding=True) @@ -89,30 +88,13 @@ training_args = TrainingArguments( save_total_limit=1, # 只保留最新 2 个模型 logging_dir="./logs", logging_steps=100, - eval_steps=5000, #evaluation_strategy="steps"时生效 - save_steps=5000, #save_strategy="steps"时生效 + eval_steps=5000, + save_steps=5000, seed=1000, load_best_model_at_end=True, ) -# === 8. 创建 EarlyStoppingCallback 实例 === -early_stopping_callback = EarlyStoppingCallback( - early_stopping_patience=2, # 连续多少次评估不提升就停 - early_stopping_threshold=0.01 # 最小提升幅度(例如设为0.01表示至少提升1%) -) - - -def compute_metrics(eval_preds): - predictions, labels = eval_preds - preds = predictions.argmax(axis=-1) - correct = (preds == labels).astype(int) - accuracy = correct.sum() / correct.size - return {"accuracy": accuracy} - - -training_args.metric_for_best_model = "accuracy" - -# === 9. 训练 === +# === 8. 训练 === trainer = Trainer( model=model, args=training_args, @@ -120,12 +102,11 @@ trainer = Trainer( eval_dataset=dev_dataset, tokenizer=tokenizer, data_collator=data_collator, - compute_metrics=compute_metrics, - callbacks=[early_stopping_callback], # 添加 EarlyStopping 回调 ) trainer.train() + # 为模型定义输入规格 input_spec = [ paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="input_ids"), @@ -133,3 +114,4 @@ input_spec = [ paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="position_ids"), paddle.static.InputSpec(shape=[None, 512], dtype="float32", name="attention_mask") ] +