在计算机视觉领域,UNet因其优异的性能在图像分割任务中广受欢迎。本文将介绍一种改进的UNet架构——UNetWithCrossAttention,它通过引入交叉注意力机制来增强模型的特征融合能力。
1. 交叉注意力机制
交叉注意力(Cross Attention)是一种让模型能够动态地从辅助特征中提取相关信息来增强主特征的机制。在我们的实现中,CrossAttention
类实现了这一功能:
class CrossAttention(nn.Module):def __init__(self, channels):super(CrossAttention, self).__init__()self.query_conv = nn.Conv2d(channels, channels // 8, kernel_size=1)self.key_conv = nn.Conv2d(channels, channels // 8, kernel_size=1)self.value_conv = nn.Conv2d(channels, channels, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))def forward(self, x1, x2):batch_size, C, height, width = x1.size()# 投影到query, key, value空间proj_query = self.query_conv(x1).view(batch_size, -1, height * width).permute(0, 2, 1)proj_key = self.key_conv(x2).view(batch_size, -1, height * width)proj_value = self.value_conv(x2).view(batch_size, -1, height * width)# 计算注意力图energy = torch.bmm(proj_query, proj_key)attention = torch.softmax(energy / math.sqrt(proj_key.size(-1)), dim=-1)# 应用注意力out = torch.bmm(proj_value, attention.permute(0, 2, 1))out = out.view(batch_size, C, height, width)# 残差连接out = self.gamma * out + x1return out
该模块的工作原理是:
-
将主特征x1投影为query,辅助特征x2投影为key和value
-
计算query和key的相似度得到注意力权重
-
使用注意力权重对value进行加权求和
-
通过残差连接将结果与原始主特征融合
2. 双卷积模块
DoubleConv
是UNet中的基础构建块,包含两个连续的卷积层,并可选择性地加入交叉注意力:
class DoubleConv(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(DoubleConv, self).__init__()self.use_cross_attention = use_cross_attentionself.conv1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))self.conv2 = nn.Sequential(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))if use_cross_attention:self.cross_attention = CrossAttention(out_channels)def forward(self, x, aux_feature=None):x = self.conv1(x)x = self.conv2(x)if self.use_cross_attention and aux_feature is not None:x = self.cross_attention(x, aux_feature)return x
3. 下采样和上采样模块
下采样模块Down
结合了最大池化和双卷积:
class Down(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(Down, self).__init__()self.downsampling = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2),DoubleConv(in_channels, out_channels, use_cross_attention))def forward(self, x, aux_feature=None):return self.downsampling[1](self.downsampling[0](x), aux_feature)
上采样模块Up
使用转置卷积进行上采样并拼接特征:
class Up(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(Up, self).__init__()self.upsampling = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels, use_cross_attention)def forward(self, x1, x2, aux_feature=None):x1 = self.upsampling(x1)x = torch.cat([x2, x1], dim=1)x = self.conv(x, aux_feature)return x
4. 完整的UNetWithCrossAttention架构
将上述模块组合起来,我们得到了完整的UNetWithCrossAttention:
class UNetWithCrossAttention(nn.Module):def __init__(self, in_channels=1, num_classes=1, use_cross_attention=False):super(UNetWithCrossAttention, self).__init__()self.in_channels = in_channelsself.num_classes = num_classesself.use_cross_attention = use_cross_attention# 编码器self.in_conv = DoubleConv(in_channels, 64, use_cross_attention)self.down1 = Down(64, 128, use_cross_attention)self.down2 = Down(128, 256, use_cross_attention)self.down3 = Down(256, 512, use_cross_attention)self.down4 = Down(512, 1024, use_cross_attention)# 解码器self.up1 = Up(1024, 512, use_cross_attention)self.up2 = Up(512, 256, use_cross_attention)self.up3 = Up(256, 128, use_cross_attention)self.up4 = Up(128, 64, use_cross_attention)self.out_conv = OutConv(64, num_classes)def forward(self, x, aux_feature=None):# 编码过程x1 = self.in_conv(x, aux_feature)x2 = self.down1(x1, aux_feature)x3 = self.down2(x2, aux_feature)x4 = self.down3(x3, aux_feature)x5 = self.down4(x4, aux_feature)# 解码过程x = self.up1(x5, x4, aux_feature)x = self.up2(x, x3, aux_feature)x = self.up3(x, x2, aux_feature)x = self.up4(x, x1, aux_feature)x = self.out_conv(x)return x
5. 应用场景与优势
这种带有交叉注意力的UNet架构特别适合以下场景:
-
多模态图像分割:当有来自不同成像模态的辅助信息时,交叉注意力可以帮助模型有效地融合这些信息
-
时序图像分析:对于视频序列,前一帧的特征可以作为辅助特征来增强当前帧的分割
-
弱监督学习:当有额外的弱监督信号时,可以通过交叉注意力将其融入主网络
相比于传统UNet,这种架构的优势在于:
-
能够动态地关注辅助特征中最相关的部分
-
通过注意力机制实现更精细的特征融合
-
保留了UNet原有的多尺度特征提取能力
-
通过残差连接避免了信息丢失
6. 总结
本文介绍了一种增强版的UNet架构,通过引入交叉注意力机制,使模型能够更有效地利用辅助特征。这种设计既保留了UNet原有的优势,又增加了灵活的特征融合能力,特别适合需要整合多源信息的复杂视觉任务。
在实际应用中,可以根据具体任务需求选择在哪些层级启用交叉注意力,也可以调整注意力模块的复杂度来平衡模型性能和计算开销。
希望这篇文章能帮助你理解交叉注意力在UNet中的应用。如果你有任何问题或建议,欢迎在评论区留言讨论!
完整代码
如下:
import torch.nn as nn
import torch
import mathclass CrossAttention(nn.Module):def __init__(self, channels):super(CrossAttention, self).__init__()self.query_conv = nn.Conv2d(channels, channels // 8, kernel_size=1)self.key_conv = nn.Conv2d(channels, channels // 8, kernel_size=1)self.value_conv = nn.Conv2d(channels, channels, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))def forward(self, x1, x2):"""x1: 主特征 (batch, channels, height, width)x2: 辅助特征 (batch, channels, height, width)"""batch_size, C, height, width = x1.size()# 投影到query, key, value空间proj_query = self.query_conv(x1).view(batch_size, -1, height * width).permute(0, 2, 1) # (B, N, C')proj_key = self.key_conv(x2).view(batch_size, -1, height * width) # (B, C', N)proj_value = self.value_conv(x2).view(batch_size, -1, height * width) # (B, C, N)# 计算注意力图energy = torch.bmm(proj_query, proj_key) # (B, N, N)attention = torch.softmax(energy / math.sqrt(proj_key.size(-1)), dim=-1)# 应用注意力out = torch.bmm(proj_value, attention.permute(0, 2, 1)) # (B, C, N)out = out.view(batch_size, C, height, width)# 残差连接out = self.gamma * out + x1return outclass DoubleConv(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(DoubleConv, self).__init__()self.use_cross_attention = use_cross_attentionself.conv1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True)self.conv2 = nn.Sequential(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))if use_cross_attention:self.cross_attention = CrossAttention(out_channels)def forward(self, x, aux_feature=None):x = self.conv1(x)x = self.conv2(x)if self.use_cross_attention and aux_feature is not None:x = self.cross_attention(x, aux_feature)return xclass Down(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(Down, self).__init__()self.downsampling = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2),DoubleConv(in_channels, out_channels, use_cross_attention))def forward(self, x, aux_feature=None):return self.downsampling[1](self.downsampling[0](x), aux_feature)class Up(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(Up, self).__init__()self.upsampling = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels, use_cross_attention)def forward(self, x1, x2, aux_feature=None):x1 = self.upsampling(x1)x = torch.cat([x2, x1], dim=1)x = self.conv(x, aux_feature)return xclass UNetWithCrossAttention(nn.Module):def __init__(self, in_channels=1, num_classes=1, use_cross_attention=False):super(UNetWithCrossAttention, self).__init__()self.in_channels = in_channelsself.num_classes = num_classesself.use_cross_attention = use_cross_attention# 编码器self.in_conv = DoubleConv(in_channels, 64, use_cross_attention)self.down1 = Down(64, 128, use_cross_attention)self.down2 = Down(128, 256, use_cross_attention)self.down3 = Down(256, 512, use_cross_attention)self.down4 = Down(512, 1024, use_cross_attention)# 解码器self.up1 = Up(1024, 512, use_cross_attention)self.up2 = Up(512, 256, use_cross_attention)self.up3 = Up(256, 128, use_cross_attention)self.up4 = Up(128, 64, use_cross_attention)self.out_conv = OutConv(64, num_classes)def forward(self, x, aux_feature=None):# 编码过程x1 = self.in_conv(x, aux_feature)x2 = self.down1(x1, aux_feature)x3 = self.down2(x2, aux_feature)x4 = self.down3(x3, aux_feature)x5 = self.down4(x4, aux_feature)# 解码过程x = self.up1(x5, x4, aux_feature)x = self.up2(x, x3, aux_feature)x = self.up3(x, x2, aux_feature)x = self.up4(x, x1, aux_feature)x = self.out_conv(x)return x