新增10中意图数据

This commit is contained in:
weiweiw 2025-05-05 15:02:37 +08:00
parent fe34128d83
commit 2f4dff403d
3 changed files with 119 additions and 132 deletions

View File

@ -7,7 +7,7 @@ import numpy as np
import functools import functools
from paddle.nn import CrossEntropyLoss from paddle.nn import CrossEntropyLoss
from paddlenlp.data import DataCollatorWithPadding from paddlenlp.data import DataCollatorWithPadding
from paddlenlp.trainer import Trainer, TrainingArguments, EarlyStoppingCallback from paddlenlp.trainer import Trainer, TrainingArguments
import os import os
from sklearn.metrics import precision_score, recall_score, f1_score from sklearn.metrics import precision_score, recall_score, f1_score
@ -140,7 +140,6 @@ def main():
eval_dataset=test_ds, eval_dataset=test_ds,
data_collator=data_collator, data_collator=data_collator,
compute_metrics=compute_metrics, # 使用自定义的评估指标 compute_metrics=compute_metrics, # 使用自定义的评估指标
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
) )
# 训练模型 # 训练模型

View File

@ -85,7 +85,7 @@ BASE_DATA = {
"person_query_types": ["班组", "工程", "分公司", "实时组织", "项目部", "项目管理部"], "person_query_types": ["班组", "工程", "分公司", "实时组织", "项目部", "项目管理部"],
# 工程状态 # 工程状态
"project_status_s": ["在建", "在作业", "在施工"] "project_status_s": ["在建", "在作业", "在施工",""]
} }
@ -1137,8 +1137,8 @@ TEMPLATE_CONFIG = {
("实施组织数量", []), ("实施组织数量", []),
("分公司详情", []), ("分公司详情", []),
("实施组织详情", []), ("实施组织详情", []),
("公司有哪些分公司", []), ("公司具体有哪些分公司", []),
("公司有哪些实施组织", []), ("公司具体有哪些实施组织", []),
("{implementation_organization}详情", ["implementation_organization"]), ("{implementation_organization}详情", ["implementation_organization"]),
("{implementation_organization}情况", ["implementation_organization"]), ("{implementation_organization}情况", ["implementation_organization"]),
("请帮我查一下具体分公司详情", []), ("请帮我查一下具体分公司详情", []),
@ -1147,139 +1147,141 @@ TEMPLATE_CONFIG = {
] ]
}, },
# "工程数量查询": {
# "date": ["今日", "昨日", "2024年5月24日", "5月24日", "今天", "昨天"],
# "templates": [
# #公司
# ("{date}公司有多少工程", ["date"]),
# ("{date}安徽送变电公司有多少工程{project_status}", ["date", "project_status"]),
# #分公司和项目部
# ("{implementation_organization}{date}有多少工程{project_status}",
# ["implementation_organization", "date", "project_status"]),
# ("{implementation_organization}{project_department}{date}有多少工程{project_status}",
# ["implementation_organization", "project_department", "date", "project_status"]),
# #建管区域和单位
# ("{construction_area}地区{date}风险等级为{risk_level}有多少工程?", ["construction_area", "date", "risk_level"]),
#
# ("{construction_area}地区{date}有多少工程{project_status}", ["construction_area", "date", "project_status"]),
#
# ("{construction_unit}{date}有多少工程{project_status}", ["construction_unit", "date","project_status"]),
#
# #分包商
# ("{date}{subcontractor}有多少工程{project_status}", ["date", "subcontractor", "project_status"]),
# ("{date}送变电公司{project_department}有多少工程?", ["date", "project_department"]),
# #项目经理
# ("{date}{project_manager}有多少工程{project_status}", ["date", "project_manager","project_status"]),
# #班组名称
# ("{team_leader}{date}有多少工程", ["team_leader", "date"]),
# #工程性质
# ("公司{date}{project_type}的工程有多少?", ["date", "project_type"]),
# #风险等级
# ("公司{date}{risk_level}风险的{project_status}工程有多少?", ["date", "risk_level", "project_status"]),
# #询问工程数量时有工程性质和风险等级吗
# ]
# },
#
# "工程详情查询": {
# "date": ["今日", "昨日", "2024年5月24日", "5月24日", "今天", "昨天"],
# "templates": [
# #公司
# ("{date}公司有哪些工程", ["date"]),
# ("{date}安徽送变电公司有哪些工程{project_status}", ["date", "project_status"]),
# #分公司和项目部
# ("{implementation_organization}{date}工程详情{project_status}",
# ["implementation_organization", "date", "project_status"]),
# ("{implementation_organization}{project_department}{date}有哪些工程{project_status}",
# ["implementation_organization", "project_department", "date", "project_status"]),
# #建管区域和单位
# ("{construction_area}地区{date}风险等级为{risk_level}工程具体情况?", ["construction_area", "date", "risk_level"]),
#
# ("{construction_area}地区{date}有哪些工程{project_status}", ["construction_area", "date", "project_status"]),
#
# ("{construction_unit}{date}有哪些工程{project_status}", ["construction_unit", "date","project_status"]),
#
# #分包商
# ("{date}{subcontractor}有多少工程{project_status}", ["date", "subcontractor", "project_status"]),
# ("{date}送变电公司{project_department}工程详情?", ["date", "project_department"]),
# #项目经理
# ("{date}{project_manager}有多少工程{project_status}", ["date", "project_manager","project_status"]),
# #班组名称
# ("{team_leader}{date}工程具体情况", ["team_leader", "date"]),
# #工程性质
# ("公司{date}{project_type}的工程有哪些?", ["date", "project_type"]),
# #风险等级
# ("公司{date}{risk_level}风险的{project_status}工程有那些?", ["date", "risk_level", "project_status"]),
# #询问工程数量时有工程性质和风险等级吗
# ]
# },
"工程数量查询": { "工程数量查询": {
"date": ["","最近"], "date": ["今日","今天", ""],
"templates": [ "templates": [
#公司 #公司
("公司有多少工程", []), ("{date}公司有多少工程", ["date"]),
("安徽送变电公司有多少工程{project_status}", ["project_status"]), ("{date}安徽送变电公司有多少个工程{project_status}", ["date", "project_status"]),
#分公司和项目部 #分公司和项目部
("{implementation_organization}有多少工程{project_status}", ("{implementation_organization}{date}有多少工程{project_status}",
["implementation_organization", "project_status"]), ["implementation_organization", "date", "project_status"]),
("{implementation_organization}{project_department}有多少工程{project_status}", ("{implementation_organization}{project_department}{date}有多少个工程{project_status}",
["implementation_organization", "project_department", "project_status"]), ["implementation_organization", "project_department", "date", "project_status"]),
#建管区域和单位 #建管区域和单位
("{construction_area}地区风险等级为{risk_level}有多少工程?", ["construction_area", "risk_level"]), ("{construction_area}地区{date}风险等级为{risk_level}有多少工程?", ["construction_area", "date", "risk_level"]),
("{construction_area}地区有多少工程{project_status}", ["construction_area", "project_status"]), ("{construction_area}地区{date}有多少个工程{project_status}", ["construction_area", "date", "project_status"]),
("{construction_unit}有多少工程{project_status}", ["construction_unit","project_status"]), ("{construction_unit}{date}有多少工程{project_status}", ["construction_unit", "date","project_status"]),
#分包商 #分包商
("{subcontractor}有多少工程{project_status}", ["subcontractor", "project_status"]), ("{date}{subcontractor}有多少工程{project_status}", ["date", "subcontractor", "project_status"]),
("安徽送变电公司{project_department}有多少工程?", ["project_department"]), ("{date}送变电公司{project_department}有多少个工程?", ["date", "project_department"]),
#项目经理 #项目经理
("{project_manager}有多少工程{project_status}", ["project_manager","project_status"]), ("{date}{project_manager}有多少工程{project_status}", ["date", "project_manager","project_status"]),
("{date}{project_manager}负责多少个工程{project_status}", ["date", "project_manager","project_status"]),
#班组名称 #班组名称
("{team_leader}有多少{project_status}工程", ["team_leader", "project_status"]), ("{team_leader}{date}有多少工程", ["team_leader", "date"]),
#工程性质 #工程性质
("公司{project_type}的工程有多少?", ["project_type"]), ("公司{date}{project_type}的工程有多少个?", ["date", "project_type"]),
#风险等级 #风险等级
("公司{risk_level}风险的{project_status}工程有多少?", ["risk_level", "project_status"]), ("公司{date}{risk_level}风险的{project_status}工程有多少?", ["date", "risk_level", "project_status"]),
#询问工程数量时有工程性质和风险等级吗 #询问工程数量时有工程性质和风险等级吗
] ]
}, },
"工程详情查询": { "工程详情查询": {
"date": ["","最近"], "date": ["今日","今天", ""],
"templates": [ "templates": [
#公司 #公司
("公司有哪些工程", []), ("{date}公司有哪些工程", ["date"]),
("截止目前公司有哪些{project_status}工程", ["project_status"]), ("{date}安徽送变电公司有哪些工程{project_status}", ["date", "project_status"]),
("安徽送变电公司有哪些工程{project_status}", ["project_status"]),
#分公司和项目部 #分公司和项目部
("{implementation_organization}工程详情{project_status}", ("{implementation_organization}{date}工程详情{project_status}",
["implementation_organization", "project_status"]), ["implementation_organization", "date", "project_status"]),
("{implementation_organization}{project_department}有哪些工程{project_status}", ("{implementation_organization}{project_department}{date}有哪些工程{project_status}",
["implementation_organization", "project_department", "project_status"]), ["implementation_organization", "project_department", "date", "project_status"]),
#建管区域和单位 #建管区域和单位
("{construction_area}地区风险等级为{risk_level}工程具体情况?", ["construction_area", "risk_level"]), ("{construction_area}地区{date}风险等级为{risk_level}工程具体情况?", ["construction_area", "date", "risk_level"]),
("{construction_area}地区有哪些{project_status}工程", ["construction_area", "project_status"]), ("{construction_area}地区{date}有哪些工程{project_status}", ["construction_area", "date", "project_status"]),
("{construction_unit}有哪些工程{project_status}", ["construction_unit","project_status"]), ("{construction_unit}{date}有哪些工程{project_status}", ["construction_unit", "date","project_status"]),
#分包商 #分包商
("{subcontractor}有多少{project_status}工程", ["subcontractor", "project_status"]), ("{date}{subcontractor}有哪些工程{project_status}", ["date", "subcontractor", "project_status"]),
("送变电公司{project_department}工程详情?", ["project_department"]), ("{date}送变电公司{project_department}工程详情?", ["date", "project_department"]),
#项目经理 #项目经理
("{project_manager}有多少工程{project_status}", ["project_manager","project_status"]), ("{date}{project_manager}负责哪些工程{project_status}", ["date", "project_manager","project_status"]),
#班组名称 #班组名称
("{team_leader}工程具体情况", ["team_leader"]), ("{team_leader}{date}工程具体情况", ["team_leader", "date"]),
#工程性质 #工程性质
("公司{project_type}的工程有哪些?", ["project_type"]), ("公司{date}{project_type}的工程有哪些?", ["date", "project_type"]),
#风险等级 #风险等级
("公司{risk_level}风险的{project_status}工程有那些?", ["risk_level", "project_status"]), ("公司{date}{risk_level}风险的{project_status}工程有哪些?", ["date", "risk_level", "project_status"]),
#询问工程数量时有工程性质和风险等级吗 #询问工程数量时有工程性质和风险等级吗
] ]
}, },
# "工程数量查询": {
# "date": ["今天","今日", ""],
# "templates": [
# #公司
# ("{date}公司有多少工程", ["date"]),
#
# ("安徽送变电公司有多少工程{project_status}", ["project_status"]),
# #分公司和项目部
# ("{date}{implementation_organization}有多少工程{project_status}",
# ["date", "implementation_organization", "project_status"]),
# ("{implementation_organization}{date}{project_department}有多少工程{project_status}",
# ["implementation_organization", "date", "project_department", "project_status"]),
# #建管区域和单位
# ("{date}{construction_area}地区风险等级为{risk_level}有多少工程?", ["""construction_area", "risk_level"]),
#
# ("{construction_area}地区有多少工程{project_status}", ["construction_area", "project_status"]),
#
# ("{construction_unit}有多少工程{project_status}", ["construction_unit","project_status"]),
#
# #分包商
# ("{subcontractor}有多少工程{project_status}", ["subcontractor", "project_status"]),
# ("安徽送变电公司{project_department}有多少工程?", ["project_department"]),
# #项目经理
# ("{project_manager}有多少工程{project_status}", ["project_manager","project_status"]),
# #班组名称
# ("{team_leader}有多少{project_status}工程", ["team_leader", "project_status"]),
# #工程性质
# ("公司{project_type}类的工程有多少?", ["project_type"]),
# #风险等级
# ("公司{risk_level}风险的{project_status}工程有多少?", ["risk_level", "project_status"]),
# #询问工程数量时有工程性质和风险等级吗
# ]
# },
#
# "工程详情查询": {
# "date": ["今天","今日",""],
# "templates": [
# #公司
# ("{date}公司有哪些工程", ["date"]),
# ("截止目前公司有哪些{project_status}工程", ["project_status"]),
# ("安徽送变电公司{date}有哪些工程{project_status}", ["date", "project_status"]),
# #分公司和项目部
# ("{implementation_organization}{date}工程详情{project_status}",
# ["implementation_organization", "date", "project_status"]),
# ("{date}{implementation_organization}{project_department}有哪些工程{project_status}",
# ["date", "implementation_organization", "project_department", "project_status"]),
# #建管区域和单位
# ("{date}{construction_area}地区风险等级为{risk_level}工程具体情况?", ["date", "construction_area", "risk_level"]),
#
# ("{construction_area}{date}地区有哪些{project_status}工程?", ["construction_area", "date", "project_status"]),
#
# ("{construction_unit}有哪些工程{project_status}", ["construction_unit","project_status"]),
#
# #分包商
# ("{subcontractor}有多少{project_status}工程", ["subcontractor", "project_status"]),
# ("送变电公司{project_department}工程详情?", ["project_department"]),
# #项目经理
# ("{project_manager}有多少工程{project_status}", ["project_manager","project_status"]),
# #班组名称
# ("{team_leader}工程具体情况", ["team_leader"]),
# #工程性质
# ("公司{project_type}类的工程有哪些?", ["project_type"]),
# #风险等级
# ("公司{risk_level}风险的{project_status}工程有那些?", ["risk_level", "project_status"]),
# #询问工程数量时有工程性质和风险等级吗
# ]
# },
"项目部数量查询": { "项目部数量查询": {
"date": ["今天","最近"], "date": ["今天","最近"],
"templates": [ "templates": [
@ -1301,10 +1303,13 @@ TEMPLATE_CONFIG = {
"date": ["今天","最近"], "date": ["今天","最近"],
"templates": [ "templates": [
#公司 #公司
("公司具体项目部", []),
("公司有哪些项目部", []), ("公司有哪些项目部", []),
("安徽送变电公司项目管理部详情", []), ("安徽送变电公司项目管理部详情", []),
("安徽送变电公司具体项目部", []),
#分公司 #分公司
("{implementation_organization}项目部详情", ["implementation_organization"]), ("{implementation_organization}项目部详情", ["implementation_organization"]),
("{implementation_organization}具体项目部", ["implementation_organization"]),
("{implementation_organization}有哪些项目管理部", ["implementation_organization"]), ("{implementation_organization}有哪些项目管理部", ["implementation_organization"]),
#请帮我查一下 #请帮我查一下
("请帮我查一下公司项目部详情", []), ("请帮我查一下公司项目部详情", []),
@ -1328,6 +1333,7 @@ TEMPLATE_CONFIG = {
#公司 #公司
("{project_name}建管单位情况", ["project_name"]), ("{project_name}建管单位情况", ["project_name"]),
("{project_name}建管单位详情", ["project_name"]), ("{project_name}建管单位详情", ["project_name"]),
("{project_name}具体建管单位", ["project_name"]),
("请介绍下{construction_unit}详情", ["construction_unit"]), ("请介绍下{construction_unit}详情", ["construction_unit"]),
("请介绍下{construction_unit}情况", ["construction_unit"]), ("请介绍下{construction_unit}情况", ["construction_unit"]),
] ]
@ -1348,10 +1354,10 @@ TEMPLATE_CONFIG = {
"date": ["今天","最近"], "date": ["今天","最近"],
"templates": [ "templates": [
#公司 #公司
("{project_name}分包单位详情", ["project_name"]), ("{project_name}具体分包单位详情", ["project_name"]),
("{project_name}分包商情况", ["project_name"]), ("{project_name}分包商情况", ["project_name"]),
("请介绍下{subcontractor}详情", ["subcontractor"]), ("请介绍下{subcontractor}详情", ["subcontractor"]),
("请介绍下{subcontractor}情况", ["subcontractor"]), ("请介绍下具体{subcontractor}情况", ["subcontractor"]),
] ]
}, },

View File

@ -2,17 +2,15 @@ import json
import paddle import paddle
from paddlenlp.datasets import MapDataset from paddlenlp.datasets import MapDataset
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer
from paddlenlp.trainer import Trainer, TrainingArguments, EarlyStoppingCallback from paddlenlp.trainer import Trainer, TrainingArguments
from paddlenlp.data import DataCollatorForTokenClassification from paddlenlp.data import DataCollatorForTokenClassification
# === 1. 加载数据 === # === 1. 加载数据 ===
def load_dataset(data_path): def load_dataset(data_path):
with open(data_path, "r", encoding="utf-8") as f: with open(data_path, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
return MapDataset(data) return MapDataset(data)
# === 2. 预处理数据 === # === 2. 预处理数据 ===
def preprocess_function(example, tokenizer): def preprocess_function(example, tokenizer):
# 预定义实体类型列表 # 预定义实体类型列表
@ -62,7 +60,7 @@ def preprocess_function(example, tokenizer):
# === 3. 加载 UIE 预训练模型 === # === 3. 加载 UIE 预训练模型 ===
lsmodel = ErnieForTokenClassification.from_pretrained("uie-base", num_classes=35) # 3 类 (O, B, I) model = ErnieForTokenClassification.from_pretrained("uie-base", num_classes=35) # 3 类 (O, B, I)
tokenizer = ErnieTokenizer.from_pretrained("uie-base") tokenizer = ErnieTokenizer.from_pretrained("uie-base")
# === 4. 加载数据集 === # === 4. 加载数据集 ===
@ -73,6 +71,7 @@ dev_dataset = load_dataset("data/val.json") # 验证数据集
train_dataset = train_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False) train_dataset = train_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False)
dev_dataset = dev_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False) dev_dataset = dev_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False)
# === 6. 数据整理 === # === 6. 数据整理 ===
data_collator = DataCollatorForTokenClassification(tokenizer, padding=True) data_collator = DataCollatorForTokenClassification(tokenizer, padding=True)
@ -89,30 +88,13 @@ training_args = TrainingArguments(
save_total_limit=1, # 只保留最新 2 个模型 save_total_limit=1, # 只保留最新 2 个模型
logging_dir="./logs", logging_dir="./logs",
logging_steps=100, logging_steps=100,
eval_steps=5000, #evaluation_strategy="steps"时生效 eval_steps=5000,
save_steps=5000, #save_strategy="steps"时生效 save_steps=5000,
seed=1000, seed=1000,
load_best_model_at_end=True, load_best_model_at_end=True,
) )
# === 8. 创建 EarlyStoppingCallback 实例 === # === 8. 训练 ===
early_stopping_callback = EarlyStoppingCallback(
early_stopping_patience=2, # 连续多少次评估不提升就停
early_stopping_threshold=0.01 # 最小提升幅度例如设为0.01表示至少提升1%
)
def compute_metrics(eval_preds):
predictions, labels = eval_preds
preds = predictions.argmax(axis=-1)
correct = (preds == labels).astype(int)
accuracy = correct.sum() / correct.size
return {"accuracy": accuracy}
training_args.metric_for_best_model = "accuracy"
# === 9. 训练 ===
trainer = Trainer( trainer = Trainer(
model=model, model=model,
args=training_args, args=training_args,
@ -120,12 +102,11 @@ trainer = Trainer(
eval_dataset=dev_dataset, eval_dataset=dev_dataset,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
compute_metrics=compute_metrics,
callbacks=[early_stopping_callback], # 添加 EarlyStopping 回调
) )
trainer.train() trainer.train()
# 为模型定义输入规格 # 为模型定义输入规格
input_spec = [ input_spec = [
paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="input_ids"), paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="input_ids"),
@ -133,3 +114,4 @@ input_spec = [
paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="position_ids"), paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="position_ids"),
paddle.static.InputSpec(shape=[None, 512], dtype="float32", name="attention_mask") paddle.static.InputSpec(shape=[None, 512], dtype="float32", name="attention_mask")
] ]