重构模型训练
This commit is contained in:
parent
9e1182f766
commit
a20f513d38
|
|
@ -0,0 +1,42 @@
|
||||||
|
import paddle
|
||||||
|
import numpy as np
|
||||||
|
from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class IntentRecognition:
|
||||||
|
def __init__(self, model_path: str, labels: list):
|
||||||
|
# 初始化模型和tokenizer
|
||||||
|
self.model = ErnieForSequenceClassification.from_pretrained(model_path)
|
||||||
|
self.tokenizer = ErnieTokenizer.from_pretrained(model_path)
|
||||||
|
self.labels = labels
|
||||||
|
|
||||||
|
def predict(self, query: str):
|
||||||
|
"""
|
||||||
|
对输入的查询文本进行意图识别,返回预测的标签和概率。
|
||||||
|
|
||||||
|
:param query: 待识别的文本
|
||||||
|
:return: (predicted_label, predicted_probability)
|
||||||
|
"""
|
||||||
|
# 对输入文本进行tokenization
|
||||||
|
inputs = self.tokenizer(query, max_length=256, truncation=True, padding='max_length', return_tensors="pd")
|
||||||
|
|
||||||
|
# 将tokenized inputs转换为paddle tensor
|
||||||
|
input_ids = paddle.to_tensor(inputs["input_ids"])
|
||||||
|
|
||||||
|
# 模型推理得到 logits
|
||||||
|
logits = self.model(input_ids)
|
||||||
|
|
||||||
|
# 使用Softmax将 logits 转换为概率分布
|
||||||
|
probabilities = F.softmax(logits, axis=-1)
|
||||||
|
|
||||||
|
# 获取最大概率的标签和其概率值
|
||||||
|
max_prob_idx = np.argmax(probabilities.numpy(), axis=-1)
|
||||||
|
max_prob_value = np.max(probabilities.numpy(), axis=-1)
|
||||||
|
|
||||||
|
# 根据预测的标签索引映射到类别名称
|
||||||
|
predicted_label = self.labels[max_prob_idx[0]] # 获取最大概率对应的标签
|
||||||
|
predicted_probability = float(max_prob_value[0]) # 获取最大概率值
|
||||||
|
predicted_id = int(max_prob_idx[0]) # 获取最大概率对应的标签
|
||||||
|
|
||||||
|
return predicted_label, predicted_probability,predicted_id
|
||||||
275
api/mian.py
275
api/mian.py
|
|
@ -1,30 +1,49 @@
|
||||||
import json
|
import pydantic
|
||||||
from flask import Flask, jsonify, request
|
from flask import Flask, jsonify, request
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from werkzeug.exceptions import HTTPException
|
from werkzeug.exceptions import HTTPException
|
||||||
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer
|
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer, ErnieForSequenceClassification
|
||||||
import paddle
|
import paddle
|
||||||
|
import numpy as np
|
||||||
|
import paddle.nn.functional as F # 用于 Softmax
|
||||||
|
from typing import List, Dict
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
# 1. 加载模型和 tokenizer
|
from api.intentRecognition import IntentRecognition
|
||||||
model_path = R"E:\workingSpace\PycharmProjects\Intention_dev\uie\uie_ner\checkpoint-4320" # 你的模型路径
|
from api.slotRecognition import SlotRecognition
|
||||||
model = ErnieForTokenClassification.from_pretrained(model_path)
|
|
||||||
tokenizer = ErnieTokenizer.from_pretrained(model_path)
|
# 常量
|
||||||
|
MODEL_ERNIE_PATH = R"E:\workingSpace\PycharmProjects\Intention_dev\ernie\output\checkpoint-4160"
|
||||||
|
MODEL_UIE_PATH = R"E:\workingSpace\PycharmProjects\Intention_dev\uie\uie_ner\checkpoint-4320"
|
||||||
|
|
||||||
|
# 类别名称列表
|
||||||
|
labels = [
|
||||||
|
"天气查询", "互联网查询", "页面切换", "日计划数量查询", "周计划数量查询",
|
||||||
|
"日计划作业内容", "周计划作业内容", "施工人数", "作业考勤人数", "知识问答"
|
||||||
|
]
|
||||||
|
|
||||||
# 标签映射
|
# 标签映射
|
||||||
label_map = {
|
label_map = {
|
||||||
0: 'O', 1: 'B-date', 11: 'I-date',
|
0: 'O', 1: 'B-date', 11: 'I-date',
|
||||||
2: 'B-project_name', 12: 'I-project_name',
|
2: 'B-projectName', 12: 'I-projectName',
|
||||||
3: 'B-project_type', 13: 'I-project_type',
|
3: 'B-projectType', 13: 'I-projectType',
|
||||||
4: 'B-construction_unit', 14: 'I-construction_unit',
|
4: 'B-constructionUnit', 14: 'I-constructionUnit',
|
||||||
5: 'B-implementation_organization', 15: 'I-implementation_organization',
|
5: 'B-implementationOrganization', 15: 'I-implementationOrganization',
|
||||||
6: 'B-project_department', 16: 'I-project_department',
|
6: 'B-projectDepartment', 16: 'I-projectDepartment',
|
||||||
7: 'B-project_manager', 17: 'I-project_manager',
|
7: 'B-projectManager', 17: 'I-projectManager',
|
||||||
8: 'B-subcontractor', 18: 'I-subcontractor',
|
8: 'B-subcontractor', 18: 'I-subcontractor',
|
||||||
9: 'B-team_leader', 19: 'I-team_leader',
|
9: 'B-teamLeader', 19: 'I-teamLeader',
|
||||||
10: 'B-risk_level', 20: 'I-risk_level'
|
10: 'B-riskLevel', 20: 'I-riskLevel'
|
||||||
}
|
}
|
||||||
|
|
||||||
app = Flask(__name__)
|
|
||||||
|
|
||||||
|
# 初始化工具类
|
||||||
|
intent_recognizer = IntentRecognition(MODEL_ERNIE_PATH, labels)
|
||||||
|
|
||||||
|
# 初始化槽位识别工具类
|
||||||
|
slot_recognizer = SlotRecognition(MODEL_UIE_PATH, label_map)
|
||||||
|
# 设置Flask应用
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
# 统一的异常处理函数
|
# 统一的异常处理函数
|
||||||
@app.errorhandler(Exception)
|
@app.errorhandler(Exception)
|
||||||
|
|
@ -46,63 +65,201 @@ def handle_exception(e):
|
||||||
}), 500
|
}), 500
|
||||||
|
|
||||||
|
|
||||||
@app.route('/')
|
def validate_user(data):
|
||||||
def hello_world():
|
"""验证用户ID"""
|
||||||
"""示例路由,返回 Hello World"""
|
if data.get("user_id") != '3bb66776-1722-4c36-b14a-73dd210fe750':
|
||||||
return jsonify({"message": "Hello, world!"})
|
return jsonify(
|
||||||
|
code=401,
|
||||||
|
msg='权限验证失败,请联系接口开发人员',
|
||||||
|
label=-1,
|
||||||
|
probability=-1
|
||||||
|
), 401
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@app.route('/predict', methods=['POST'])
|
class LabelMessage(BaseModel):
|
||||||
def predict():
|
text: str = Field(..., description="消息内容")
|
||||||
"""处理预测请求"""
|
user_id: str = Field(..., description="消息内容")
|
||||||
data = request.get_json()
|
|
||||||
|
|
||||||
# 提取文本
|
|
||||||
text = data.get("text", "")
|
|
||||||
if not text:
|
|
||||||
return jsonify({"error": "No text provided"}), 400
|
|
||||||
|
|
||||||
# 处理输入文本
|
# 每条消息的结构
|
||||||
inputs = tokenizer(text, max_len=512, return_tensors="pd")
|
class Message(BaseModel):
|
||||||
model.eval()
|
role: str = Field(..., description="消息内容")
|
||||||
|
content: str = Field(..., description="消息内容")
|
||||||
|
# timestamp: str = Field(..., description="消息时间戳")
|
||||||
|
|
||||||
with paddle.no_grad():
|
|
||||||
logits = model(**inputs)
|
|
||||||
predictions = paddle.argmax(logits, axis=-1)
|
|
||||||
|
|
||||||
# 解析预测结果
|
# 请求数据的结构
|
||||||
predicted_labels = predictions.numpy()[0]
|
class RequestData(BaseModel):
|
||||||
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].numpy())
|
messages: List[Message] = Field(..., description="消息列表")
|
||||||
|
user_id: str = Field(..., description="用户ID")
|
||||||
|
|
||||||
entities = {}
|
|
||||||
current_entity = None
|
|
||||||
current_label = None
|
|
||||||
|
|
||||||
for token, label_id in zip(tokens, predicted_labels):
|
# 意图识别
|
||||||
label = label_map.get(label_id, "O")
|
@app.route('/intent_reco', methods=['POST'])
|
||||||
|
def intent_reco():
|
||||||
|
"""意图识别"""
|
||||||
|
try:
|
||||||
|
# 获取请求中的 JSON 数据
|
||||||
|
data = request.get_json()
|
||||||
|
request_data = LabelMessage(**data) # Pydantic 会验证数据结构
|
||||||
|
text = request_data.text
|
||||||
|
user_id = request_data.user_id
|
||||||
|
# 检查必需字段
|
||||||
|
if not text:
|
||||||
|
return jsonify({"error": "text is required"}), 400
|
||||||
|
if not user_id:
|
||||||
|
return jsonify({"error": "user_id is required"}), 400
|
||||||
|
|
||||||
if label.startswith("B-"): # 开始新实体
|
# 验证用户ID
|
||||||
if current_entity:
|
user_validation_error = validate_user(data)
|
||||||
entities[current_label] = "".join(current_entity)
|
if user_validation_error:
|
||||||
current_entity = [token]
|
return user_validation_error
|
||||||
current_label = label[2:] # 去掉 B-
|
|
||||||
|
|
||||||
elif label.startswith("I-") and current_entity and label[2:] == current_label:
|
# 调用predict方法进行意图识别
|
||||||
current_entity.append(token) # 继续合并同一实体
|
predicted_label, predicted_probability,predicted_id = intent_recognizer.predict(text)
|
||||||
|
|
||||||
else: # 非实体
|
return jsonify(
|
||||||
if current_entity:
|
code=200,
|
||||||
entities[current_label] = "".join(current_entity)
|
msg="成功",
|
||||||
current_entity = None
|
int=predicted_id,
|
||||||
current_label = None
|
label=predicted_label,
|
||||||
|
probability=float(predicted_probability)
|
||||||
|
)
|
||||||
|
|
||||||
# 处理最后一个实体
|
except Exception as e:
|
||||||
if current_entity:
|
return jsonify({"error": str(e)}), 500
|
||||||
entities[current_label] = "".join(current_entity)
|
|
||||||
|
|
||||||
# 输出最终的实体作为 JSON
|
|
||||||
return jsonify(entities)
|
# 槽位抽取
|
||||||
|
@app.route('/slot_reco', methods=['POST'])
|
||||||
|
def slot_reco():
|
||||||
|
"""槽位识别"""
|
||||||
|
try:
|
||||||
|
# 获取请求中的 JSON 数据
|
||||||
|
data = request.get_json()
|
||||||
|
request_data = LabelMessage(**data) # Pydantic 会验证数据结构
|
||||||
|
text = request_data.text
|
||||||
|
user_id = request_data.user_id
|
||||||
|
|
||||||
|
# 检查必需字段
|
||||||
|
if not text:
|
||||||
|
return jsonify({"error": "text is required"}), 400
|
||||||
|
if not user_id:
|
||||||
|
return jsonify({"error": "user_id is required"}), 400
|
||||||
|
|
||||||
|
# 验证用户ID
|
||||||
|
user_validation_error = validate_user(data)
|
||||||
|
if user_validation_error:
|
||||||
|
return user_validation_error
|
||||||
|
|
||||||
|
# 调用 recognize 方法进行槽位识别
|
||||||
|
entities = slot_recognizer.recognize(text)
|
||||||
|
|
||||||
|
return jsonify(
|
||||||
|
code=200,
|
||||||
|
msg="成功",
|
||||||
|
slot=entities)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({"error": str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/agent', methods=['POST'])
|
||||||
|
def agent():
|
||||||
|
try:
|
||||||
|
data = request.get_json()
|
||||||
|
# 使用 Pydantic 来验证数据结构
|
||||||
|
request_data = RequestData(**data) # Pydantic 会验证数据结构
|
||||||
|
messages = request_data.messages
|
||||||
|
user_id = request_data.user_id
|
||||||
|
|
||||||
|
# 检查必需字段是否存在
|
||||||
|
if not messages:
|
||||||
|
return jsonify({"error": "messages is required"}), 400
|
||||||
|
if not user_id:
|
||||||
|
return jsonify({"error": "user_id is required"}), 400
|
||||||
|
|
||||||
|
# 验证用户ID(假设这个函数已经定义)
|
||||||
|
user_validation_error = validate_user(data)
|
||||||
|
if user_validation_error:
|
||||||
|
return user_validation_error
|
||||||
|
if len(messages) == 1: # 首轮
|
||||||
|
query = messages[0].content # 使用 Message 对象的 .content 属性
|
||||||
|
# 先进行意图识别
|
||||||
|
predicted_label, predicted_probability, predicted_id = intent_recognizer.predict(query)
|
||||||
|
# 再进行槽位抽取
|
||||||
|
entities = slot_recognizer.recognize(query)
|
||||||
|
status, sk = check_lost(predicted_label, entities)
|
||||||
|
|
||||||
|
# 返回意图和槽位识别的结果
|
||||||
|
return jsonify({
|
||||||
|
"code": 200,
|
||||||
|
"msg": "成功",
|
||||||
|
"answer": {
|
||||||
|
"int": predicted_id,
|
||||||
|
"label": predicted_label,
|
||||||
|
"probability": predicted_probability,
|
||||||
|
"slot": entities
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
# 如果是后续轮次(多轮对话),这里只做示例,可能需要根据具体需求进行处理
|
||||||
|
else:
|
||||||
|
query = messages[0].content # 使用 Message 对象的 .content 属性
|
||||||
|
return jsonify({
|
||||||
|
"user_id": user_id,
|
||||||
|
"query": query,
|
||||||
|
"message_count": len(messages)
|
||||||
|
})
|
||||||
|
|
||||||
|
except ValidationError as e:
|
||||||
|
return jsonify({"error": e.errors()}), 400 # 捕捉 Pydantic 错误并返回
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({"error": str(e)}), 500 # 捕捉其他错误并返回
|
||||||
|
|
||||||
|
|
||||||
|
def check_lost(int_res, slot):
|
||||||
|
# mapping = {
|
||||||
|
# "页面切换":[['页面','应用']],
|
||||||
|
# "作业计划数量查询":[['时间']],
|
||||||
|
# "周计划查询":[['时间']],
|
||||||
|
# "作业内容":[['时间']],
|
||||||
|
# "施工人数":[['时间']],
|
||||||
|
# "作业考勤人数":[['时间']],
|
||||||
|
# }
|
||||||
|
mapping = {
|
||||||
|
1: [['date', 'area']],
|
||||||
|
3: [['page'], ['app'], ['module']],
|
||||||
|
4: [['date']],
|
||||||
|
5: [['date']],
|
||||||
|
6: [['date']],
|
||||||
|
7: [['date']],
|
||||||
|
8: [[]],
|
||||||
|
9: [[]],
|
||||||
|
}
|
||||||
|
if not mapping.__contains__(int_res):
|
||||||
|
return 0, []
|
||||||
|
cur_k = list(slot.keys())
|
||||||
|
idx = -1
|
||||||
|
idx_len = 99
|
||||||
|
for i in range(len(mapping[int_res])):
|
||||||
|
sk = mapping[int_res][i]
|
||||||
|
left = [x for x in sk if x not in cur_k]
|
||||||
|
more = [x for x in cur_k if x not in sk]
|
||||||
|
if len(more) >= 0 and len(left) == 0:
|
||||||
|
idx = i
|
||||||
|
idx_len = 0
|
||||||
|
break
|
||||||
|
if len(left) < idx_len:
|
||||||
|
idx = i
|
||||||
|
idx_len = len(left)
|
||||||
|
|
||||||
|
if idx_len == 0: # 匹配通过
|
||||||
|
return 0, cur_k
|
||||||
|
left = [x for x in mapping[int_res][idx] if x not in cur_k]
|
||||||
|
return 1, left # mapping[int_res][idx]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
app.run(host='0.0.0.0', port=5000, debug=True) # 启动 API,调试模式和指定端口
|
app.run(host='0.0.0.0', port=5000, debug=True)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
import paddle
|
||||||
|
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer
|
||||||
|
|
||||||
|
class SlotRecognition:
|
||||||
|
def __init__(self, model_path: str, label_map: dict):
|
||||||
|
"""
|
||||||
|
初始化槽位识别模型和tokenizer
|
||||||
|
:param model_path: 模型路径
|
||||||
|
:param label_map: 标签映射字典
|
||||||
|
"""
|
||||||
|
self.model = ErnieForTokenClassification.from_pretrained(model_path)
|
||||||
|
self.tokenizer = ErnieTokenizer.from_pretrained(model_path)
|
||||||
|
self.label_map = label_map
|
||||||
|
|
||||||
|
def recognize(self, text: str):
|
||||||
|
"""
|
||||||
|
对输入的文本进行槽位识别,返回识别出的实体。
|
||||||
|
:param text: 输入的文本
|
||||||
|
:return: entities 字典,包含识别出的槽位实体
|
||||||
|
"""
|
||||||
|
# 处理输入文本
|
||||||
|
inputs = self.tokenizer(text, max_length=512, return_tensors="pd")
|
||||||
|
|
||||||
|
# 使用无梯度计算
|
||||||
|
with paddle.no_grad():
|
||||||
|
logits = self.model(**inputs)
|
||||||
|
predictions = paddle.argmax(logits, axis=-1)
|
||||||
|
|
||||||
|
# 解析预测结果
|
||||||
|
predicted_labels = predictions.numpy()[0]
|
||||||
|
tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].numpy())
|
||||||
|
|
||||||
|
entities = {}
|
||||||
|
current_entity = None
|
||||||
|
current_label = None
|
||||||
|
|
||||||
|
for token, label_id in zip(tokens, predicted_labels):
|
||||||
|
label = self.label_map.get(label_id, "O")
|
||||||
|
|
||||||
|
if label.startswith("B-"): # 开始新实体
|
||||||
|
if current_entity:
|
||||||
|
entities[current_label] = "".join(current_entity)
|
||||||
|
current_entity = [token]
|
||||||
|
current_label = label[2:] # 去掉 B-
|
||||||
|
|
||||||
|
elif label.startswith("I-") and current_entity and label[2:] == current_label:
|
||||||
|
current_entity.append(token) # 继续合并同一实体
|
||||||
|
|
||||||
|
else: # 非实体
|
||||||
|
if current_entity:
|
||||||
|
entities[current_label] = "".join(current_entity)
|
||||||
|
current_entity = None
|
||||||
|
current_label = None
|
||||||
|
|
||||||
|
# 处理最后一个实体
|
||||||
|
if current_entity:
|
||||||
|
entities[current_label] = "".join(current_entity)
|
||||||
|
# 对所有实体进行替换:替换每个实体中的 '##' 为 ' '
|
||||||
|
for key, value in entities.items():
|
||||||
|
entities[key] = value.replace('#', '')
|
||||||
|
return entities
|
||||||
|
|
@ -1,13 +1,20 @@
|
||||||
import paddle
|
import paddle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from paddlenlp.transformers import ErnieTokenizer
|
from paddlenlp.transformers import ErnieTokenizer, ErnieForSequenceClassification
|
||||||
import paddle.nn.functional as F # 用于 Softmax
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
|
# 类别名称列表
|
||||||
|
labels = [
|
||||||
|
"天气查询", "互联网查询", "页面切换", "日计划数量查询", "周计划数量查询",
|
||||||
|
"日计划作业内容", "周计划作业内容", "施工人数", "作业考勤人数", "知识问答"
|
||||||
|
]
|
||||||
|
|
||||||
# 加载模型和tokenizer
|
# 加载模型和tokenizer
|
||||||
model = paddle.jit.load("trained_model_static") # 加载保存的静态图模型
|
model = ErnieForSequenceClassification.from_pretrained(R"E:\workingSpace\PycharmProjects\Intention_dev\ernie\output\checkpoint-4160") # 使用文本分类模型
|
||||||
tokenizer = ErnieTokenizer.from_pretrained("E:/workingSpace/PycharmProjects/Intention/models/ernie-3.0-tiny-base-v2-zh")
|
tokenizer = ErnieTokenizer.from_pretrained(R"E:\workingSpace\PycharmProjects\Intention_dev\ernie\output\checkpoint-4160")
|
||||||
|
|
||||||
# 创建输入示例
|
# 创建输入示例
|
||||||
text = "今天送变电二公司有?"
|
text = "胡彬项目经理上一周作业内容是什么?"
|
||||||
inputs = tokenizer(text, max_length=256, truncation=True, padding='max_length', return_tensors="pd")
|
inputs = tokenizer(text, max_length=256, truncation=True, padding='max_length', return_tensors="pd")
|
||||||
|
|
||||||
# 将输入数据转化为 Paddle tensor 格式
|
# 将输入数据转化为 Paddle tensor 格式
|
||||||
|
|
@ -18,10 +25,12 @@ model.eval() # 确保模型在推理模式
|
||||||
logits = model(input_ids) # 模型推理得到logits
|
logits = model(input_ids) # 模型推理得到logits
|
||||||
|
|
||||||
# 使用 Softmax 转换 logits 为概率
|
# 使用 Softmax 转换 logits 为概率
|
||||||
probabilities = F.softmax(logits, axis=1) # 归一化 logits 得到概率分布
|
probabilities = F.softmax(logits, axis=-1) # 归一化 logits 得到概率分布
|
||||||
# 获取最大概率的标签
|
|
||||||
max_prob_idx = np.argmax(probabilities.numpy(), axis=1)
|
# 获取最大概率的标签(整个句子的意图)
|
||||||
max_prob_value = np.max(probabilities.numpy(), axis=1)
|
max_prob_idx = np.argmax(probabilities.numpy(), axis=-1) # 获取最大概率的标签
|
||||||
# 输出预测结果
|
max_prob_value = np.max(probabilities.numpy(), axis=-1) # 获取最大概率值
|
||||||
print(f"Predicted label: {max_prob_idx}")
|
|
||||||
print(f"Predicted label: {max_prob_value}")
|
# 根据预测的标签索引映射到类别名称
|
||||||
|
predicted_label = labels[max_prob_idx[0]] # 根据索引获取对应的标签
|
||||||
|
predicted_probability = max_prob_value[0] # 获取最大概率值
|
||||||
|
|
|
||||||
143
ernie/train.py
143
ernie/train.py
|
|
@ -11,10 +11,14 @@ from paddlenlp.trainer import Trainer, TrainingArguments
|
||||||
import os
|
import os
|
||||||
from sklearn.metrics import precision_score, recall_score, f1_score
|
from sklearn.metrics import precision_score, recall_score, f1_score
|
||||||
|
|
||||||
|
|
||||||
def load_config(config_path):
|
def load_config(config_path):
|
||||||
"""加载 YAML 配置文件"""
|
"""加载 YAML 配置文件"""
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
try:
|
||||||
return yaml.safe_load(f)
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
return yaml.safe_load(f)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"读取配置文件时出错: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def generate_label_mappings(labels):
|
def generate_label_mappings(labels):
|
||||||
|
|
@ -34,22 +38,29 @@ def preprocess_function(examples, tokenizer, max_length, is_test=False):
|
||||||
|
|
||||||
def read_local_dataset(path, label2id=None, is_test=False):
|
def read_local_dataset(path, label2id=None, is_test=False):
|
||||||
"""读取本地数据集"""
|
"""读取本地数据集"""
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
try:
|
||||||
data = json.load(f)
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
for item in data:
|
data = json.load(f)
|
||||||
if is_test:
|
for item in data:
|
||||||
if "text" in item:
|
if is_test:
|
||||||
yield {"text": item["text"]}
|
if "text" in item:
|
||||||
else:
|
yield {"text": item["text"]}
|
||||||
if "text" in item and "label" in item:
|
else:
|
||||||
yield {"text": item["text"], "label": label2id.get(item["label"], -1)}
|
if "text" in item and "label" in item:
|
||||||
|
yield {"text": item["text"], "label": label2id.get(item["label"], -1)}
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"读取数据集时出错: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def load_and_preprocess_dataset(path, label2id, tokenizer, max_length, is_test=False):
|
def load_and_preprocess_dataset(path, label2id, tokenizer, max_length, is_test=False):
|
||||||
"""加载并预处理数据集"""
|
"""加载并预处理数据集"""
|
||||||
dataset = load_dataset(read_local_dataset, path=path, label2id=label2id, lazy=False, is_test=is_test)
|
try:
|
||||||
trans_func = functools.partial(preprocess_function, tokenizer=tokenizer, max_length=max_length, is_test=is_test)
|
dataset = load_dataset(read_local_dataset, path=path, label2id=label2id, lazy=False, is_test=is_test)
|
||||||
return dataset.map(trans_func)
|
trans_func = functools.partial(preprocess_function, tokenizer=tokenizer, max_length=max_length, is_test=is_test)
|
||||||
|
return dataset.map(trans_func)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"加载和预处理数据集时出错: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def export_model(trainer, export_model_dir):
|
def export_model(trainer, export_model_dir):
|
||||||
"""导出模型和 tokenizer"""
|
"""导出模型和 tokenizer"""
|
||||||
|
|
@ -59,6 +70,7 @@ def export_model(trainer, export_model_dir):
|
||||||
paddle.jit.save(model_to_export, os.path.join(export_model_dir, 'model'), input_spec=input_spec)
|
paddle.jit.save(model_to_export, os.path.join(export_model_dir, 'model'), input_spec=input_spec)
|
||||||
trainer.tokenizer.save_pretrained(export_model_dir)
|
trainer.tokenizer.save_pretrained(export_model_dir)
|
||||||
|
|
||||||
|
# 保存 id2label 和 label2id 文件
|
||||||
id2label_file = os.path.join(export_model_dir, 'id2label.json')
|
id2label_file = os.path.join(export_model_dir, 'id2label.json')
|
||||||
label2id_file = os.path.join(export_model_dir, 'label2id.json')
|
label2id_file = os.path.join(export_model_dir, 'label2id.json')
|
||||||
with open(id2label_file, 'w', encoding='utf-8') as f:
|
with open(id2label_file, 'w', encoding='utf-8') as f:
|
||||||
|
|
@ -71,68 +83,75 @@ def export_model(trainer, export_model_dir):
|
||||||
def compute_metrics(p):
|
def compute_metrics(p):
|
||||||
"""计算评估指标"""
|
"""计算评估指标"""
|
||||||
predictions, labels = p
|
predictions, labels = p
|
||||||
pred_labels = np.argmax(predictions, axis=1)
|
pred_labels = np.argmax(predictions, axis=1) + 1
|
||||||
accuracy = np.sum(pred_labels == labels) / len(labels)
|
accuracy = np.sum(pred_labels == labels) / len(labels)
|
||||||
precision = precision_score(labels, pred_labels, average='macro')
|
precision = precision_score(labels, pred_labels, average='macro')
|
||||||
recall = recall_score(labels, pred_labels, average='macro')
|
recall = recall_score(labels, pred_labels, average='macro')
|
||||||
f1 = f1_score(labels, pred_labels, average='macro')
|
f1 = f1_score(labels, pred_labels, average='macro')
|
||||||
|
|
||||||
metrics = {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}
|
metrics = {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}
|
||||||
print("Computed metrics:", metrics) # Debug statement
|
print("Computed metrics:", metrics) # 打印计算出来的指标
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# 读取配置
|
try:
|
||||||
config = load_config("data.yaml")
|
# 读取配置
|
||||||
label_id, id_label = generate_label_mappings(config["labels"])
|
config = load_config("data.yaml")
|
||||||
|
label_id, id_label = generate_label_mappings(config["labels"])
|
||||||
|
|
||||||
# 加载数据集
|
# 加载数据集
|
||||||
tokenizer = ErnieTokenizer.from_pretrained(config["model_path"])
|
tokenizer = ErnieTokenizer.from_pretrained(config["model_path"])
|
||||||
train_ds = load_and_preprocess_dataset(config["train"], label_id, tokenizer, max_length=256)
|
train_ds = load_and_preprocess_dataset(config["train"], label_id, tokenizer, max_length=256)
|
||||||
test_ds = load_and_preprocess_dataset(config["test"], label_id, tokenizer, max_length=256, is_test=True)
|
test_ds = load_and_preprocess_dataset(config["test"], label_id, tokenizer, max_length=256, is_test=True)
|
||||||
|
|
||||||
# 加载模型
|
# 加载模型
|
||||||
model = ErnieForSequenceClassification.from_pretrained(config["model_path"], num_classes=config["nc"],
|
model = ErnieForSequenceClassification.from_pretrained(config["model_path"], num_classes=len(label_id),
|
||||||
label2id=label_id, id2label=id_label)
|
label2id=label_id, id2label=id_label)
|
||||||
|
|
||||||
# 定义 DataLoader
|
# 定义 DataLoader
|
||||||
data_collator = DataCollatorWithPadding(tokenizer)
|
data_collator = DataCollatorWithPadding(tokenizer)
|
||||||
|
|
||||||
# 定义训练参数
|
# 定义训练参数
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
output_dir="./output",
|
output_dir="./output",
|
||||||
evaluation_strategy="steps", # 按步数进行评估
|
evaluation_strategy="epoch",
|
||||||
eval_steps=100, # 每100步评估一次
|
save_strategy="epoch",
|
||||||
save_steps=500,
|
eval_steps=100, # 每100步评估一次
|
||||||
logging_dir="./logs",
|
save_steps=500,
|
||||||
logging_steps=50, # 每50步输出一次日志
|
logging_dir="./logs",
|
||||||
num_train_epochs=10, # 训练轮数
|
logging_steps=50, # 每50步输出一次日志
|
||||||
per_device_train_batch_size=16,
|
num_train_epochs=10, # 训练轮数
|
||||||
per_device_eval_batch_size=16,
|
per_device_train_batch_size=16,
|
||||||
gradient_accumulation_steps=1,
|
per_device_eval_batch_size=16,
|
||||||
learning_rate=5e-5,
|
gradient_accumulation_steps=1,
|
||||||
weight_decay=0.01,
|
learning_rate=5e-5,
|
||||||
disable_tqdm=False,
|
weight_decay=0.01,
|
||||||
metric_for_best_model="accuracy", # 根据准确率选择最佳模型
|
disable_tqdm=False,
|
||||||
greater_is_better=True, # 准确率越高越好
|
greater_is_better=True, # 准确率越高越好
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建 Trainer
|
# 创建 Trainer
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
criterion=CrossEntropyLoss(),
|
criterion=CrossEntropyLoss(),
|
||||||
train_dataset=train_ds,
|
train_dataset=train_ds,
|
||||||
eval_dataset=test_ds,
|
eval_dataset=test_ds,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
compute_metrics=compute_metrics, # 使用自定义的评估指标
|
compute_metrics=compute_metrics, # 使用自定义的评估指标
|
||||||
)
|
)
|
||||||
|
|
||||||
# 训练模型
|
# 训练模型
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
# 保存模型
|
||||||
|
trainer.save_model("./saved_model_static") # 默认保存为 './uie_ner' 目录
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"训练过程中出错: {str(e)}")
|
||||||
|
|
||||||
# 导出模型
|
|
||||||
export_model(trainer, './output/export')
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ construction_units = ["国网安徽省电力有限公司建设分公司", "国
|
||||||
project_departments = ["第九项目管理部(马鞍山)", "第十一项目管理部(马鞍山)", "第八项目管理部(芜湖)",
|
project_departments = ["第九项目管理部(马鞍山)", "第十一项目管理部(马鞍山)", "第八项目管理部(芜湖)",
|
||||||
"第五项目管理部(阜阳)", "第六项目管理部(滁州)", "第十二项目管理部(陕皖)",
|
"第五项目管理部(阜阳)", "第六项目管理部(滁州)", "第十二项目管理部(陕皖)",
|
||||||
"第十三项目管理部(黄山)", "第四项目管理部(安庆)"]
|
"第十三项目管理部(黄山)", "第四项目管理部(安庆)"]
|
||||||
project_managers = ["陈少平", "范文立", "何东洋", "胡彬", "黄东林", "姜松竺", "刘闩", "柳杰"]
|
project_managers = ["陈少平项目经理", "范文立项目经理", "何东洋项目经理", "胡彬项目经理", "黄东林项目经理", "姜松竺项目经理", "刘闩项目经理", "柳杰项目经理"]
|
||||||
subcontractors = ["安徽远宏电力工程有限公司", "安徽京硚建设有限公司", "武汉久林电力建设有限公司",
|
subcontractors = ["安徽远宏电力工程有限公司", "安徽京硚建设有限公司", "武汉久林电力建设有限公司",
|
||||||
"安徽省鸿钢建设发展有限公司", "安徽星联建筑安装有限公司", "福建文港建设工程有限公司",
|
"安徽省鸿钢建设发展有限公司", "安徽星联建筑安装有限公司", "福建文港建设工程有限公司",
|
||||||
"芜湖冉电电力安装工程有限责任公司", "合肥市胜峰建筑安装有限公司", "安徽劦力建筑装饰有限责任公司",
|
"芜湖冉电电力安装工程有限责任公司", "合肥市胜峰建筑安装有限公司", "安徽劦力建筑装饰有限责任公司",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue