残差神经网络(Residual Neural Network,简称 ResNet)是深度学习领域的里程碑式模型,由何凯明等人在 2015 年提出,成功解决了深层神经网络训练中的梯度消失 / 爆炸问题,使训练超深网络(如 152 层)成为可能。以下从核心原理、结构设计、优势与应用等方面进行详解。
一、核心问题:深层网络的训练困境
在 ResNet 提出前,随着网络层数增加,模型性能会先提升,然后迅速下降 —— 这种下降并非由过拟合导致,而是因为深层网络的梯度难以有效传递到浅层,导致浅层参数无法被充分训练(梯度消失 / 爆炸)。
ResNet 通过引入 “残差连接”(Residual Connection)解决了这一问题。
二、核心原理:残差连接与恒等映射
1. 传统网络的映射方式
传统深层网络中,每一层的目标是学习一个 “直接映射”(Direct Mapping):
设输入为x,经过多层非线性变换后,输出为H(x),即网络需要学习H(x)。
2. 残差网络的映射方式
ResNet 提出:不直接学习H(x),而是学习 “残差”F(x)=H(x)−x。
此时,原映射可表示为:H(x)=F(x)+x
其中,F(x)是残差函数(由若干卷积层 / 激活函数组成),x通过 “跳跃连接”(Skip Connection)直接与F(x)相加,形成最终输出。
3. 为什么残差连接有效?
- 梯度传递更顺畅:反向传播时,梯度可通过x直接传递到浅层(避免梯度消失)。例如,若F(x)=0,则H(x)=x,形成 “恒等映射”,网络可轻松学习到这种简单映射,再在此基础上优化残差。
- 简化学习目标:学习残差F(x)比直接学习H(x)更简单。例如,当目标映射接近恒等映射时,F(x)接近 0,网络只需微调即可,无需重新学习复杂的映射。
三、ResNet 的基本结构:残差块(Residual Block)
残差块是 ResNet 的基本单元,分为两种类型:
1. 基本残差块(Basic Block,用于 ResNet-18/34)
由 2 个卷积层组成,结构如下:x→Conv2d(64,3x3)→BN→ReLU→Conv2d(64,3x3)→BN→(+x)→ReLU
- 输入x先经过两个 3x3 卷积层(带批归一化 BN 和 ReLU 激活),得到残差F(x)。
- 若输入x与F(x)的维度相同(通道数、尺寸一致),则直接相加(恒等映射);若维度不同(如 stride > 1 或通道数变化),则需通过 1x1 卷积调整x的维度(称为 “投影捷径”,Projection Shortcut):x→Conv2d(out_channels,1x1,stride)→BN→(+F(x))
2. 瓶颈残差块(Bottleneck Block,用于 ResNet-50/101/152)
为减少计算量,用 3 个卷积层(1x1 + 3x3 + 1x1)组成,结构如下:x→Conv2d(C,1x1)→BN→ReLU→Conv2d(C,3x3)→BN→ReLU→Conv2d(4C,1x1)→BN→(+x′)→ReLU
- 1x1 卷积用于 “降维”(减少通道数),3x3 卷积用于提取特征,最后 1x1 卷积 “升维”(恢复通道数),显著降低计算量。
- 同样支持投影捷径(当维度不匹配时)。
四、完整 ResNet 的网络架构
ResNet 通过堆叠残差块形成深层网络,不同层数的 ResNet 结构如下表:
网络类型 | 残差块类型 | 卷积层配置(每个阶段的残差块数量) | 总层数 |
---|---|---|---|
ResNet-18 | 基本块 | [2, 2, 2, 2] | 18 |
ResNet-34 | 基本块 | [3, 4, 6, 3] | 34 |
ResNet-50 | 瓶颈块 | [3, 4, 6, 3] | 50 |
ResNet-101 | 瓶颈块 | [3, 4, 23, 3] | 101 |
ResNet-152 | 瓶颈块 | [3, 8, 36, 3] | 152 |
- 整体流程:输入图像 → 7x7 卷积(步长 2)+ 最大池化 → 4 个阶段的残差块堆叠(每个阶段通道数翻倍,尺寸减半) → 全局平均池化 → 全连接层(输出分类结果)。
五、ResNet 的优势
- 解决深层网络训练难题:通过残差连接实现梯度有效传递,可训练数百层甚至上千层的网络。
- 性能优异:在 ImageNet 等数据集上,ResNet 的错误率显著低于 VGG、GoogLeNet 等模型。
- 泛化能力强:残差结构可迁移到其他任务(如目标检测、语义分割),成为许多深度学习模型的基础组件(如 Faster R-CNN、U-Net++)。
六、ResNet 的变体与延伸
- ResNeXt:引入 “分组卷积”(Group Convolution),在保持性能的同时减少参数。
- DenseNet:将残差连接的 “相加” 改为 “拼接”(Concatenate),强化特征复用。
- Res2Net:在残差块中引入多尺度特征融合,提升细粒度特征提取能力。
- 应用扩展:从图像分类扩展到目标检测(如 FPN)、视频分析(如 I3D)、自然语言处理(如残差 LSTM)等领域。
七、总结
ResNet 通过残差连接的创新设计,突破了深层网络的训练瓶颈,不仅推动了计算机视觉的发展,也为其他领域的深层模型设计提供了重要思路。其核心思想 ——通过简化学习目标(学习残差)提升模型性能—— 已成为深度学习的经典范式。
import torch
import torch.nn as nn
import torch.nn.functional as Fclass BasicBlock(nn.Module):"""基本残差块,用于ResNet-18/34"""expansion = 1 # 输出通道数是输入的多少倍def __init__(self, in_channels, out_channels, stride=1, downsample=None):super(BasicBlock, self).__init__()# 第一个卷积层self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)# 第二个卷积层(步长固定为1)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.downsample = downsample # 用于调整输入x的维度以匹配残差def forward(self, x):identity = x # 保存输入用于残差连接# 计算残差F(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)# 如果需要调整维度,则对输入x进行下采样if self.downsample is not None:identity = self.downsample(x)# 残差连接:H(x) = F(x) + xout += identityout = self.relu(out)return outclass Bottleneck(nn.Module):"""瓶颈残差块,用于ResNet-50/101/152"""expansion = 4 # 输出通道数是中间层的4倍def __init__(self, in_channels, out_channels, stride=1, downsample=None):super(Bottleneck, self).__init__()# 1x1卷积降维self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)# 3x3卷积提取特征self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)# 1x1卷积升维self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, stride=1, bias=False)self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampledef forward(self, x):identity = x# 计算残差F(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)# 调整输入维度if self.downsample is not None:identity = self.downsample(x)# 残差连接out += identityout = self.relu(out)return outclass ResNet(nn.Module):"""ResNet主网络"""def __init__(self, block, layers, num_classes=1000):super(ResNet, self).__init__()self.in_channels = 64 # 初始输入通道数# 初始卷积层self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(self.in_channels)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# 四个阶段的残差块self.layer1 = self._make_layer(block, 64, layers[0], stride=1)self.layer2 = self._make_layer(block, 128, layers[1], stride=2)self.layer3 = self._make_layer(block, 256, layers[2], stride=2)self.layer4 = self._make_layer(block, 512, layers[3], stride=2)# 分类头self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)# 初始化权重for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)def _make_layer(self, block, out_channels, blocks, stride=1):"""构建一个由多个残差块组成的层"""downsample = None# 如果步长不为1或输入输出通道数不匹配,需要下采样调整维度if stride != 1 or self.in_channels != out_channels * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.in_channels, out_channels * block.expansion,kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * block.expansion),)layers = []# 添加第一个残差块(可能包含下采样)layers.append(block(self.in_channels, out_channels, stride, downsample))self.in_channels = out_channels * block.expansion# 添加剩余的残差块(步长固定为1)for _ in range(1, blocks):layers.append(block(self.in_channels, out_channels))return nn.Sequential(*layers)def forward(self, x):# 初始特征提取x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)# 经过四个残差层x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)# 分类x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return x# 定义不同层数的ResNet
def resnet18(num_classes=1000):return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)def resnet34(num_classes=1000):return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)def resnet50(num_classes=1000):return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)def resnet101(num_classes=1000):return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)def resnet152(num_classes=1000):return ResNet(Bottleneck, [3, 8, 36, 3], num_classes)# 测试代码
if __name__ == "__main__":# 创建ResNet-18模型model = resnet18(num_classes=10)# 随机生成一个3通道输入(模拟224x224图像)x = torch.randn(2, 3, 224, 224) # batch_size=2# 前向传播output = model(x)print(f"输入形状: {x.shape}")print(f"输出形状: {output.shape}") # 应输出(2, 10)