文章目录
- 什么是gradient checkpoint
- 原理
- 使用场景
- 注意事项
什么是gradient checkpoint
- gradient checkpoint是一种优化深度学习模型内存使用的技术,尤其在训练大型模型时非常有用。它通过牺牲计算时间为代价来减少显存占用。
- 大多数情况下,transformers库中的gradient checkpoint粒度是“一个Transformer Block(也叫layer)为单位。
原理
-在标准的反向传播中,为了计算梯度,需要保存所有中间激活值(activations),这会占用大量显存。
- Gradient Checkpointing 的核心思想是只保留部分层的激活值,其余层在反向传播时重新计算,从而节省显存。【一般只保存transformer block的输入输出,这样节省了大量的存储】
使用场景
- 显存受限时(如训练大模型)
- batch size 需要增大但受显存限制
- 模型层数较多(如Transformer)
注意事项
- 会增加训练时间(因为需要重复计算激活值)【如果计算是瓶颈,那么这个方法会增加训练时长。】
- 不适用于所有模型结构,建议先测试是否有效
- 可能与某些优化器或混合精度训练有兼容性问题