【深度学习踩坑实录】从 Checkpoint 报错到 TrainingArguments 精通:QNLI 任务微调全流程复盘

作为一名深度学习初学者,最近在基于 Hugging Face Transformers 微调 BERT 模型做 QNLI 任务时,被Checkpoint 保存TrainingArguments 配置这两个知识点卡了整整两天。从磁盘爆满、权重文件加载报错,到不知道如何控制 Checkpoint 数量,每一个问题都让我一度想放弃。好在最终逐一解决,特此整理成博客,希望能帮到同样踩坑的朋友。

一、核心背景:我在做什么?

本次任务是基于 GLUE 数据集的 QNLI(Question Natural Language Inference,问题自然语言推理)任务,用 Hugging Face 的run_glue.py脚本微调bert-base-cased模型。核心需求很简单:

  1. 顺利完成模型微调,避免中途中断;
  2. 控制 Checkpoint(模型快照)的保存数量,防止磁盘爆满;
  3. 后续能正常加载 Checkpoint,用于后续的 TRAK 贡献度分析。

但实际操作中,光是 “Checkpoint 保存” 这一个环节,就暴露出我对TrainingArguments(训练参数配置类)的认知盲区。

二、先搞懂:TrainingArguments 是什么?为什么它很重要?

在解决问题前,必须先理清TrainingArguments的核心作用 —— 它是 Hugging Face Transformers 库中控制训练全流程的 “总开关”,几乎所有与训练相关的配置(如批次大小、学习率、Checkpoint 保存策略)都通过它定义。

1. TrainingArguments 的本质

TrainingArguments是一个数据类(dataclass) ,它将训练过程中需要的所有参数(从优化器设置到日志保存)封装成结构化对象,再传递给Trainer(训练器)实例。无需手动编写训练循环,只需配置好TrainingArgumentsTrainer就能自动完成训练、验证、Checkpoint 保存等操作。

2. 常用核心参数(按功能分类)

我整理了本次任务中最常用的参数,按 “训练基础配置”“Checkpoint 控制”“日志与验证” 三类划分,新手直接套用即可:

类别参数名作用说明常用值示例
训练基础配置output_dir训练结果(Checkpoint、日志、指标)的保存根路径/root/autodl-tmp/bert_qnli
per_device_train_batch_size单设备训练批次大小(GPU 内存不足就调小)8/16/32
learning_rate学习率(BERT 类模型微调常用 5e-5/2e-5)5e-5
num_train_epochs训练总轮次(QNLI 任务 3-5 轮足够)3.0
fp16是否启用混合精度训练(GPU 支持时可加速,减少显存占用)true
Checkpoint 控制save_strategyCheckpoint 保存时机(核心!)"epoch"(按轮次)/"steps"(按步数)
save_steps按步数保存时,每多少步保存一次(需配合save_strategy="steps"2000/5000
save_total_limit最多保存多少个 Checkpoint(超过自动删除最旧的,防磁盘爆满)2/3
save_only_model是否只保存模型权重(不保存优化器、调度器状态,减小文件体积)true
overwrite_output_dir是否覆盖已存在的output_dir(避免 “目录非空” 报错)true
日志与验证do_eval是否在训练中执行验证(判断模型性能)true
eval_strategy验证时机(建议与save_strategy一致)"epoch"/"steps"
logging_dir日志保存路径(TensorBoard 可视化用)/root/autodl-tmp/bert_qnli/logs
logging_steps每多少步记录一次日志(查看训练进度)100/200

3. TrainingArguments 的配置方式

TrainingArguments不支持在代码中硬编码(除非修改脚本),常用两种配置方式,新手推荐第二种:

方式 1:命令行参数(快速调试)

运行run_glue.py时,通过--参数名 参数值的格式传递,示例:

python run_glue.py \--model_name_or_path bert-base-cased \--task_name qnli \--output_dir /root/autodl-tmp/bert_qnli \--do_train \--do_eval \--per_device_train_batch_size 8 \--learning_rate 5e-5 \--num_train_epochs 3 \--save_strategy epoch \--save_total_limit 3 \--overwrite_output_dir
方式 2:JSON 配置文件(固定复用)

将所有参数写入 JSON 文件(如qnli_train_config.json),运行时直接指定文件,适合参数较多或多任务复用:

{"model_name_or_path": "bert-base-cased","task_name": "qnli","do_train": true,"do_eval": true,"max_seq_length": 128,"per_device_train_batch_size": 8,"learning_rate": 5e-5,"num_train_epochs": 3.0,"output_dir": "/root/autodl-tmp/bert_qnli_new","save_strategy": "epoch","save_total_limit": 3,"overwrite_output_dir": true,"logging_dir": "/root/autodl-tmp/bert_qnli_new/logs","logging_steps": 100
}
python run_glue.py qnli_train_config.json

三、我的踩坑实录:3 个经典问题与解决方案

接下来重点复盘我遇到的 3 个核心问题,每个问题都附 “报错现象→原因分析→解决步骤”,新手可直接对号入座。

问题 1:训练中途磁盘爆满,被迫中断

报错现象

训练到约 3 万步时,服务器提示 “磁盘空间不足”,查看output_dir发现有 10 多个 Checkpoint 文件夹,每个文件夹占用数百 MB,累计占用超过 20GB。

原因分析

默认情况下,TrainingArgumentssave_strategy"steps"(每 500 步保存一次),且save_total_limit未设置(不限制保存数量)。QNLI 任务 1 个 epoch 约 1.3 万步,3 个 epoch 会生成 6-8 个 Checkpoint,加上优化器状态文件(optimizer.pt),很容易撑爆磁盘。

解决步骤
  1. TrainingArguments中添加save_total_limit: 3(最多保存 3 个 Checkpoint,超过自动删除最旧的);
  2. 选择合适的save_strategy:若追求稳定,用"epoch"(每轮保存一次,3 个 epoch 仅 3 个 Checkpoint);若需中途恢复,用"steps"并设置较大的save_steps(如 5000 步);
  3. 可选添加save_only_model: true(只保存模型权重,不保存优化器状态,每个 Checkpoint 体积从 500MB 缩减到 300MB 左右)。

问题 2:加载 Checkpoint 时提示 “_pickle.UnpicklingError: invalid load key, '\xe0'”

报错现象

训练中断后,尝试加载已保存的 Checkpoint(路径/root/autodl-tmp/bert_qnli/checkpoint-31000),运行代码:

model.load_state_dict(torch.load(os.path.join(checkpoint, "model.safetensors"), map_location=DEVICE))

报错:_pickle.UnpicklingError: invalid load key, '\xe0'

原因分析
  • model.safetensorsSafetensors 格式的权重文件(更安全,但需专用方法加载);
  • torch.load()是 PyTorch 原生加载函数,更适合加载pytorch_model.bin(PyTorch 二进制格式),用它加载 Safetensors 格式会因 “格式不兼容” 报错。
解决步骤

有两种方案,按需选择:

方案 1:改用 Safetensors 专用加载函数(需先安装safetensors库)

pip install safetensors
from safetensors.torch import load_file
# 用load_file()替代torch.load()
model.load_state_dict(load_file(os.path.join(checkpoint, "model.safetensors"), device=DEVICE))

方案 2:让 Checkpoint 默认保存为pytorch_model.bin格式
TrainingArguments中添加save_safetensors: false,后续生成的 Checkpoint 会默认保存为pytorch_model.bin,直接用torch.load()加载即可:

model.load_state_dict(torch.load(os.path.join(checkpoint, "pytorch_model.bin"), map_location=DEVICE))

问题 3:配置 TrainingArguments 后,Checkpoint 迟迟不生成

报错现象

设置save_strategy: "epoch"后,训练到 4000 步仍未生成任何 Checkpoint,怀疑配置未生效。

原因分析

save_strategy: "epoch"表示每轮训练结束后才保存 Checkpoint,而 QNLI 任务 1 个 epoch 约 1.3 万步(训练集约 10 万样本,batch_size=8时:100000÷8=12500 步)。4000 步仅完成第一个 epoch 的 1/3,未到保存时机,属于正常现象。

解决步骤
  1. 若想快速验证配置是否生效,临时改用save_strategy: "steps"并设置较小的save_steps(如 2000 步),训练到 2000 步时会自动生成checkpoint-2000
  2. 若坚持按 epoch 保存,耐心等待第一个 epoch 结束(约 1.3 万步),日志会打印Saving model checkpoint to xxx,此时output_dir下会出现第一个 Checkpoint;
  3. 查看日志确认配置:搜索save_strategysave_total_limit,确认日志中显示的参数与 JSON 配置一致(避免 JSON 文件未被正确读取)。

四、总结:TrainingArguments 配置 “避坑指南”

经过这次踩坑,我总结出 3 条新手必看的 “避坑原则”,帮你少走弯路:

  1. 优先用 JSON 配置文件:命令行参数容易遗漏,JSON 文件可固化配置,后续复用或修改时更清晰;
  2. Checkpoint 配置 “三要素”:每次训练前必确认save_strategy(保存时机)、save_total_limit(保存数量)、output_dir(保存路径),这三个参数直接决定是否会出现磁盘爆满或 Checkpoint 丢失;
  3. 加载 Checkpoint 前先看格式:先查看 Checkpoint 文件夹中的权重文件名(是model.safetensors还是pytorch_model.bin),再选择对应的加载函数,避免格式不兼容报错。

最后想说,深度学习中的 “环境配置” 和 “参数调试” 虽然繁琐,但每一次踩坑都是对知识点的深化。这次从 “完全不懂 TrainingArguments” 到 “能灵活控制 Checkpoint”,虽然花了两天时间,但后续再做其他 GLUE 任务(如 SST-2、MRPC)时,直接复用配置就能快速上手 —— 这大概就是踩坑的价值吧。

如果你也在做类似任务,欢迎在评论区交流更多踩坑经验~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/96951.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/96951.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java面试小册(3)

21【Q】: 什么是Java的SPI机制?【A】:SPI 是一种插件机制,用于在运行时动态加载服务的实现。它通过定义接口(服务接口)并提供一种可扩展的方式来让服务的提供着(实现类)在运行时注入&#xff0c…

P1150 Peter 的烟

记录20#include <bits/stdc.h> using namespace std; int main(){int n,k;cin>>n>>k;int cnt0;while(n>k){cntk;nn-k1;}cntn;cout<<cnt;return 0; }突破口每吸完一根烟就把烟蒂保存起来&#xff0c;k&#xff08;k>1&#xff09;个烟蒂可以换一个…

Cursor和Hbuilder用5分钟开发微信小程序

分享一个5分钟搞定微信小程序开发的技能&#xff0c;需要用到两个工具&#xff1a;Cursor和Hbuilder。 第1步、下载HBuilder。Hbuilder可以实现一套代码直接生成安卓、苹果、鸿蒙各个平台APP。访问Hbuilder的官方网站&#xff0c;HBuilderX-高效极客技巧&#xff0c;选择适合…

k8s的dashboard

找一个装有docker的机器&#xff0c;在一个rocky linux的虚拟机里弄拉取一个rancher镜像建立一个目录&#xff0c;目的&#xff1a;和里面数据做持久化关联后台运行&#xff0c;让他有权限&#xff0c;8080端口和容器80端口映射&#xff0c;443和443做映射查看一下删掉&#xf…

桥接模式,打造灵活可扩展的日志系统C++

一、为什么用桥接模式在企业开发中&#xff0c;日志系统几乎是标配。常见需求&#xff1a;日志有多种类型&#xff08;Info、Warning、Error 等&#xff09;&#xff1b;日志需要支持多种输出方式&#xff08;控制台输出、写文件、远程上传、数据库存储等&#xff09;。如果把这…

kafka--基础知识点--5.3--producer事务

1 事务简介 Kafka事务是Apache Kafka在流处理场景中实现Exactly-Once语义的核心机制。它允许生产者在跨多个分区和主题的操作中&#xff0c;以原子性&#xff08;Atomicity&#xff09;的方式提交或回滚消息&#xff0c;确保数据处理的最终一致性。例如&#xff0c;在流处理中…

利用DeepSeek实现服务器客户端模式的DuckDB原型

在网上看到韩国公司开发的一款GooseDB&#xff0c;DuckDB™ 的功能扩展分支&#xff0c;具有服务器/客户端、多会话和并发写入支持&#xff0c;使用 PostgreSQL 有线协议&#xff0c;但它是Freeware而不是开源&#xff0c;所以让DeepSeek实现之。 首先把readme页面发给他翻译&a…

麦当劳APP逆向

版本 V 7.0.17.0反调试 梆梆企业加固 frida反调试部分代码 headers {"biz_scenario": "500","biz_from": "1004","User-Agent": "mcdonald_Android/7.0.17.0 (Android)","ct": "102","…

大数据毕业设计选题推荐-基于大数据的结核病数据可视化分析系统-Hadoop-Spark-数据可视化-BigData

✨作者主页&#xff1a;IT毕设梦工厂✨ 个人简介&#xff1a;曾从事计算机专业培训教学&#xff0c;擅长Java、Python、PHP、.NET、Node.js、GO、微信小程序、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇…

Vue3 视频播放器完整指南 – @videojs-player/vue 从入门到精通

前言 在 Vue 3 生态中&#xff0c;视频播放功能是许多应用的核心需求。videojs-player/vue 是一个专门为 Vue 3 设计的视频播放器组件&#xff0c;基于成熟的 Video.js 库构建&#xff0c;提供了简单而强大的视频播放解决方案。 主要特性 Vue 3 组件化&#xff1a;原生 Vue …

【靶场练习】--DVWA第一关Brute Force(暴力破解)全难度分析

注意&#xff0c;这一关必须要使用Burpsuite来抓包 目录Low1.抓包2.发送到爆破模块3.选择爆破模式爆破模式介绍4.添加载荷5.添加字典6.爆破查看查看源码Medium查看源码High1.抓包2.在bp的extensions中找到CSRF Token Tracker&#xff0c;并安装3.构造字典4.成功爆破查看源码Imp…

Java语言——排序算法

一、基本概念排序&#xff1a;将n个数字按一定顺序排列&#xff08;比如&#xff1a;升序&#xff0c;或者降序&#xff09; ^内部排序 &#xff1a;若整个排序过程不需要访问外存便能完成&#xff0c;则称此类排序问题为内部排序 ^外部排序&#xff1a;若参加排序的记录数量很…

【Linux】人事档案——用户及组管理

目录 1 用户及组管理 2 用户及用户组管理命令 2.1 useradd&#xff1a;建立用户 useradd命令用于建立用户&#xff0c;该 2.2 passwd&#xff1a;更改用户密码 2.3 usermod&#xff1a;更改用户信息 2.4 groupadd&#xff1a;建立用户组 2.5 finger&#xff1a;查找并显…

给定一个有序的正数数组arr和一个正数range,如果可以自由选择arr中的数字,想累加得 到 1~range 范围上所有的数,返回arr最少还缺几个数。

给定一个有序的正数数组arr和一个正数range&#xff0c;如果可以自由选择arr中的数字&#xff0c;想累加得 到 1~range 范围上所有的数&#xff0c;返回arr最少还缺几个数。 #include <iostream> #include <vector>using namespace std;void func1(std::vector<…

BigemapPro快速添加历史影像(Arcgis卫星地图历史地图)

这是Esri(Arcgis)官方提供的历史影像数据&#xff0c;可放心使用。https://livingatlas.arcgis.com/wayback如何快速添加到Bigemap Pro软件里&#xff0c;详细步骤如下&#xff1a;复制下面的文本保存为 配置.bmmap,然后拖入软件就可以了{"BmLayerVersion":"1.0…

[免费]基于Python的Django医院管理系统【论文+源码+SQL脚本】

大家好&#xff0c;我是python222_小锋老师&#xff0c;看到一个不错的基于Python的Django医院管理系统&#xff0c;分享下哈。 项目视频演示 https://www.bilibili.com/video/BV1iPH8zmEut/ 项目介绍 随着人民生活水平日益增长&#xff0c;科技日益发达的今天&#xff0c;…

MyBatis 从入门到精通(第三篇)—— 动态 SQL、关联查询与查询缓存

在前两篇博客中&#xff0c;我们掌握了 MyBatis 的基础搭建、核心架构与 Mapper 代理开发&#xff0c;能应对简单的单表 CRUD 场景。但实际项目中&#xff0c;业务往往更复杂 —— 比如 “多条件动态查询”“员工与部门的关联查询”“高频查询的性能优化” 等。本篇将聚焦 MyBa…

Linux内核中IPv4的BEET模式封装机制解析

引言 在Linux网络栈中,IPSec提供了网络层的数据加密和认证服务。传统的IPSec支持两种模式:传输模式(Transport Mode)和隧道模式(Tunnel Mode)。然而,这两种模式各有优缺点:传输模式开销小但无法隐藏原始IP头;隧道模式提供完全封装但增加了开销。 BEET(Bound End-to…

设计模式——创建型模式

什么是设计模式&#xff1f;设计模式是软件工程中解决常见问题的经典方案&#xff0c;它们代表了最佳实践和经验总结。通过使用设计模式&#xff0c;开发者可以创建更加灵活、可维护和可扩展的代码结构。设计模式不是具体的代码实现&#xff0c;而是针对特定问题的通用解决方案…

我爱学算法之—— 位运算(上)

常见位运算 对于位运算&#xff1a; &&#xff1a;按位与&#xff0c;有0则0。 |&#xff1a;按位或&#xff0c;有1则1。 ^&#xff1a;按位异或&#xff0c;相同为0、不同为1。&#xff08;无进位相加&#xff09; ~&#xff1a;二进制位按位取反。 对于位运算的常见使用…