重构模型训练
This commit is contained in:
parent
9e1182f766
commit
a20f513d38
|
|
@ -0,0 +1,42 @@
|
|||
import paddle
|
||||
import numpy as np
|
||||
from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class IntentRecognition:
|
||||
def __init__(self, model_path: str, labels: list):
|
||||
# 初始化模型和tokenizer
|
||||
self.model = ErnieForSequenceClassification.from_pretrained(model_path)
|
||||
self.tokenizer = ErnieTokenizer.from_pretrained(model_path)
|
||||
self.labels = labels
|
||||
|
||||
def predict(self, query: str):
|
||||
"""
|
||||
对输入的查询文本进行意图识别,返回预测的标签和概率。
|
||||
|
||||
:param query: 待识别的文本
|
||||
:return: (predicted_label, predicted_probability)
|
||||
"""
|
||||
# 对输入文本进行tokenization
|
||||
inputs = self.tokenizer(query, max_length=256, truncation=True, padding='max_length', return_tensors="pd")
|
||||
|
||||
# 将tokenized inputs转换为paddle tensor
|
||||
input_ids = paddle.to_tensor(inputs["input_ids"])
|
||||
|
||||
# 模型推理得到 logits
|
||||
logits = self.model(input_ids)
|
||||
|
||||
# 使用Softmax将 logits 转换为概率分布
|
||||
probabilities = F.softmax(logits, axis=-1)
|
||||
|
||||
# 获取最大概率的标签和其概率值
|
||||
max_prob_idx = np.argmax(probabilities.numpy(), axis=-1)
|
||||
max_prob_value = np.max(probabilities.numpy(), axis=-1)
|
||||
|
||||
# 根据预测的标签索引映射到类别名称
|
||||
predicted_label = self.labels[max_prob_idx[0]] # 获取最大概率对应的标签
|
||||
predicted_probability = float(max_prob_value[0]) # 获取最大概率值
|
||||
predicted_id = int(max_prob_idx[0]) # 获取最大概率对应的标签
|
||||
|
||||
return predicted_label, predicted_probability,predicted_id
|
||||
275
api/mian.py
275
api/mian.py
|
|
@ -1,30 +1,49 @@
|
|||
import json
|
||||
import pydantic
|
||||
from flask import Flask, jsonify, request
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import HTTPException
|
||||
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer
|
||||
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer, ErnieForSequenceClassification
|
||||
import paddle
|
||||
import numpy as np
|
||||
import paddle.nn.functional as F # 用于 Softmax
|
||||
from typing import List, Dict
|
||||
from pydantic import ValidationError
|
||||
|
||||
# 1. 加载模型和 tokenizer
|
||||
model_path = R"E:\workingSpace\PycharmProjects\Intention_dev\uie\uie_ner\checkpoint-4320" # 你的模型路径
|
||||
model = ErnieForTokenClassification.from_pretrained(model_path)
|
||||
tokenizer = ErnieTokenizer.from_pretrained(model_path)
|
||||
from api.intentRecognition import IntentRecognition
|
||||
from api.slotRecognition import SlotRecognition
|
||||
|
||||
# 常量
|
||||
MODEL_ERNIE_PATH = R"E:\workingSpace\PycharmProjects\Intention_dev\ernie\output\checkpoint-4160"
|
||||
MODEL_UIE_PATH = R"E:\workingSpace\PycharmProjects\Intention_dev\uie\uie_ner\checkpoint-4320"
|
||||
|
||||
# 类别名称列表
|
||||
labels = [
|
||||
"天气查询", "互联网查询", "页面切换", "日计划数量查询", "周计划数量查询",
|
||||
"日计划作业内容", "周计划作业内容", "施工人数", "作业考勤人数", "知识问答"
|
||||
]
|
||||
|
||||
# 标签映射
|
||||
label_map = {
|
||||
0: 'O', 1: 'B-date', 11: 'I-date',
|
||||
2: 'B-project_name', 12: 'I-project_name',
|
||||
3: 'B-project_type', 13: 'I-project_type',
|
||||
4: 'B-construction_unit', 14: 'I-construction_unit',
|
||||
5: 'B-implementation_organization', 15: 'I-implementation_organization',
|
||||
6: 'B-project_department', 16: 'I-project_department',
|
||||
7: 'B-project_manager', 17: 'I-project_manager',
|
||||
2: 'B-projectName', 12: 'I-projectName',
|
||||
3: 'B-projectType', 13: 'I-projectType',
|
||||
4: 'B-constructionUnit', 14: 'I-constructionUnit',
|
||||
5: 'B-implementationOrganization', 15: 'I-implementationOrganization',
|
||||
6: 'B-projectDepartment', 16: 'I-projectDepartment',
|
||||
7: 'B-projectManager', 17: 'I-projectManager',
|
||||
8: 'B-subcontractor', 18: 'I-subcontractor',
|
||||
9: 'B-team_leader', 19: 'I-team_leader',
|
||||
10: 'B-risk_level', 20: 'I-risk_level'
|
||||
9: 'B-teamLeader', 19: 'I-teamLeader',
|
||||
10: 'B-riskLevel', 20: 'I-riskLevel'
|
||||
}
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
# 初始化工具类
|
||||
intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
|
||||
|
||||
# 初始化槽位识别工具类
|
||||
slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map)
|
||||
# 设置Flask应用
|
||||
app = Flask(__name__)
|
||||
|
||||
# 统一的异常处理函数
|
||||
@app.errorhandler(Exception)
|
||||
|
|
@ -46,63 +65,201 @@ def handle_exception(e):
|
|||
}), 500
|
||||
|
||||
|
||||
@app.route('/')
|
||||
def hello_world():
|
||||
"""示例路由,返回 Hello World"""
|
||||
return jsonify({"message": "Hello, world!"})
|
||||
def validate_user(data):
|
||||
"""验证用户ID"""
|
||||
if data.get("user_id") != '3bb66776-1722-4c36-b14a-73dd210fe750':
|
||||
return jsonify(
|
||||
code=401,
|
||||
msg='权限验证失败,请联系接口开发人员',
|
||||
label=-1,
|
||||
probability=-1
|
||||
), 401
|
||||
return None
|
||||
|
||||
|
||||
@app.route('/predict', methods=['POST'])
|
||||
def predict():
|
||||
"""处理预测请求"""
|
||||
data = request.get_json()
|
||||
class LabelMessage(BaseModel):
|
||||
text: str = Field(..., description="消息内容")
|
||||
user_id: str = Field(..., description="消息内容")
|
||||
|
||||
# 提取文本
|
||||
text = data.get("text", "")
|
||||
if not text:
|
||||
return jsonify({"error": "No text provided"}), 400
|
||||
|
||||
# 处理输入文本
|
||||
inputs = tokenizer(text, max_len=512, return_tensors="pd")
|
||||
model.eval()
|
||||
# 每条消息的结构
|
||||
class Message(BaseModel):
|
||||
role: str = Field(..., description="消息内容")
|
||||
content: str = Field(..., description="消息内容")
|
||||
# timestamp: str = Field(..., description="消息时间戳")
|
||||
|
||||
with paddle.no_grad():
|
||||
logits = model(**inputs)
|
||||
predictions = paddle.argmax(logits, axis=-1)
|
||||
|
||||
# 解析预测结果
|
||||
predicted_labels = predictions.numpy()[0]
|
||||
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].numpy())
|
||||
# 请求数据的结构
|
||||
class RequestData(BaseModel):
|
||||
messages: List[Message] = Field(..., description="消息列表")
|
||||
user_id: str = Field(..., description="用户ID")
|
||||
|
||||
entities = {}
|
||||
current_entity = None
|
||||
current_label = None
|
||||
|
||||
for token, label_id in zip(tokens, predicted_labels):
|
||||
label = label_map.get(label_id, "O")
|
||||
# 意图识别
|
||||
@app.route('/intent_reco', methods=['POST'])
|
||||
def intent_reco():
|
||||
"""意图识别"""
|
||||
try:
|
||||
# 获取请求中的 JSON 数据
|
||||
data = request.get_json()
|
||||
request_data = LabelMessage(**data) # Pydantic 会验证数据结构
|
||||
text = request_data.text
|
||||
user_id = request_data.user_id
|
||||
# 检查必需字段
|
||||
if not text:
|
||||
return jsonify({"error": "text is required"}), 400
|
||||
if not user_id:
|
||||
return jsonify({"error": "user_id is required"}), 400
|
||||
|
||||
if label.startswith("B-"): # 开始新实体
|
||||
if current_entity:
|
||||
entities[current_label] = "".join(current_entity)
|
||||
current_entity = [token]
|
||||
current_label = label[2:] # 去掉 B-
|
||||
# 验证用户ID
|
||||
user_validation_error = validate_user(data)
|
||||
if user_validation_error:
|
||||
return user_validation_error
|
||||
|
||||
elif label.startswith("I-") and current_entity and label[2:] == current_label:
|
||||
current_entity.append(token) # 继续合并同一实体
|
||||
# 调用predict方法进行意图识别
|
||||
predicted_label, predicted_probability,predicted_id = intent_recognizer.predict(text)
|
||||
|
||||
else: # 非实体
|
||||
if current_entity:
|
||||
entities[current_label] = "".join(current_entity)
|
||||
current_entity = None
|
||||
current_label = None
|
||||
return jsonify(
|
||||
code=200,
|
||||
msg="成功",
|
||||
int=predicted_id,
|
||||
label=predicted_label,
|
||||
probability=float(predicted_probability)
|
||||
)
|
||||
|
||||
# 处理最后一个实体
|
||||
if current_entity:
|
||||
entities[current_label] = "".join(current_entity)
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
# 输出最终的实体作为 JSON
|
||||
return jsonify(entities)
|
||||
|
||||
# 槽位抽取
|
||||
@app.route('/slot_reco', methods=['POST'])
|
||||
def slot_reco():
|
||||
"""槽位识别"""
|
||||
try:
|
||||
# 获取请求中的 JSON 数据
|
||||
data = request.get_json()
|
||||
request_data = LabelMessage(**data) # Pydantic 会验证数据结构
|
||||
text = request_data.text
|
||||
user_id = request_data.user_id
|
||||
|
||||
# 检查必需字段
|
||||
if not text:
|
||||
return jsonify({"error": "text is required"}), 400
|
||||
if not user_id:
|
||||
return jsonify({"error": "user_id is required"}), 400
|
||||
|
||||
# 验证用户ID
|
||||
user_validation_error = validate_user(data)
|
||||
if user_validation_error:
|
||||
return user_validation_error
|
||||
|
||||
# 调用 recognize 方法进行槽位识别
|
||||
entities = slot_recognizer.recognize(text)
|
||||
|
||||
return jsonify(
|
||||
code=200,
|
||||
msg="成功",
|
||||
slot=entities)
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
|
||||
@app.route('/agent', methods=['POST'])
|
||||
def agent():
|
||||
try:
|
||||
data = request.get_json()
|
||||
# 使用 Pydantic 来验证数据结构
|
||||
request_data = RequestData(**data) # Pydantic 会验证数据结构
|
||||
messages = request_data.messages
|
||||
user_id = request_data.user_id
|
||||
|
||||
# 检查必需字段是否存在
|
||||
if not messages:
|
||||
return jsonify({"error": "messages is required"}), 400
|
||||
if not user_id:
|
||||
return jsonify({"error": "user_id is required"}), 400
|
||||
|
||||
# 验证用户ID(假设这个函数已经定义)
|
||||
user_validation_error = validate_user(data)
|
||||
if user_validation_error:
|
||||
return user_validation_error
|
||||
if len(messages) == 1: # 首轮
|
||||
query = messages[0].content # 使用 Message 对象的 .content 属性
|
||||
# 先进行意图识别
|
||||
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query)
|
||||
# 再进行槽位抽取
|
||||
entities = slot_recognizer.recognize(query)
|
||||
status, sk = check_lost(predicted_label, entities)
|
||||
|
||||
# 返回意图和槽位识别的结果
|
||||
return jsonify({
|
||||
"code": 200,
|
||||
"msg": "成功",
|
||||
"answer": {
|
||||
"int": predicted_id,
|
||||
"label": predicted_label,
|
||||
"probability": predicted_probability,
|
||||
"slot": entities
|
||||
},
|
||||
})
|
||||
|
||||
# 如果是后续轮次(多轮对话),这里只做示例,可能需要根据具体需求进行处理
|
||||
else:
|
||||
query = messages[0].content # 使用 Message 对象的 .content 属性
|
||||
return jsonify({
|
||||
"user_id": user_id,
|
||||
"query": query,
|
||||
"message_count": len(messages)
|
||||
})
|
||||
|
||||
except ValidationError as e:
|
||||
return jsonify({"error": e.errors()}), 400 # 捕捉 Pydantic 错误并返回
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
|
||||
|
||||
|
||||
def check_lost(int_res, slot):
|
||||
# mapping = {
|
||||
# "页面切换":[['页面','应用']],
|
||||
# "作业计划数量查询":[['时间']],
|
||||
# "周计划查询":[['时间']],
|
||||
# "作业内容":[['时间']],
|
||||
# "施工人数":[['时间']],
|
||||
# "作业考勤人数":[['时间']],
|
||||
# }
|
||||
mapping = {
|
||||
1: [['date', 'area']],
|
||||
3: [['page'], ['app'], ['module']],
|
||||
4: [['date']],
|
||||
5: [['date']],
|
||||
6: [['date']],
|
||||
7: [['date']],
|
||||
8: [[]],
|
||||
9: [[]],
|
||||
}
|
||||
if not mapping.__contains__(int_res):
|
||||
return 0, []
|
||||
cur_k = list(slot.keys())
|
||||
idx = -1
|
||||
idx_len = 99
|
||||
for i in range(len(mapping[int_res])):
|
||||
sk = mapping[int_res][i]
|
||||
left = [x for x in sk if x not in cur_k]
|
||||
more = [x for x in cur_k if x not in sk]
|
||||
if len(more) >= 0 and len(left) == 0:
|
||||
idx = i
|
||||
idx_len = 0
|
||||
break
|
||||
if len(left) < idx_len:
|
||||
idx = i
|
||||
idx_len = len(left)
|
||||
|
||||
if idx_len == 0: # 匹配通过
|
||||
return 0, cur_k
|
||||
left = [x for x in mapping[int_res][idx] if x not in cur_k]
|
||||
return 1, left # mapping[int_res][idx]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(host='0.0.0.0', port=5000, debug=True) # 启动 API,调试模式和指定端口
|
||||
app.run(host='0.0.0.0', port=5000, debug=True)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,61 @@
|
|||
import paddle
|
||||
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer
|
||||
|
||||
class SlotRecognition:
|
||||
def __init__(self, model_path: str, label_map: dict):
|
||||
"""
|
||||
初始化槽位识别模型和tokenizer
|
||||
:param model_path: 模型路径
|
||||
:param label_map: 标签映射字典
|
||||
"""
|
||||
self.model = ErnieForTokenClassification.from_pretrained(model_path)
|
||||
self.tokenizer = ErnieTokenizer.from_pretrained(model_path)
|
||||
self.label_map = label_map
|
||||
|
||||
def recognize(self, text: str):
|
||||
"""
|
||||
对输入的文本进行槽位识别,返回识别出的实体。
|
||||
:param text: 输入的文本
|
||||
:return: entities 字典,包含识别出的槽位实体
|
||||
"""
|
||||
# 处理输入文本
|
||||
inputs = self.tokenizer(text, max_length=512, return_tensors="pd")
|
||||
|
||||
# 使用无梯度计算
|
||||
with paddle.no_grad():
|
||||
logits = self.model(**inputs)
|
||||
predictions = paddle.argmax(logits, axis=-1)
|
||||
|
||||
# 解析预测结果
|
||||
predicted_labels = predictions.numpy()[0]
|
||||
tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].numpy())
|
||||
|
||||
entities = {}
|
||||
current_entity = None
|
||||
current_label = None
|
||||
|
||||
for token, label_id in zip(tokens, predicted_labels):
|
||||
label = self.label_map.get(label_id, "O")
|
||||
|
||||
if label.startswith("B-"): # 开始新实体
|
||||
if current_entity:
|
||||
entities[current_label] = "".join(current_entity)
|
||||
current_entity = [token]
|
||||
current_label = label[2:] # 去掉 B-
|
||||
|
||||
elif label.startswith("I-") and current_entity and label[2:] == current_label:
|
||||
current_entity.append(token) # 继续合并同一实体
|
||||
|
||||
else: # 非实体
|
||||
if current_entity:
|
||||
entities[current_label] = "".join(current_entity)
|
||||
current_entity = None
|
||||
current_label = None
|
||||
|
||||
# 处理最后一个实体
|
||||
if current_entity:
|
||||
entities[current_label] = "".join(current_entity)
|
||||
# 对所有实体进行替换:替换每个实体中的 '##' 为 ' '
|
||||
for key, value in entities.items():
|
||||
entities[key] = value.replace('#', '')
|
||||
return entities
|
||||
|
|
@ -1,13 +1,20 @@
|
|||
import paddle
|
||||
import numpy as np
|
||||
from paddlenlp.transformers import ErnieTokenizer
|
||||
import paddle.nn.functional as F # 用于 Softmax
|
||||
from paddlenlp.transformers import ErnieTokenizer, ErnieForSequenceClassification
|
||||
import paddle.nn.functional as F
|
||||
|
||||
# 类别名称列表
|
||||
labels = [
|
||||
"天气查询", "互联网查询", "页面切换", "日计划数量查询", "周计划数量查询",
|
||||
"日计划作业内容", "周计划作业内容", "施工人数", "作业考勤人数", "知识问答"
|
||||
]
|
||||
|
||||
# 加载模型和tokenizer
|
||||
model = paddle.jit.load("trained_model_static") # 加载保存的静态图模型
|
||||
tokenizer = ErnieTokenizer.from_pretrained("E:/workingSpace/PycharmProjects/Intention/models/ernie-3.0-tiny-base-v2-zh")
|
||||
model = ErnieForSequenceClassification.from_pretrained(R"E:\workingSpace\PycharmProjects\Intention_dev\ernie\output\checkpoint-4160") # 使用文本分类模型
|
||||
tokenizer = ErnieTokenizer.from_pretrained(R"E:\workingSpace\PycharmProjects\Intention_dev\ernie\output\checkpoint-4160")
|
||||
|
||||
# 创建输入示例
|
||||
text = "今天送变电二公司有?"
|
||||
text = "胡彬项目经理上一周作业内容是什么?"
|
||||
inputs = tokenizer(text, max_length=256, truncation=True, padding='max_length', return_tensors="pd")
|
||||
|
||||
# 将输入数据转化为 Paddle tensor 格式
|
||||
|
|
@ -18,10 +25,12 @@ model.eval() # 确保模型在推理模式
|
|||
logits = model(input_ids) # 模型推理得到logits
|
||||
|
||||
# 使用 Softmax 转换 logits 为概率
|
||||
probabilities = F.softmax(logits, axis=1) # 归一化 logits 得到概率分布
|
||||
# 获取最大概率的标签
|
||||
max_prob_idx = np.argmax(probabilities.numpy(), axis=1)
|
||||
max_prob_value = np.max(probabilities.numpy(), axis=1)
|
||||
# 输出预测结果
|
||||
print(f"Predicted label: {max_prob_idx}")
|
||||
print(f"Predicted label: {max_prob_value}")
|
||||
probabilities = F.softmax(logits, axis=-1) # 归一化 logits 得到概率分布
|
||||
|
||||
# 获取最大概率的标签(整个句子的意图)
|
||||
max_prob_idx = np.argmax(probabilities.numpy(), axis=-1) # 获取最大概率的标签
|
||||
max_prob_value = np.max(probabilities.numpy(), axis=-1) # 获取最大概率值
|
||||
|
||||
# 根据预测的标签索引映射到类别名称
|
||||
predicted_label = labels[max_prob_idx[0]] # 根据索引获取对应的标签
|
||||
predicted_probability = max_prob_value[0] # 获取最大概率值
|
||||
|
|
|
|||
143
ernie/train.py
143
ernie/train.py
|
|
@ -11,10 +11,14 @@ from paddlenlp.trainer import Trainer, TrainingArguments
|
|||
import os
|
||||
from sklearn.metrics import precision_score, recall_score, f1_score
|
||||
|
||||
|
||||
def load_config(config_path):
|
||||
"""加载 YAML 配置文件"""
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
return yaml.safe_load(f)
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
return yaml.safe_load(f)
|
||||
except Exception as e:
|
||||
raise ValueError(f"读取配置文件时出错: {str(e)}")
|
||||
|
||||
|
||||
def generate_label_mappings(labels):
|
||||
|
|
@ -34,22 +38,29 @@ def preprocess_function(examples, tokenizer, max_length, is_test=False):
|
|||
|
||||
def read_local_dataset(path, label2id=None, is_test=False):
|
||||
"""读取本地数据集"""
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
for item in data:
|
||||
if is_test:
|
||||
if "text" in item:
|
||||
yield {"text": item["text"]}
|
||||
else:
|
||||
if "text" in item and "label" in item:
|
||||
yield {"text": item["text"], "label": label2id.get(item["label"], -1)}
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
for item in data:
|
||||
if is_test:
|
||||
if "text" in item:
|
||||
yield {"text": item["text"]}
|
||||
else:
|
||||
if "text" in item and "label" in item:
|
||||
yield {"text": item["text"], "label": label2id.get(item["label"], -1)}
|
||||
except Exception as e:
|
||||
raise ValueError(f"读取数据集时出错: {str(e)}")
|
||||
|
||||
|
||||
def load_and_preprocess_dataset(path, label2id, tokenizer, max_length, is_test=False):
|
||||
"""加载并预处理数据集"""
|
||||
dataset = load_dataset(read_local_dataset, path=path, label2id=label2id, lazy=False, is_test=is_test)
|
||||
trans_func = functools.partial(preprocess_function, tokenizer=tokenizer, max_length=max_length, is_test=is_test)
|
||||
return dataset.map(trans_func)
|
||||
try:
|
||||
dataset = load_dataset(read_local_dataset, path=path, label2id=label2id, lazy=False, is_test=is_test)
|
||||
trans_func = functools.partial(preprocess_function, tokenizer=tokenizer, max_length=max_length, is_test=is_test)
|
||||
return dataset.map(trans_func)
|
||||
except Exception as e:
|
||||
raise ValueError(f"加载和预处理数据集时出错: {str(e)}")
|
||||
|
||||
|
||||
def export_model(trainer, export_model_dir):
|
||||
"""导出模型和 tokenizer"""
|
||||
|
|
@ -59,6 +70,7 @@ def export_model(trainer, export_model_dir):
|
|||
paddle.jit.save(model_to_export, os.path.join(export_model_dir, 'model'), input_spec=input_spec)
|
||||
trainer.tokenizer.save_pretrained(export_model_dir)
|
||||
|
||||
# 保存 id2label 和 label2id 文件
|
||||
id2label_file = os.path.join(export_model_dir, 'id2label.json')
|
||||
label2id_file = os.path.join(export_model_dir, 'label2id.json')
|
||||
with open(id2label_file, 'w', encoding='utf-8') as f:
|
||||
|
|
@ -71,68 +83,75 @@ def export_model(trainer, export_model_dir):
|
|||
def compute_metrics(p):
|
||||
"""计算评估指标"""
|
||||
predictions, labels = p
|
||||
pred_labels = np.argmax(predictions, axis=1)
|
||||
pred_labels = np.argmax(predictions, axis=1) + 1
|
||||
accuracy = np.sum(pred_labels == labels) / len(labels)
|
||||
precision = precision_score(labels, pred_labels, average='macro')
|
||||
recall = recall_score(labels, pred_labels, average='macro')
|
||||
f1 = f1_score(labels, pred_labels, average='macro')
|
||||
|
||||
metrics = {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}
|
||||
print("Computed metrics:", metrics) # Debug statement
|
||||
print("Computed metrics:", metrics) # 打印计算出来的指标
|
||||
return metrics
|
||||
|
||||
|
||||
def main():
|
||||
# 读取配置
|
||||
config = load_config("data.yaml")
|
||||
label_id, id_label = generate_label_mappings(config["labels"])
|
||||
try:
|
||||
# 读取配置
|
||||
config = load_config("data.yaml")
|
||||
label_id, id_label = generate_label_mappings(config["labels"])
|
||||
|
||||
# 加载数据集
|
||||
tokenizer = ErnieTokenizer.from_pretrained(config["model_path"])
|
||||
train_ds = load_and_preprocess_dataset(config["train"], label_id, tokenizer, max_length=256)
|
||||
test_ds = load_and_preprocess_dataset(config["test"], label_id, tokenizer, max_length=256, is_test=True)
|
||||
# 加载数据集
|
||||
tokenizer = ErnieTokenizer.from_pretrained(config["model_path"])
|
||||
train_ds = load_and_preprocess_dataset(config["train"], label_id, tokenizer, max_length=256)
|
||||
test_ds = load_and_preprocess_dataset(config["test"], label_id, tokenizer, max_length=256, is_test=True)
|
||||
|
||||
# 加载模型
|
||||
model = ErnieForSequenceClassification.from_pretrained(config["model_path"], num_classes=config["nc"],
|
||||
label2id=label_id, id2label=id_label)
|
||||
# 加载模型
|
||||
model = ErnieForSequenceClassification.from_pretrained(config["model_path"], num_classes=len(label_id),
|
||||
label2id=label_id, id2label=id_label)
|
||||
|
||||
# 定义 DataLoader
|
||||
data_collator = DataCollatorWithPadding(tokenizer)
|
||||
# 定义 DataLoader
|
||||
data_collator = DataCollatorWithPadding(tokenizer)
|
||||
|
||||
# 定义训练参数
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./output",
|
||||
evaluation_strategy="steps", # 按步数进行评估
|
||||
eval_steps=100, # 每100步评估一次
|
||||
save_steps=500,
|
||||
logging_dir="./logs",
|
||||
logging_steps=50, # 每50步输出一次日志
|
||||
num_train_epochs=10, # 训练轮数
|
||||
per_device_train_batch_size=16,
|
||||
per_device_eval_batch_size=16,
|
||||
gradient_accumulation_steps=1,
|
||||
learning_rate=5e-5,
|
||||
weight_decay=0.01,
|
||||
disable_tqdm=False,
|
||||
metric_for_best_model="accuracy", # 根据准确率选择最佳模型
|
||||
greater_is_better=True, # 准确率越高越好
|
||||
)
|
||||
# 定义训练参数
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./output",
|
||||
evaluation_strategy="epoch",
|
||||
save_strategy="epoch",
|
||||
eval_steps=100, # 每100步评估一次
|
||||
save_steps=500,
|
||||
logging_dir="./logs",
|
||||
logging_steps=50, # 每50步输出一次日志
|
||||
num_train_epochs=10, # 训练轮数
|
||||
per_device_train_batch_size=16,
|
||||
per_device_eval_batch_size=16,
|
||||
gradient_accumulation_steps=1,
|
||||
learning_rate=5e-5,
|
||||
weight_decay=0.01,
|
||||
disable_tqdm=False,
|
||||
greater_is_better=True, # 准确率越高越好
|
||||
)
|
||||
|
||||
# 创建 Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
args=training_args,
|
||||
criterion=CrossEntropyLoss(),
|
||||
train_dataset=train_ds,
|
||||
eval_dataset=test_ds,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics, # 使用自定义的评估指标
|
||||
)
|
||||
# 创建 Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
args=training_args,
|
||||
criterion=CrossEntropyLoss(),
|
||||
train_dataset=train_ds,
|
||||
eval_dataset=test_ds,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics, # 使用自定义的评估指标
|
||||
)
|
||||
|
||||
# 训练模型
|
||||
trainer.train()
|
||||
# 训练模型
|
||||
trainer.train()
|
||||
|
||||
# 保存模型
|
||||
trainer.save_model("./saved_model_static") # 默认保存为 './uie_ner' 目录
|
||||
|
||||
except Exception as e:
|
||||
print(f"训练过程中出错: {str(e)}")
|
||||
|
||||
# 导出模型
|
||||
export_model(trainer, './output/export')
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ construction_units = ["国网安徽省电力有限公司建设分公司", "国
|
|||
project_departments = ["第九项目管理部(马鞍山)", "第十一项目管理部(马鞍山)", "第八项目管理部(芜湖)",
|
||||
"第五项目管理部(阜阳)", "第六项目管理部(滁州)", "第十二项目管理部(陕皖)",
|
||||
"第十三项目管理部(黄山)", "第四项目管理部(安庆)"]
|
||||
project_managers = ["陈少平", "范文立", "何东洋", "胡彬", "黄东林", "姜松竺", "刘闩", "柳杰"]
|
||||
project_managers = ["陈少平项目经理", "范文立项目经理", "何东洋项目经理", "胡彬项目经理", "黄东林项目经理", "姜松竺项目经理", "刘闩项目经理", "柳杰项目经理"]
|
||||
subcontractors = ["安徽远宏电力工程有限公司", "安徽京硚建设有限公司", "武汉久林电力建设有限公司",
|
||||
"安徽省鸿钢建设发展有限公司", "安徽星联建筑安装有限公司", "福建文港建设工程有限公司",
|
||||
"芜湖冉电电力安装工程有限责任公司", "合肥市胜峰建筑安装有限公司", "安徽劦力建筑装饰有限责任公司",
|
||||
|
|
|
|||
Loading…
Reference in New Issue