知识点回顾
- 图像数据的格式:灰度和彩色数据
- 模型的定义
- 显存占用的4种地方
- 模型参数+梯度参数
- 优化器参数
- 数据批量所占显存
- 神经元输出中间状态
- batchisize和训练的关系
作业:今日代码较少,理解内容即可
在 PyTorch 中,图像数据的形状通常遵循 (通道数, 高度, 宽度) 的格式(即 Channel First 格式),这与常见的 (高度, 宽度, 通道数)(Channel Last,如 NumPy 数组)不同。---注意顺序关系,
注意点:
- 如果用matplotlib库来画图,需要转换下顺序 image = np.transpose(image.numpy(), (1, 2, 0)
- 模型输入通常需要批次维度(Batch Size),形状变为 (批次大小, 通道数, 高度, 宽度)。例如,批量输入 10 张 MNIST 图像时,形状为 (10, 1, 28, 28)
对于图像数据集比如MNIST构建神经网络来训练的话,比起之前的结构化数据多了一个展平操作:
# 定义两层MLP神经网络
class MLP(nn.Module):def __init__(self, input_size=784, hidden_size=128, num_classes=10):super().__init__()self.flatten = nn.Flatten() # 将28x28的图像展平为784维向量self.layer1 = nn.Linear(input_size, hidden_size) # 第一层:784个输入,128个神经元self.relu = nn.ReLU() # 激活函数self.layer2 = nn.Linear(hidden_size, num_classes) # 第二层:128个输入,10个输出(对应10个数字类别)def forward(self, x):x = self.flatten(x) # 展平图像x = self.layer1(x) # 第一层线性变换x = self.relu(x) # 应用ReLU激活函数x = self.layer2(x) # 第二层线性变换,输出logitsreturn x# 初始化模型
model = MLP()
MLP的输入层要求输入是一维向量,但 MNIST 图像是二维结构(28×28 像素),形状为 [1, 28, 28](通道 × 高 × 宽)。nn.Flatten() 展平操作将二维图像 “拉成” 一维向量(784=28×28 个元素),使其符合全连接层的输入格式
在面对数据集过大的情况下,由于无法一次性将数据全部加入到显存中,所以采取了分批次加载这种方式。所以实际应用中,输入图像还存在batch_size这一维度,但在PyTorch中,模型定义和输入尺寸的指定不依赖于batch_size,无论设置多大的batch_size,模型结构和输入尺寸的写法都是不变的,batch_size是在数据加载阶段定义的(之前提过这是DataLoader的参数)
那么显存设置多少合适呢?如果设置的太小,那么每个batch_size的训练不足以发挥显卡的能力,浪费计算资源;如果设置的太大,会出现OOM(out of memory)显存一般被以下内容占用:
- 模型参数与梯度:模型的权重和对应的梯度会占用显存,尤其是深度神经网络(如 Transformer、ResNet 等),一个 1 亿参数的模型(如 BERT-base),单精度(float32)参数占用约 400MB(1e8×4Byte),加上梯度则翻倍至 800MB(每个权重参数都有其对应的梯度)
- 部分优化器(如 Adam)会为每个参数存储动量(Momentum)和平方梯度(Square Gradient),进一步增加显存占用(通常为参数大小的 2-3 倍)
- 其他开销
- 单张图像尺寸:1×28×28(通道×高×宽),归一化转换为张量后为float32类型,显存占用:1×28×28×4 Byte = 3,136 Byte ≈ 3 KB
- 批量数据占用:batch_size × 单张图像占用,例如batch_size=64时,数据占用为64×3 KB ≈ 192 KB
对于batch_size的设置,大规模数据时,通常从16开始测试,然后逐渐增加,确保代码运行正常且不报错,直到出现内存不足(OOM)报错或训练效果下降,此时选择略小于该值的 batch_size。训练时候搭配 nvidia-smi 监控显存占用,合适的 batch_size = 硬件显存允许的最大值 × 0.8(预留安全空间),并通过训练效果验证调整
@浙大疏锦行