打开方案,规程规范和图纸的功能的训练和标准化
This commit is contained in:
parent
5413911e64
commit
f2cab2701b
|
|
@ -0,0 +1,107 @@
|
|||
# 意图识别和槽位抽取模型训练和推理
|
||||
|
||||
|
||||
## 项目概述
|
||||
本项目
|
||||
1)通过训练百度的ernie和uie模型,实现意图识别和槽位抽取的目的。
|
||||
2)通过大模型提示词实现将用户的多轮问答,还原出用户问题
|
||||
意图识别:即根据用户问题 返回用户的意图类型
|
||||
槽位抽取:即根据用户问题 返回用户的问题中的槽位信息(对应api中的参数信息)
|
||||
多轮问题还原:即根据用户的多轮问题和回复 还原用户最终问题
|
||||
|
||||
## 文件结构
|
||||
|
||||
```
|
||||
项目根目录/
|
||||
├── api/
|
||||
│ ├── logs/ # 运行日志
|
||||
│ └── standard_data/ #标准化数据本地文件(每天凌晨三天通过redis同步)
|
||||
│ └── main.py #主程序
|
||||
│ └── config.py #配置文件
|
||||
│ └── constants.py #常量定义文件
|
||||
│ └── globalData.py #全局数据同步文件类
|
||||
│ └── intentRecognition.py #意图识别
|
||||
│ └── logger_util.py #写日志文件
|
||||
│ └── slotRecognition.py #槽位抽取
|
||||
│ └── utis.py #通用具体类
|
||||
│ └── standard_test.py #单元测试
|
||||
│
|
||||
├── ernie/ #意图训练
|
||||
│ ├── data/ #意图训练数据集
|
||||
│ └── data.yaml #意图训练配置
|
||||
│ └── train.py #意图训练脚本
|
||||
│
|
||||
├── uie/ #槽位抽取训练
|
||||
│ └── data/ #槽位抽取训练数据集
|
||||
│ └── train.py #槽位抽取训练脚本
|
||||
│
|
||||
├── generated_data/ #生成训练数据集目录
|
||||
│ └── data/ #训练数据集的中间文件
|
||||
│ │ └──
|
||||
│ │ └──
|
||||
│ └──output/
|
||||
│ │ └── ernie/ #最终生成的用于意图训练的数据集
|
||||
│ │ └── uie/ #最终生成的用于槽位抽取训练的数据集
|
||||
│ └── generated.py #根据意图和槽位数据模版 向本级data目录里产生数据集文件
|
||||
│ └── 合并数据.py #将本级data目录里的数据合并到output 的merged_data.json
|
||||
│ └── 按比例分配ernie数据.py #将本级data目录里的数据合并到output 的merged_data.json 按比例抽取意图训练数据集到output/ernie
|
||||
│ └── 按比例分配uie数据.py #将本级data目录里的数据合并到output 的merged_data.json 按比例抽取意图训练数据集到output/uie
|
||||
```
|
||||
注意:第一次训练在联网的情况下不需要提前下载模型ernie和uie模型到服务器
|
||||
|
||||
## 项目接口
|
||||
意图识别和槽位抽取api(在下面两个api的功能基础上增加了多轮问题的还原):
|
||||
POST:http://192.168.0.37:18074/agent
|
||||
{"user_id":"3bb66776-1722-4c36-b14a-73dd210fe750",
|
||||
"messages":[{"role":"user","content":"送一分公司"},{"role":"assistant","content":"送电一分公司第一项目管理部(金上)风险等级为2级的有0项,3级的有0项,4级的有1项,5级的有0项,有1项作业计划"},{"role":"user","content":"今天有多少项作业计划"},{"role":"assistant","content":"2025-04-25 公司风险等级为2级的有15项,3级的有144项,4级的有262项,5级的有0项,有421项作业计划"},{"role":"user","content":"具体作业计划是什么"}]}
|
||||
|
||||
意图识别 api:
|
||||
POST:http://192.168.0.37:18074/intent_reco
|
||||
Body:{"user_id":"3bb66776-1722-4c36-b14a-73dd210fe750","text":"打开众兴-草庙乡牵引站220kV电缆线路工程(土建部分)(PROJ-2023-0328)一般跨越施工措施"}
|
||||
|
||||
|
||||
槽位抽取api:
|
||||
POST:http://192.168.0.37:18074/slot_reco
|
||||
Body:{"user_id":"3bb66776-1722-4c36-b14a-73dd210fe750","text":"今天芦集变电站有多少作业计划"}
|
||||
|
||||
|
||||
## 多轮问题还原类实例
|
||||
需要通过http://192.168.0.37:18074/agent 这个接口去实现多轮还原和意图及槽位的提取
|
||||
### 追问类
|
||||
{"user_id":"3bb66776-1722-4c36-b14a-73dd210fe750",
|
||||
"messages":[{"role":"user","content":"今天有多少项作业计划"},{"role":"assistant","content":"2025-04-25公司一共有421项作业计划,分别如下:风险等级为2级的有15项,3级的有144项,4级的有262项,5级的有0项"},{"role":"user","content":"作业内容"}]}
|
||||
|
||||
{"user_id":"3bb66776-1722-4c36-b14a-73dd210fe750",
|
||||
"messages":[{"role":"user","content":"“送一分公司第一项目部金上今天有多少项作业计划"},{"role":"assistant","content":"今天送电一分公司第一项目管理部(金上)有21项作业计划"},{"role":"user","content":"班组详情呢"}]}
|
||||
|
||||
|
||||
### 补充类
|
||||
{"user_id":"3bb66776-1722-4c36-b14a-73dd210fe750",
|
||||
"messages":[{"role":"user","content":"第10项目管理部特高压部门有多少作业计划?"},{"role":"assistant","content":"非常抱歉,请问你想查询什么时间的日计划数量查询?"},{"role":"user","content":"今天"},{"role":"assistant","content":"请补充该项目部所属的分公司名称"},{"role":"user","content":"变电分公司"}]}
|
||||
|
||||
{"user_id":"3bb66776-1722-4c36-b14a-73dd210fe750",
|
||||
"messages":[{"role":"user","content":"第10项目管理部特高压有多少作业计划?"},{"role":"assistant","content":"非常抱歉,请问你想查询什么时间的日计划数量查询?"},{"role":"user","content":"今天"},{"role":"assistant","content":"请补充该项目部所属的分公司名称"},{"role":"user","content":"变电分公司"}]}
|
||||
|
||||
{"user_id":"3bb66776-1722-4c36-b14a-73dd210fe750",
|
||||
"messages":[{"role":"user","content":"第一项目部有多少作业计划?"},{"role":"assistant","content":"非常抱歉,请问你想查询什么时间的日计划数量查询?"},{"role":"user","content":"今天的"}]}
|
||||
|
||||
{"user_id":"3bb66776-1722-4c36-b14a-73dd210fe750",
|
||||
"messages":[{"role":"user","content":"第一项目部有多少作业计划?"},{"role":"assistant","content":"非常抱歉,请问你想查询什么时间的日计划数量查询?"},{"role":"user","content":"今天"},{"role":"assistant","content":"请补充该项目部所属的分公司名称"},{"role":"user","content":"送一分公司"}]}
|
||||
|
||||
|
||||
### 序号词指代类
|
||||
{"user_id":"3bb66776-1722-4c36-b14a-73dd210fe750",
|
||||
"messages":[{"role":"user","content":"送一分公司第一项目部今天有多少作业计划"},{"role":"assistant","content":"您说的项目名可能是:第1个: 第一项目管理部(金上),第2个: 第一项目管理部(天津),请确认您要选择哪一个?"},{"role":"user","content":"第二个"}]}
|
||||
|
||||
{"user_id":"3bb66776-1722-4c36-b14a-73dd210fe750",
|
||||
"messages":[{"role":"user","content":"打开合州换流站摄像头"},{"role":"assistant","content":"您说的工程名可能是:第1个:合州±800千伏换流站电气安装A包(PROJ-2025-0056),第2个:陕北-安徽±800千伏特高压直流输电工程合州换流站交流系统调试及特殊试验(PROJ-2025-0159),第3个:合州±800千伏换流站工程大件运输工程(PROJ-2024-1194),第4个:合州换流站-文都500千伏线路工程(PROJ-2024-1089),第5个:合州±800kV换流站电气A包工程(调试部分)(PROJ-2025-0102),第6个:陕北-安徽直流工程合州±800千伏换流站土建A包(PROJ-2024-0312),第7个:文都-官山改接入合州换流站500千伏线路工程(PROJ-2024-1090),请确认您要选择哪一个?"},{"role":"user","content":"第1个"}]}
|
||||
|
||||
{"user_id":"3bb66776-1722-4c36-b14a-73dd210fe750",
|
||||
"messages":[{"role":"user","content":"打开金牛紫蓬工程视频"},{"role":"assistant","content":" 您说的工程名可能是:第1个:金牛500kV变电站新建工程(PROJ-2024-0866),第2个:换流站-金牛500kV线路工程(PROJ-2024-1249),第3个:金牛-紫蓬500kV线路工程(PROJ-2023-0503),第4个:郑蒲-金河口110kV线路工程(PROJ-2024-0223),请确认您要选择哪一个个?"},{"role":"user","content":"第3个"}]}
|
||||
|
||||
### 综合类(综合以上)
|
||||
{"user_id":"3bb66776-1722-4c36-b14a-73dd210fe750",
|
||||
"messages":[{"role":"user","content":"宋1分公司第一项目部今天有多少作业计划"},{"role":"assistant","content":"您说的项目名可能是:第1个: 第一项目管理部(金上),第2个: 第2个:</strong第一项目管理部(天津),请确认您要选择哪一个?"},{"role":"user","content":"第1个"},{"role":"assistant","content":"2025-03-28 送电一分公司第一项目管理部金上)有2项作业计划"},{"role":"user","content":"2项作业计划是什么"},{"role":"assistant","content":"今天送电一分公司第一项目管理部(金上)有1项作业计划:第1项:工程名称:金上-湖北线路工程(川12标)(PROJ-2023-0152),项目经理:姚永先,班组名称:李洋班组,班组人数:12,工作负责人:李洋,安全监护员:张节艳,班组质检员:刘永远,特种作业人员:许龙彪,宋伟云,况旻伦,刘付金,杨博文,徐云孝,,一般作业人员:方友昌,朱刘文,吴怀胜,作业部位:2601~2864,作业内容:验收消缺,风险等级:4级,风险可能导致的后果:高处坠落、物体打击、触由,8+2工况:无"},{"role":"user","content":"班组详情"}]}
|
||||
|
||||
{"user_id":"3bb66776-1722-4c36-b14a-73dd210fe750",
|
||||
"messages":[{"role":"user","content":"第10项目管理部特高压部门有多少作业计划?"},{"role":"assistant","content":"非常抱歉,请问你想查询什么时间的日计划数量查询?"},{"role":"user","content":"今天"},{"role":"assistant","content":"请补充该项目部所属的分公司名称"},{"role":"user","content":"变电分公司"},{"role":"assistant","content":"变电分公司第十项目管理部(特高压)有2项作业计划"},{"role":"user","content":"具体的作业计划内容"}]}
|
||||
|
|
@ -15,8 +15,8 @@ from config import *
|
|||
from globalData import GlobalData
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
|
||||
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-14672"
|
||||
MODEL_UIE_PATH = R"../uie/output_temp/checkpoint-18774"
|
||||
MODEL_ERNIE_PATH = R"../ernie/output/checkpoint-8408"
|
||||
MODEL_UIE_PATH = R"../uie/output/checkpoint-18774"
|
||||
|
||||
|
||||
# 类别名称列表
|
||||
|
|
@ -393,7 +393,7 @@ def extract_multi_chat(messages):
|
|||
|
||||
函数 替换新属性(文本,新查询属性):
|
||||
先删除文本中的"有多少"等类似的表达数量表达,
|
||||
再将文本里的查询属性替换为新查询属性,并保持其他内容不变并返回 且保持新查询属性的语气
|
||||
再将文本里的查询属性替换为新查询属性,并保持其他内容不变,不要增加“有多少”,“是什么”类似的这些疑问词,且保持新查询属性的语气,并返回
|
||||
示例:替换新属性("今天送一分公司有多少作业计划", "作业内容") 返回 "今天送一分公司的作业内容"
|
||||
|
||||
函数 有完整的句意(新问题):
|
||||
|
|
|
|||
|
|
@ -1,81 +0,0 @@
|
|||
# from langchain_openai import OpenAIEmbeddings
|
||||
# from utils import CheckResult, StandardType, load_standard_name
|
||||
#
|
||||
# standard_program_name_list = load_standard_name('./standard_data/standard_program.txt')
|
||||
#
|
||||
# params = {'model': 'bge-large-zh-v1.5',
|
||||
# 'openai_api_base': 'http://218.23.122.14:63015/v1-openai/',
|
||||
# 'openai_api_key': 'gpustack_baacebfd27bb3d01_092ce528ae05cb7d05acb052e6490090',
|
||||
# 'openai_proxy': ''}
|
||||
#
|
||||
# try:
|
||||
# embedding = OpenAIEmbeddings(**params)
|
||||
# result = embedding.embed_documents(standard_program_name_list,chunk_size=500)
|
||||
#
|
||||
# print(f"mbedding.embed_documents 结果:{result}")
|
||||
#
|
||||
#
|
||||
# except Exception as e:
|
||||
# print(f"failed to create Embeddings for model. {e}")
|
||||
|
||||
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from utils import CheckResult, StandardType, load_standard_name
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import numpy as np
|
||||
|
||||
# 加载标准项目部名称列表
|
||||
standard_program_name_list = load_standard_name('./standard_data/standard_program.txt')
|
||||
# 模型参数
|
||||
params = {'model': 'bge-large-zh-v1.5',
|
||||
'openai_api_base': 'http://127.0.0.1:9997/v1',
|
||||
'openai_api_key': 'EMPTY',
|
||||
'openai_proxy': ''}
|
||||
# 创建嵌入模型
|
||||
embedding = OpenAIEmbeddings(**params)
|
||||
|
||||
# 获取标准项目部名称的嵌入向量
|
||||
standard_embeddings = embedding.embed_documents(standard_program_name_list, chunk_size=500)
|
||||
|
||||
|
||||
def fuzzy_match(query):
|
||||
try:
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import numpy as np
|
||||
# 查询名称
|
||||
query_embedding = embedding.embed_query(query)
|
||||
|
||||
# 计算相似度
|
||||
similarities = cosine_similarity([query_embedding], standard_embeddings)[0]
|
||||
|
||||
# 找到最相似的项目部名称
|
||||
most_similar_index = np.argmax(similarities)
|
||||
most_similar_name = standard_program_name_list[most_similar_index]
|
||||
print(f"输入名称: {query}")
|
||||
print(f"最相似的项目部名称: {most_similar_name}")
|
||||
print(f"相似度: {similarities[most_similar_index]:.4f}")
|
||||
return most_similar_name, similarities[most_similar_index]
|
||||
except Exception as e:
|
||||
print(f"相似性判断错误{e}")
|
||||
|
||||
# try:
|
||||
# # 查询名称
|
||||
# query = "定西第一项目部"
|
||||
# query_embedding = embedding.embed_query(query)
|
||||
#
|
||||
# # 计算相似度
|
||||
# similarities = cosine_similarity([query_embedding], standard_embeddings)[0]
|
||||
#
|
||||
# # 找到最相似的项目部名称
|
||||
# most_similar_index = np.argmax(similarities)
|
||||
# most_similar_name = standard_program_name_list[most_similar_index]
|
||||
#
|
||||
# print(f"输入名称: {query}")
|
||||
# print(f"最相似的项目部名称: {most_similar_name}")
|
||||
# print(f"相似度: {similarities[most_similar_index]:.4f}")
|
||||
#
|
||||
# except Exception as e:
|
||||
# print(f"Failed to create embeddings or compute similarity: {e}")
|
||||
|
||||
match_program, match_possibility = fuzzy_match("第一项目部定西")
|
||||
print(f"fuzzy_match program result:{match_program}, {match_possibility}")
|
||||
48
api/utils.py
48
api/utils.py
|
|
@ -700,54 +700,6 @@ def check_standard_name_slot_probability(int_res, slot) -> tuple:
|
|||
# return CheckResult.NO_MATCH, ""
|
||||
|
||||
|
||||
def standardize_implement_company(slot_item) -> tuple:
|
||||
if IMPLEMENTATION_ORG in slot_item:
|
||||
value = slot_item[IMPLEMENTATION_ORG]
|
||||
logger.info(f"standardize_specification_design_pic 原始分公司名 : {value}")
|
||||
match_results = standardize_sub_company(value, GlobalData.simply_to_standard_company_name_map,
|
||||
GlobalData.pinyin_simply_to_standard_company_name_map, 70, 90)
|
||||
logger.info(f"standardize_specification_design_pic 匹配后分公司名: result:{match_results}")
|
||||
if match_results and len(match_results) == 1:
|
||||
slot_item[IMPLEMENTATION_ORG] = match_results[0]
|
||||
else:
|
||||
prompt = generate_project_prompt_with_key(match_results, original_name=slot_item[IMPLEMENTATION_ORG],
|
||||
slot_key=IMPLEMENTATION_ORG)
|
||||
return CheckResult.NEEDS_MORE_ROUNDS, prompt
|
||||
return CheckResult.NO_MATCH, ""
|
||||
|
||||
|
||||
def standardize_project(slot_item) -> tuple:
|
||||
if PROJECT_NAME in slot_item:
|
||||
value = slot_item[PROJECT_NAME]
|
||||
logger.info(f"standardize_specification_design_pic 原始工程名 : {slot_item[PROJECT_NAME]}")
|
||||
match_results = standardize_project_name(value, GlobalData.simply_to_standard_project_name_map,
|
||||
GlobalData.pinyin_simply_to_standard_project_name_map, 70, 90)
|
||||
logger.info(f"standardize_specification_design_pic 匹配后工程名 :result:{match_results}")
|
||||
|
||||
if match_results and len(match_results) == 1:
|
||||
slot_item[PROJECT_NAME] = match_results[0]
|
||||
else:
|
||||
prompt = generate_project_prompt(match_results, original_name=slot_item[PROJECT_NAME], type="工程名")
|
||||
return CheckResult.NEEDS_MORE_ROUNDS, prompt
|
||||
return CheckResult.NO_MATCH, ""
|
||||
|
||||
|
||||
def standardize_design(slot_item) -> tuple:
|
||||
if PROJECT_NAME in slot_item:
|
||||
value = slot_item[PROJECT_NAME]
|
||||
logger.info(f"standardize_specification_design_pic 原始工程名 : {slot_item[PROJECT_NAME]}")
|
||||
match_results = standardize_project_name(value, GlobalData.simply_to_standard_project_name_map,
|
||||
GlobalData.pinyin_simply_to_standard_project_name_map, 70, 90)
|
||||
logger.info(f"standardize_specification_design_pic 匹配后工程名 :result:{match_results}")
|
||||
|
||||
if match_results and len(match_results) == 1:
|
||||
slot_item[PROJECT_NAME] = match_results[0]
|
||||
else:
|
||||
prompt = generate_project_prompt(match_results, original_name=slot_item[PROJECT_NAME], type="工程名")
|
||||
return CheckResult.NEEDS_MORE_ROUNDS, prompt
|
||||
return CheckResult.NO_MATCH, ""
|
||||
|
||||
|
||||
def standardize_specification_design_pic(slot) -> tuple:
|
||||
# #分公司名标准化
|
||||
# result_type, prompt = standardize_implement_company(slot)
|
||||
|
|
|
|||
|
|
@ -120,9 +120,9 @@ def main():
|
|||
save_steps=2000, # 每2000步保存一次,save_strategy="steps"时生效
|
||||
logging_dir="./logs",
|
||||
logging_steps=100, # 每100步输出一次日志
|
||||
num_train_epochs=10, # 训练轮数
|
||||
per_device_train_batch_size=32,
|
||||
per_device_eval_batch_size=32,
|
||||
num_train_epochs=8, # 训练轮数
|
||||
per_device_train_batch_size=64,
|
||||
per_device_eval_batch_size=64,
|
||||
gradient_accumulation_steps=1,
|
||||
learning_rate=5e-5,
|
||||
weight_decay=0.01,
|
||||
|
|
|
|||
|
|
@ -97,7 +97,11 @@ BASE_DATA = {
|
|||
"design_specification_names": [
|
||||
"《35kV电力电缆交流耐压试验方案》","220kV南蒙2753线拆线、拆塔施工方案","悬索封网实验方案","灌注桩承台基础施工方案"
|
||||
"一般跨越施工措施", "省道专项施工方案","吊车组立角钢塔施工方案","承台基础及接地施工措施","项目管理实施规划","电力电缆方案","线路拆旧跨越110kV线路施工方案",
|
||||
"断面悬浮抱杆组塔施工方案","灌注桩基础及接地施工措施"
|
||||
"断面悬浮抱杆组塔施工方案","灌注桩基础及接地施工措施",
|
||||
"安徽马鞍山锁库500千伏变电站工程《构支架安装施工方案》","安徽金牛~福渡500千伏线路工程G56-G57停电封拆网不停电跨越110kV金山655、656线方案", "子期-濠州220kV架空线路工程灌注桩基础及接地施工措施",
|
||||
"国网六安供电公司500kV皋铭5357线/皋传5358线2025年度停电检修 补充三措一案","文都-官山改接入合州换流站500千伏线路工程口700断面悬浮抱杆组塔施工方案(初稿)",
|
||||
"游乐南岗湖光路线路施工方案"
|
||||
|
||||
|
||||
"110kV-750kV架空输电线路铁塔基础施工工艺导则","国网(基建2)112-2022 国家电网有限公司输变电工程建设质量管理规定","1000kV架空输电线路施工质量检验及评定规程","《国家电网有限公司施工项目部标准化管理手册线路工程分册》",
|
||||
"国家电网有限公司输变电工程标准工艺(电缆工程分册)2022版","架空输电线路螺旋锚基础施工及质量验收规范","国家电网有限公司安全生产反违章工作管理办法"],
|
||||
|
|
@ -785,19 +789,21 @@ TEMPLATE_CONFIG = {
|
|||
#方案和规程规范
|
||||
("打开{design_specification_name}<方案>", ["project_name", "design_specification_name"]),
|
||||
("打开{design_specification_name}", ["project_name", "design_specification_name"]),
|
||||
("打开{project_name}{design_specification_name}<方案>", ["project_name", "design_specification_name"]),
|
||||
("打开{project_name}的{design_specification_name}<方案>", ["project_name", "design_specification_name"]),
|
||||
|
||||
("打开{project_name}的{design_specification_name}", ["project_name", "design_specification_name"]),
|
||||
("打开{project_name}下的{design_specification_name}", ["project_name", "design_specification_name"]),
|
||||
|
||||
("打开{implementation_organization}{project_name}的{design_specification_name}",
|
||||
|
||||
("打开{implementation_organization}{project_name}下的{design_specification_name}",
|
||||
["implementation_organization", "project_name", "design_specification_name"]),
|
||||
|
||||
#图纸
|
||||
("打开{pic_name}", ["pic_name"]),
|
||||
("打开{project_name}{pic_name}", ["project_name", "pic_name"]),
|
||||
("打开{project_name}下的{pic_name}", ["project_name", "pic_name"]),
|
||||
("打开{project_name}{pic_name}<图纸>", ["project_name", "pic_name"]),
|
||||
("打开{project_name}的{pic_name}", ["project_name", "pic_name"]),
|
||||
("打开{implementation_organization}{project_name}的{pic_name}",
|
||||
("打开{implementation_organization}{project_name}下的{pic_name}",
|
||||
["implementation_organization", "project_name", "pic_name"]),
|
||||
]
|
||||
},
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ def preprocess_function(example, tokenizer):
|
|||
'implementation_organization', 'project_department', 'project_manager',
|
||||
'subcontractor', 'team_leader', 'risk_level', 'page', 'operating', 'team_name',
|
||||
'construction_area', 'person_name', 'person_query_type', 'project_status',
|
||||
"sky_net", "program_navigation"
|
||||
"pic_name", "design_specification_name"
|
||||
]
|
||||
|
||||
# 文本 Tokenization
|
||||
|
|
@ -81,8 +81,8 @@ training_args = TrainingArguments(
|
|||
output_dir="./output_temp",
|
||||
evaluation_strategy="epoch",
|
||||
save_strategy="epoch",
|
||||
per_device_train_batch_size=32, # 你的显存较大,可调整 batch_size
|
||||
per_device_eval_batch_size=32,
|
||||
per_device_train_batch_size=64, # 你的显存较大,可调整 batch_size
|
||||
per_device_eval_batch_size=64,
|
||||
learning_rate=2e-5,
|
||||
num_train_epochs=10, # 训练轮数
|
||||
weight_decay=0.01,
|
||||
|
|
|
|||
Loading…
Reference in New Issue