新增10中意图数据
This commit is contained in:
parent
fe34128d83
commit
2f4dff403d
|
|
@ -7,7 +7,7 @@ import numpy as np
|
|||
import functools
|
||||
from paddle.nn import CrossEntropyLoss
|
||||
from paddlenlp.data import DataCollatorWithPadding
|
||||
from paddlenlp.trainer import Trainer, TrainingArguments, EarlyStoppingCallback
|
||||
from paddlenlp.trainer import Trainer, TrainingArguments
|
||||
import os
|
||||
from sklearn.metrics import precision_score, recall_score, f1_score
|
||||
|
||||
|
|
@ -140,7 +140,6 @@ def main():
|
|||
eval_dataset=test_ds,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics, # 使用自定义的评估指标
|
||||
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
|
||||
)
|
||||
|
||||
# 训练模型
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ BASE_DATA = {
|
|||
"person_query_types": ["班组", "工程", "分公司", "实时组织", "项目部", "项目管理部"],
|
||||
|
||||
# 工程状态
|
||||
"project_status_s": ["在建", "在作业", "在施工"]
|
||||
"project_status_s": ["在建", "在作业", "在施工",""]
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -1137,8 +1137,8 @@ TEMPLATE_CONFIG = {
|
|||
("实施组织数量", []),
|
||||
("分公司详情", []),
|
||||
("实施组织详情", []),
|
||||
("公司有哪些分公司", []),
|
||||
("公司有哪些实施组织", []),
|
||||
("公司具体有哪些分公司", []),
|
||||
("公司具体有哪些实施组织", []),
|
||||
("{implementation_organization}详情", ["implementation_organization"]),
|
||||
("{implementation_organization}情况", ["implementation_organization"]),
|
||||
("请帮我查一下具体分公司详情", []),
|
||||
|
|
@ -1147,139 +1147,141 @@ TEMPLATE_CONFIG = {
|
|||
]
|
||||
},
|
||||
|
||||
# "工程数量查询": {
|
||||
# "date": ["今日", "昨日", "2024年5月24日", "5月24日", "今天", "昨天"],
|
||||
# "templates": [
|
||||
# #公司
|
||||
# ("{date}公司有多少工程", ["date"]),
|
||||
# ("{date}安徽送变电公司有多少工程{project_status}", ["date", "project_status"]),
|
||||
# #分公司和项目部
|
||||
# ("{implementation_organization}{date}有多少工程{project_status}",
|
||||
# ["implementation_organization", "date", "project_status"]),
|
||||
# ("{implementation_organization}{project_department}{date}有多少工程{project_status}",
|
||||
# ["implementation_organization", "project_department", "date", "project_status"]),
|
||||
# #建管区域和单位
|
||||
# ("{construction_area}地区{date}风险等级为{risk_level}有多少工程?", ["construction_area", "date", "risk_level"]),
|
||||
#
|
||||
# ("{construction_area}地区{date}有多少工程{project_status}?", ["construction_area", "date", "project_status"]),
|
||||
#
|
||||
# ("{construction_unit}{date}有多少工程{project_status}?", ["construction_unit", "date","project_status"]),
|
||||
#
|
||||
# #分包商
|
||||
# ("{date}{subcontractor}有多少工程{project_status}", ["date", "subcontractor", "project_status"]),
|
||||
# ("{date}送变电公司{project_department}有多少工程?", ["date", "project_department"]),
|
||||
# #项目经理
|
||||
# ("{date}{project_manager}有多少工程{project_status}", ["date", "project_manager","project_status"]),
|
||||
# #班组名称
|
||||
# ("{team_leader}{date}有多少工程", ["team_leader", "date"]),
|
||||
# #工程性质
|
||||
# ("公司{date}{project_type}的工程有多少?", ["date", "project_type"]),
|
||||
# #风险等级
|
||||
# ("公司{date}{risk_level}风险的{project_status}工程有多少?", ["date", "risk_level", "project_status"]),
|
||||
# #询问工程数量时有工程性质和风险等级吗
|
||||
# ]
|
||||
# },
|
||||
#
|
||||
# "工程详情查询": {
|
||||
# "date": ["今日", "昨日", "2024年5月24日", "5月24日", "今天", "昨天"],
|
||||
# "templates": [
|
||||
# #公司
|
||||
# ("{date}公司有哪些工程", ["date"]),
|
||||
# ("{date}安徽送变电公司有哪些工程{project_status}", ["date", "project_status"]),
|
||||
# #分公司和项目部
|
||||
# ("{implementation_organization}{date}工程详情{project_status}",
|
||||
# ["implementation_organization", "date", "project_status"]),
|
||||
# ("{implementation_organization}{project_department}{date}有哪些工程{project_status}",
|
||||
# ["implementation_organization", "project_department", "date", "project_status"]),
|
||||
# #建管区域和单位
|
||||
# ("{construction_area}地区{date}风险等级为{risk_level}工程具体情况?", ["construction_area", "date", "risk_level"]),
|
||||
#
|
||||
# ("{construction_area}地区{date}有哪些工程{project_status}?", ["construction_area", "date", "project_status"]),
|
||||
#
|
||||
# ("{construction_unit}{date}有哪些工程{project_status}?", ["construction_unit", "date","project_status"]),
|
||||
#
|
||||
# #分包商
|
||||
# ("{date}{subcontractor}有多少工程{project_status}", ["date", "subcontractor", "project_status"]),
|
||||
# ("{date}送变电公司{project_department}工程详情?", ["date", "project_department"]),
|
||||
# #项目经理
|
||||
# ("{date}{project_manager}有多少工程{project_status}", ["date", "project_manager","project_status"]),
|
||||
# #班组名称
|
||||
# ("{team_leader}{date}工程具体情况", ["team_leader", "date"]),
|
||||
# #工程性质
|
||||
# ("公司{date}{project_type}的工程有哪些?", ["date", "project_type"]),
|
||||
# #风险等级
|
||||
# ("公司{date}{risk_level}风险的{project_status}工程有那些?", ["date", "risk_level", "project_status"]),
|
||||
# #询问工程数量时有工程性质和风险等级吗
|
||||
# ]
|
||||
# },
|
||||
|
||||
"工程数量查询": {
|
||||
"date": ["今天","最近"],
|
||||
"date": ["今日","今天", ""],
|
||||
"templates": [
|
||||
#公司
|
||||
("公司有多少工程", []),
|
||||
("安徽送变电公司有多少工程{project_status}", ["project_status"]),
|
||||
("{date}公司有多少工程", ["date"]),
|
||||
("{date}安徽送变电公司有多少个工程{project_status}", ["date", "project_status"]),
|
||||
#分公司和项目部
|
||||
("{implementation_organization}有多少工程{project_status}",
|
||||
["implementation_organization", "project_status"]),
|
||||
("{implementation_organization}{project_department}有多少工程{project_status}",
|
||||
["implementation_organization", "project_department", "project_status"]),
|
||||
("{implementation_organization}{date}有多少工程{project_status}",
|
||||
["implementation_organization", "date", "project_status"]),
|
||||
("{implementation_organization}{project_department}{date}有多少个工程{project_status}",
|
||||
["implementation_organization", "project_department", "date", "project_status"]),
|
||||
#建管区域和单位
|
||||
("{construction_area}地区风险等级为{risk_level}有多少工程?", ["construction_area", "risk_level"]),
|
||||
("{construction_area}地区{date}风险等级为{risk_level}有多少工程?", ["construction_area", "date", "risk_level"]),
|
||||
|
||||
("{construction_area}地区有多少工程{project_status}?", ["construction_area", "project_status"]),
|
||||
("{construction_area}地区{date}有多少个工程{project_status}?", ["construction_area", "date", "project_status"]),
|
||||
|
||||
("{construction_unit}有多少工程{project_status}?", ["construction_unit","project_status"]),
|
||||
("{construction_unit}{date}有多少工程{project_status}?", ["construction_unit", "date","project_status"]),
|
||||
|
||||
#分包商
|
||||
("{subcontractor}有多少工程{project_status}", ["subcontractor", "project_status"]),
|
||||
("安徽送变电公司{project_department}有多少工程?", ["project_department"]),
|
||||
("{date}{subcontractor}有多少工程{project_status}", ["date", "subcontractor", "project_status"]),
|
||||
("{date}送变电公司{project_department}有多少个工程?", ["date", "project_department"]),
|
||||
#项目经理
|
||||
("{project_manager}有多少工程{project_status}", ["project_manager","project_status"]),
|
||||
("{date}{project_manager}有多少工程{project_status}", ["date", "project_manager","project_status"]),
|
||||
("{date}{project_manager}负责多少个工程{project_status}", ["date", "project_manager","project_status"]),
|
||||
#班组名称
|
||||
("{team_leader}有多少{project_status}工程", ["team_leader", "project_status"]),
|
||||
("{team_leader}{date}有多少工程", ["team_leader", "date"]),
|
||||
#工程性质
|
||||
("公司{project_type}类的工程有多少?", ["project_type"]),
|
||||
("公司{date}{project_type}的工程有多少个?", ["date", "project_type"]),
|
||||
#风险等级
|
||||
("公司{risk_level}风险的{project_status}工程有多少?", ["risk_level", "project_status"]),
|
||||
("公司{date}{risk_level}风险的{project_status}工程有多少?", ["date", "risk_level", "project_status"]),
|
||||
#询问工程数量时有工程性质和风险等级吗
|
||||
]
|
||||
},
|
||||
|
||||
"工程详情查询": {
|
||||
"date": ["今天","最近"],
|
||||
"date": ["今日","今天", ""],
|
||||
"templates": [
|
||||
#公司
|
||||
("公司有哪些工程", []),
|
||||
("截止目前公司有哪些{project_status}工程", ["project_status"]),
|
||||
("安徽送变电公司有哪些工程{project_status}", ["project_status"]),
|
||||
("{date}公司有哪些工程", ["date"]),
|
||||
("{date}安徽送变电公司有哪些工程{project_status}", ["date", "project_status"]),
|
||||
#分公司和项目部
|
||||
("{implementation_organization}工程详情{project_status}",
|
||||
["implementation_organization", "project_status"]),
|
||||
("{implementation_organization}{project_department}有哪些工程{project_status}",
|
||||
["implementation_organization", "project_department", "project_status"]),
|
||||
("{implementation_organization}{date}工程详情{project_status}",
|
||||
["implementation_organization", "date", "project_status"]),
|
||||
("{implementation_organization}{project_department}{date}有哪些工程{project_status}",
|
||||
["implementation_organization", "project_department", "date", "project_status"]),
|
||||
#建管区域和单位
|
||||
("{construction_area}地区风险等级为{risk_level}工程具体情况?", ["construction_area", "risk_level"]),
|
||||
("{construction_area}地区{date}风险等级为{risk_level}工程具体情况?", ["construction_area", "date", "risk_level"]),
|
||||
|
||||
("{construction_area}地区有哪些{project_status}工程?", ["construction_area", "project_status"]),
|
||||
("{construction_area}地区{date}有哪些工程{project_status}?", ["construction_area", "date", "project_status"]),
|
||||
|
||||
("{construction_unit}有哪些工程{project_status}?", ["construction_unit","project_status"]),
|
||||
("{construction_unit}{date}有哪些工程{project_status}?", ["construction_unit", "date","project_status"]),
|
||||
|
||||
#分包商
|
||||
("{subcontractor}有多少{project_status}工程", ["subcontractor", "project_status"]),
|
||||
("送变电公司{project_department}工程详情?", ["project_department"]),
|
||||
("{date}{subcontractor}有哪些工程{project_status}", ["date", "subcontractor", "project_status"]),
|
||||
("{date}送变电公司{project_department}工程详情?", ["date", "project_department"]),
|
||||
#项目经理
|
||||
("{project_manager}有多少工程{project_status}", ["project_manager","project_status"]),
|
||||
("{date}{project_manager}负责哪些工程{project_status}", ["date", "project_manager","project_status"]),
|
||||
#班组名称
|
||||
("{team_leader}工程具体情况", ["team_leader"]),
|
||||
("{team_leader}{date}工程具体情况", ["team_leader", "date"]),
|
||||
#工程性质
|
||||
("公司{project_type}类的工程有哪些?", ["project_type"]),
|
||||
("公司{date}{project_type}的工程有哪些?", ["date", "project_type"]),
|
||||
#风险等级
|
||||
("公司{risk_level}风险的{project_status}工程有那些?", ["risk_level", "project_status"]),
|
||||
("公司{date}{risk_level}风险的{project_status}工程有哪些?", ["date", "risk_level", "project_status"]),
|
||||
#询问工程数量时有工程性质和风险等级吗
|
||||
]
|
||||
},
|
||||
|
||||
# "工程数量查询": {
|
||||
# "date": ["今天","今日", ""],
|
||||
# "templates": [
|
||||
# #公司
|
||||
# ("{date}公司有多少工程", ["date"]),
|
||||
#
|
||||
# ("安徽送变电公司有多少工程{project_status}", ["project_status"]),
|
||||
# #分公司和项目部
|
||||
# ("{date}{implementation_organization}有多少工程{project_status}",
|
||||
# ["date", "implementation_organization", "project_status"]),
|
||||
# ("{implementation_organization}{date}{project_department}有多少工程{project_status}",
|
||||
# ["implementation_organization", "date", "project_department", "project_status"]),
|
||||
# #建管区域和单位
|
||||
# ("{date}{construction_area}地区风险等级为{risk_level}有多少工程?", ["""construction_area", "risk_level"]),
|
||||
#
|
||||
# ("{construction_area}地区有多少工程{project_status}?", ["construction_area", "project_status"]),
|
||||
#
|
||||
# ("{construction_unit}有多少工程{project_status}?", ["construction_unit","project_status"]),
|
||||
#
|
||||
# #分包商
|
||||
# ("{subcontractor}有多少工程{project_status}", ["subcontractor", "project_status"]),
|
||||
# ("安徽送变电公司{project_department}有多少工程?", ["project_department"]),
|
||||
# #项目经理
|
||||
# ("{project_manager}有多少工程{project_status}", ["project_manager","project_status"]),
|
||||
# #班组名称
|
||||
# ("{team_leader}有多少{project_status}工程", ["team_leader", "project_status"]),
|
||||
# #工程性质
|
||||
# ("公司{project_type}类的工程有多少?", ["project_type"]),
|
||||
# #风险等级
|
||||
# ("公司{risk_level}风险的{project_status}工程有多少?", ["risk_level", "project_status"]),
|
||||
# #询问工程数量时有工程性质和风险等级吗
|
||||
# ]
|
||||
# },
|
||||
#
|
||||
# "工程详情查询": {
|
||||
# "date": ["今天","今日",""],
|
||||
# "templates": [
|
||||
# #公司
|
||||
# ("{date}公司有哪些工程", ["date"]),
|
||||
# ("截止目前公司有哪些{project_status}工程", ["project_status"]),
|
||||
# ("安徽送变电公司{date}有哪些工程{project_status}", ["date", "project_status"]),
|
||||
# #分公司和项目部
|
||||
# ("{implementation_organization}{date}工程详情{project_status}",
|
||||
# ["implementation_organization", "date", "project_status"]),
|
||||
# ("{date}{implementation_organization}{project_department}有哪些工程{project_status}",
|
||||
# ["date", "implementation_organization", "project_department", "project_status"]),
|
||||
# #建管区域和单位
|
||||
# ("{date}{construction_area}地区风险等级为{risk_level}工程具体情况?", ["date", "construction_area", "risk_level"]),
|
||||
#
|
||||
# ("{construction_area}{date}地区有哪些{project_status}工程?", ["construction_area", "date", "project_status"]),
|
||||
#
|
||||
# ("{construction_unit}有哪些工程{project_status}?", ["construction_unit","project_status"]),
|
||||
#
|
||||
# #分包商
|
||||
# ("{subcontractor}有多少{project_status}工程", ["subcontractor", "project_status"]),
|
||||
# ("送变电公司{project_department}工程详情?", ["project_department"]),
|
||||
# #项目经理
|
||||
# ("{project_manager}有多少工程{project_status}", ["project_manager","project_status"]),
|
||||
# #班组名称
|
||||
# ("{team_leader}工程具体情况", ["team_leader"]),
|
||||
# #工程性质
|
||||
# ("公司{project_type}类的工程有哪些?", ["project_type"]),
|
||||
# #风险等级
|
||||
# ("公司{risk_level}风险的{project_status}工程有那些?", ["risk_level", "project_status"]),
|
||||
# #询问工程数量时有工程性质和风险等级吗
|
||||
# ]
|
||||
# },
|
||||
|
||||
"项目部数量查询": {
|
||||
"date": ["今天","最近"],
|
||||
"templates": [
|
||||
|
|
@ -1301,10 +1303,13 @@ TEMPLATE_CONFIG = {
|
|||
"date": ["今天","最近"],
|
||||
"templates": [
|
||||
#公司
|
||||
("公司具体项目部", []),
|
||||
("公司有哪些项目部", []),
|
||||
("安徽送变电公司项目管理部详情", []),
|
||||
("安徽送变电公司具体项目部", []),
|
||||
#分公司
|
||||
("{implementation_organization}项目部详情", ["implementation_organization"]),
|
||||
("{implementation_organization}具体项目部", ["implementation_organization"]),
|
||||
("{implementation_organization}有哪些项目管理部", ["implementation_organization"]),
|
||||
#请帮我查一下
|
||||
("请帮我查一下公司项目部详情", []),
|
||||
|
|
@ -1328,6 +1333,7 @@ TEMPLATE_CONFIG = {
|
|||
#公司
|
||||
("{project_name}建管单位情况", ["project_name"]),
|
||||
("{project_name}建管单位详情", ["project_name"]),
|
||||
("{project_name}具体建管单位", ["project_name"]),
|
||||
("请介绍下{construction_unit}详情", ["construction_unit"]),
|
||||
("请介绍下{construction_unit}情况", ["construction_unit"]),
|
||||
]
|
||||
|
|
@ -1348,10 +1354,10 @@ TEMPLATE_CONFIG = {
|
|||
"date": ["今天","最近"],
|
||||
"templates": [
|
||||
#公司
|
||||
("{project_name}分包单位详情", ["project_name"]),
|
||||
("{project_name}具体分包单位详情", ["project_name"]),
|
||||
("{project_name}分包商情况", ["project_name"]),
|
||||
("请介绍下{subcontractor}详情", ["subcontractor"]),
|
||||
("请介绍下{subcontractor}情况", ["subcontractor"]),
|
||||
("请介绍下具体{subcontractor}情况", ["subcontractor"]),
|
||||
]
|
||||
},
|
||||
|
||||
|
|
|
|||
34
uie/train.py
34
uie/train.py
|
|
@ -2,17 +2,15 @@ import json
|
|||
import paddle
|
||||
from paddlenlp.datasets import MapDataset
|
||||
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer
|
||||
from paddlenlp.trainer import Trainer, TrainingArguments, EarlyStoppingCallback
|
||||
from paddlenlp.trainer import Trainer, TrainingArguments
|
||||
from paddlenlp.data import DataCollatorForTokenClassification
|
||||
|
||||
|
||||
# === 1. 加载数据 ===
|
||||
def load_dataset(data_path):
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return MapDataset(data)
|
||||
|
||||
|
||||
# === 2. 预处理数据 ===
|
||||
def preprocess_function(example, tokenizer):
|
||||
# 预定义实体类型列表
|
||||
|
|
@ -62,7 +60,7 @@ def preprocess_function(example, tokenizer):
|
|||
|
||||
|
||||
# === 3. 加载 UIE 预训练模型 ===
|
||||
lsmodel = ErnieForTokenClassification.from_pretrained("uie-base", num_classes=35) # 3 类 (O, B, I)
|
||||
model = ErnieForTokenClassification.from_pretrained("uie-base", num_classes=35) # 3 类 (O, B, I)
|
||||
tokenizer = ErnieTokenizer.from_pretrained("uie-base")
|
||||
|
||||
# === 4. 加载数据集 ===
|
||||
|
|
@ -73,6 +71,7 @@ dev_dataset = load_dataset("data/val.json") # 验证数据集
|
|||
train_dataset = train_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False)
|
||||
dev_dataset = dev_dataset.map(lambda x: preprocess_function(x, tokenizer), lazy=False)
|
||||
|
||||
|
||||
# === 6. 数据整理 ===
|
||||
data_collator = DataCollatorForTokenClassification(tokenizer, padding=True)
|
||||
|
||||
|
|
@ -89,30 +88,13 @@ training_args = TrainingArguments(
|
|||
save_total_limit=1, # 只保留最新 2 个模型
|
||||
logging_dir="./logs",
|
||||
logging_steps=100,
|
||||
eval_steps=5000, #evaluation_strategy="steps"时生效
|
||||
save_steps=5000, #save_strategy="steps"时生效
|
||||
eval_steps=5000,
|
||||
save_steps=5000,
|
||||
seed=1000,
|
||||
load_best_model_at_end=True,
|
||||
)
|
||||
|
||||
# === 8. 创建 EarlyStoppingCallback 实例 ===
|
||||
early_stopping_callback = EarlyStoppingCallback(
|
||||
early_stopping_patience=2, # 连续多少次评估不提升就停
|
||||
early_stopping_threshold=0.01 # 最小提升幅度(例如设为0.01表示至少提升1%)
|
||||
)
|
||||
|
||||
|
||||
def compute_metrics(eval_preds):
|
||||
predictions, labels = eval_preds
|
||||
preds = predictions.argmax(axis=-1)
|
||||
correct = (preds == labels).astype(int)
|
||||
accuracy = correct.sum() / correct.size
|
||||
return {"accuracy": accuracy}
|
||||
|
||||
|
||||
training_args.metric_for_best_model = "accuracy"
|
||||
|
||||
# === 9. 训练 ===
|
||||
# === 8. 训练 ===
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
|
|
@ -120,12 +102,11 @@ trainer = Trainer(
|
|||
eval_dataset=dev_dataset,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics,
|
||||
callbacks=[early_stopping_callback], # 添加 EarlyStopping 回调
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
|
||||
# 为模型定义输入规格
|
||||
input_spec = [
|
||||
paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="input_ids"),
|
||||
|
|
@ -133,3 +114,4 @@ input_spec = [
|
|||
paddle.static.InputSpec(shape=[None, 512], dtype="int64", name="position_ids"),
|
||||
paddle.static.InputSpec(shape=[None, 512], dtype="float32", name="attention_mask")
|
||||
]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue