UNet改进(4):交叉注意力(Cross Attention)-多模态/多特征交互

在计算机视觉领域,UNet因其优异的性能在图像分割任务中广受欢迎。本文将介绍一种改进的UNet架构——UNetWithCrossAttention,它通过引入交叉注意力机制来增强模型的特征融合能力。

1. 交叉注意力机制

交叉注意力(Cross Attention)是一种让模型能够动态地从辅助特征中提取相关信息来增强主特征的机制。在我们的实现中,CrossAttention类实现了这一功能:

class CrossAttention(nn.Module):def __init__(self, channels):super(CrossAttention, self).__init__()self.query_conv = nn.Conv2d(channels, channels // 8, kernel_size=1)self.key_conv = nn.Conv2d(channels, channels // 8, kernel_size=1)self.value_conv = nn.Conv2d(channels, channels, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))def forward(self, x1, x2):batch_size, C, height, width = x1.size()# 投影到query, key, value空间proj_query = self.query_conv(x1).view(batch_size, -1, height * width).permute(0, 2, 1)proj_key = self.key_conv(x2).view(batch_size, -1, height * width)proj_value = self.value_conv(x2).view(batch_size, -1, height * width)# 计算注意力图energy = torch.bmm(proj_query, proj_key)attention = torch.softmax(energy / math.sqrt(proj_key.size(-1)), dim=-1)# 应用注意力out = torch.bmm(proj_value, attention.permute(0, 2, 1))out = out.view(batch_size, C, height, width)# 残差连接out = self.gamma * out + x1return out

该模块的工作原理是:

  1. 将主特征x1投影为query,辅助特征x2投影为key和value

  2. 计算query和key的相似度得到注意力权重

  3. 使用注意力权重对value进行加权求和

  4. 通过残差连接将结果与原始主特征融合

2. 双卷积模块

DoubleConv是UNet中的基础构建块,包含两个连续的卷积层,并可选择性地加入交叉注意力:

class DoubleConv(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(DoubleConv, self).__init__()self.use_cross_attention = use_cross_attentionself.conv1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))self.conv2 = nn.Sequential(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))if use_cross_attention:self.cross_attention = CrossAttention(out_channels)def forward(self, x, aux_feature=None):x = self.conv1(x)x = self.conv2(x)if self.use_cross_attention and aux_feature is not None:x = self.cross_attention(x, aux_feature)return x

3. 下采样和上采样模块

下采样模块Down结合了最大池化和双卷积:

class Down(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(Down, self).__init__()self.downsampling = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2),DoubleConv(in_channels, out_channels, use_cross_attention))def forward(self, x, aux_feature=None):return self.downsampling[1](self.downsampling[0](x), aux_feature)

上采样模块Up使用转置卷积进行上采样并拼接特征:

class Up(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(Up, self).__init__()self.upsampling = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels, use_cross_attention)def forward(self, x1, x2, aux_feature=None):x1 = self.upsampling(x1)x = torch.cat([x2, x1], dim=1)x = self.conv(x, aux_feature)return x

4. 完整的UNetWithCrossAttention架构

将上述模块组合起来,我们得到了完整的UNetWithCrossAttention:

class UNetWithCrossAttention(nn.Module):def __init__(self, in_channels=1, num_classes=1, use_cross_attention=False):super(UNetWithCrossAttention, self).__init__()self.in_channels = in_channelsself.num_classes = num_classesself.use_cross_attention = use_cross_attention# 编码器self.in_conv = DoubleConv(in_channels, 64, use_cross_attention)self.down1 = Down(64, 128, use_cross_attention)self.down2 = Down(128, 256, use_cross_attention)self.down3 = Down(256, 512, use_cross_attention)self.down4 = Down(512, 1024, use_cross_attention)# 解码器self.up1 = Up(1024, 512, use_cross_attention)self.up2 = Up(512, 256, use_cross_attention)self.up3 = Up(256, 128, use_cross_attention)self.up4 = Up(128, 64, use_cross_attention)self.out_conv = OutConv(64, num_classes)def forward(self, x, aux_feature=None):# 编码过程x1 = self.in_conv(x, aux_feature)x2 = self.down1(x1, aux_feature)x3 = self.down2(x2, aux_feature)x4 = self.down3(x3, aux_feature)x5 = self.down4(x4, aux_feature)# 解码过程x = self.up1(x5, x4, aux_feature)x = self.up2(x, x3, aux_feature)x = self.up3(x, x2, aux_feature)x = self.up4(x, x1, aux_feature)x = self.out_conv(x)return x

5. 应用场景与优势

这种带有交叉注意力的UNet架构特别适合以下场景:

  1. 多模态图像分割:当有来自不同成像模态的辅助信息时,交叉注意力可以帮助模型有效地融合这些信息

  2. 时序图像分析:对于视频序列,前一帧的特征可以作为辅助特征来增强当前帧的分割

  3. 弱监督学习:当有额外的弱监督信号时,可以通过交叉注意力将其融入主网络

相比于传统UNet,这种架构的优势在于:

  • 能够动态地关注辅助特征中最相关的部分

  • 通过注意力机制实现更精细的特征融合

  • 保留了UNet原有的多尺度特征提取能力

  • 通过残差连接避免了信息丢失

6. 总结

本文介绍了一种增强版的UNet架构,通过引入交叉注意力机制,使模型能够更有效地利用辅助特征。这种设计既保留了UNet原有的优势,又增加了灵活的特征融合能力,特别适合需要整合多源信息的复杂视觉任务。

在实际应用中,可以根据具体任务需求选择在哪些层级启用交叉注意力,也可以调整注意力模块的复杂度来平衡模型性能和计算开销。

希望这篇文章能帮助你理解交叉注意力在UNet中的应用。如果你有任何问题或建议,欢迎在评论区留言讨论!

完整代码

如下:

import torch.nn as nn
import torch
import mathclass CrossAttention(nn.Module):def __init__(self, channels):super(CrossAttention, self).__init__()self.query_conv = nn.Conv2d(channels, channels // 8, kernel_size=1)self.key_conv = nn.Conv2d(channels, channels // 8, kernel_size=1)self.value_conv = nn.Conv2d(channels, channels, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))def forward(self, x1, x2):"""x1: 主特征 (batch, channels, height, width)x2: 辅助特征 (batch, channels, height, width)"""batch_size, C, height, width = x1.size()# 投影到query, key, value空间proj_query = self.query_conv(x1).view(batch_size, -1, height * width).permute(0, 2, 1)  # (B, N, C')proj_key = self.key_conv(x2).view(batch_size, -1, height * width)  # (B, C', N)proj_value = self.value_conv(x2).view(batch_size, -1, height * width)  # (B, C, N)# 计算注意力图energy = torch.bmm(proj_query, proj_key)  # (B, N, N)attention = torch.softmax(energy / math.sqrt(proj_key.size(-1)), dim=-1)# 应用注意力out = torch.bmm(proj_value, attention.permute(0, 2, 1))  # (B, C, N)out = out.view(batch_size, C, height, width)# 残差连接out = self.gamma * out + x1return outclass DoubleConv(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(DoubleConv, self).__init__()self.use_cross_attention = use_cross_attentionself.conv1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True)self.conv2 = nn.Sequential(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))if use_cross_attention:self.cross_attention = CrossAttention(out_channels)def forward(self, x, aux_feature=None):x = self.conv1(x)x = self.conv2(x)if self.use_cross_attention and aux_feature is not None:x = self.cross_attention(x, aux_feature)return xclass Down(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(Down, self).__init__()self.downsampling = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2),DoubleConv(in_channels, out_channels, use_cross_attention))def forward(self, x, aux_feature=None):return self.downsampling[1](self.downsampling[0](x), aux_feature)class Up(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(Up, self).__init__()self.upsampling = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels, use_cross_attention)def forward(self, x1, x2, aux_feature=None):x1 = self.upsampling(x1)x = torch.cat([x2, x1], dim=1)x = self.conv(x, aux_feature)return xclass UNetWithCrossAttention(nn.Module):def __init__(self, in_channels=1, num_classes=1, use_cross_attention=False):super(UNetWithCrossAttention, self).__init__()self.in_channels = in_channelsself.num_classes = num_classesself.use_cross_attention = use_cross_attention# 编码器self.in_conv = DoubleConv(in_channels, 64, use_cross_attention)self.down1 = Down(64, 128, use_cross_attention)self.down2 = Down(128, 256, use_cross_attention)self.down3 = Down(256, 512, use_cross_attention)self.down4 = Down(512, 1024, use_cross_attention)# 解码器self.up1 = Up(1024, 512, use_cross_attention)self.up2 = Up(512, 256, use_cross_attention)self.up3 = Up(256, 128, use_cross_attention)self.up4 = Up(128, 64, use_cross_attention)self.out_conv = OutConv(64, num_classes)def forward(self, x, aux_feature=None):# 编码过程x1 = self.in_conv(x, aux_feature)x2 = self.down1(x1, aux_feature)x3 = self.down2(x2, aux_feature)x4 = self.down3(x3, aux_feature)x5 = self.down4(x4, aux_feature)# 解码过程x = self.up1(x5, x4, aux_feature)x = self.up2(x, x3, aux_feature)x = self.up3(x, x2, aux_feature)x = self.up4(x, x1, aux_feature)x = self.out_conv(x)return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/88062.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/88062.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C#里从CSV文件加载BLOB数据字段到数据库的处理

大量的数据保存在CSV文件, 当需要把这些数据加载到数据库,然后使用数据库来共享出去。 就需要把CSV文件导入数据库, 怎么样快速地把CSV文件导入数据库呢? 这个就需要使用类MySqlBulkLoader,它是mariadb数据库快速导入的方式。 一般使用SQL语句导入是10秒,那么使用这种方…

【后端】负载均衡

长期不定期更新补充。 定义 负载均衡(Load Balancing)是指将来自客户端的请求合理分发到多个服务器或服务节点,以提高系统性能、可用性与可靠性。 分工 前端不做负载均衡,前端只发请求,不知道请求去哪台服务器。 负…

记录一次:Java Web 项目 CSS 样式/图片丢失问题:一次深度排查与根源分析

记录一次:Java Web 项目 CSS 样式/图片丢失问题:一次深度排查与根源分析 **记录一次:Java Web 项目 CSS 样式丢失问题:一次深度排查与根源分析****第一层分析:资源路径问题****第二层分析:服务端跳转逻辑**…

torchmd-net开源程序是训练神经网络潜力

​一、软件介绍 文末提供程序和源码下载 TorchMD-NET 提供最先进的神经网络电位 (NNP) 和训练它们的机制。如果有多个 NNP,它可提供高效、快速的实现,并且它集成在 GPU 加速的分子动力学代码中,如 ACEMD、OpenMM 和 …

在Docker上安装Mongo及Redis-NOSQL数据库

应用环境 Ubuntu 20.04.6 LTS (GNU/Linux 5.15.0-139-generic x86_64) Docker version 28.1.1, build 4eba377 文章目录 一、部署Mongo1. 拉取容器镜像2. 生成Run脚本2.1 准备条件2.2 参数解读2.3 实例脚本 3. 实例操作3.1 Mongo bash控制台3.2 库表操作 4. MongoDB Compass (G…

Java 编程之责任链模式

一、什么是责任链模式? 责任链模式(Chain of Responsibility Pattern) 是一种行为型设计模式,它让多个对象都有机会处理请求,从而避免请求的发送者和接收者之间的耦合关系。将这些对象连成一条链,沿着这条…

1、做中学 | 一年级上期 Golang简介和安装环境

一、什么是golang Golang,通常简称 Go,是由 Google 公司的 Robert Griesemer、Rob Pike 和 Ken Thompson 于 2007 年创建的一种开源编程语言,并在 2009 年正式对外公布。 已经有了很多编程语言,为什么还要创建一种新的编程语言&…

Linux--迷宫探秘:从路径解析到存储哲学

上一篇博客我们说完了文件系统在硬件层面的意义,今天我们来说说文件系统在软件层是怎么管理的。 Linux--深入EXT2文件系统:数据是如何被组织、存储与访问的?-CSDN博客 🌌 引言:文件系统的宇宙观 "在Linux的宇宙中…

淘宝商品数据实时获取方案|API 接口开发与安全接入

在电商数据获取领域,除了官方 API,第三方数据 API 接入也是高效获取淘宝商品数据的重要途径。第三方数据 API 凭借丰富的功能、灵活的服务,为企业和开发者提供了多样化的数据解决方案。本文将聚焦第三方数据 API 接入,详细介绍其优…

什么是防抖和节流?它们有什么区别?

文章目录 一、防抖(Debounce)1.1 什么是防抖?1.2 防抖的实现 二、节流(Throttle)2.1 什么是节流?2.2 节流的实现方式 三、防抖与节流的对比四、总结 在前端开发中,我们经常会遇到一些高频触发的…

Springboot集成阿里云OSS上传

Springboot集成阿里云OSS上传 API 接口描述 DEMO提供的四个API接口,支持不同方式的文件和 JSON 数据上传: 1. 普通文件上传接口 上传任意类型的文件 2. JSON 字符串上传接口 上传 JSON 字符串 3. 单个 JSON 压缩上传接口 上传并压缩 JSON 字符串…

删除大表数据注意事项

数据库是否会因删除操作卡死,没有固定的 “安全删除条数”,而是受数据库配置、表结构、操作方式、当前负载等多种因素影响。以下是关键影响因素及实践建议: 一、导致数据库卡死的核心因素 硬件与数据库配置 CPU / 内存瓶颈:删除…

Redis 是单线程模型?|得物技术

一、背景 使用过Redis的同学肯定都了解过一个说法,说Redis是单线程模型,那么实际情况是怎样的呢? 其实,我们常说Redis是单线程模型,是指Redis采用单线程的事件驱动模型,只有并且只会在一个主线程中执行Re…

[特殊字符] AIGC工具深度实战:GPT与通义灵码如何彻底重构企业开发流程

🔍 第一模块:理念颠覆——为什么AIGC不是“玩具”而是“效能倍增器”? ▍企业开发的核心痛点图谱(2025版) ​​研发效能瓶颈​​:需求膨胀与交付时限矛盾持续尖锐,传统敏捷方法论已触天花板​…

(LeetCode 面试经典 150 题) 169. 多数元素(哈希表 || 二分查找)

题目&#xff1a;169. 多数元素 方法一&#xff1a;二分法&#xff0c;最坏的时间复杂度0(nlogn)&#xff0c;但平均0(n)即可。空间复杂度为0(1)。 C版本&#xff1a; int nnums.size();int l0,rn-1;while(l<r){int mid(lr)/2;int ans0;for(auto x:nums){if(xnums[mid]) a…

(17)java+ selenium->自动化测试-元素定位大法之By css上

1.简介 CSS定位方式和xpath定位方式基本相同,只是CSS定位表达式有其自己的格式。CSS定位方式拥有比xpath定位速度快,且比CSS稳定的特性。下面详细介绍CSS定位方式的使用方法。相对CSS来说,具有语法简单,定位速度快等优点。 2.CSS定位优势 CSS定位是平常使用过程中非常重要…

【软考高级系统架构论文】企业集成平台的技术与应用

论文真题 企业集成平台是一个支持复杂信息环境下信息系统开发、集成和协同运行的软件支撑环境。它基于各种企业经营业务的信息特征,在异构分布环境(操作系统、网络、数据库)下为应用提供一致的信息访问和交互手段,对其上运行的应用进行管理,为应用提供服务,并支持企业信息…

i.MX8MP LVDS 显示子系统全解析:设备树配置与 DRM 架构详解

&#x1f525; 推荐&#xff1a;《Yocto项目实战教程&#xff1a;高效定制嵌入式Linux系统》 京东正版促销&#xff0c;欢迎支持原创&#xff01; 链接&#xff1a;https://item.jd.com/15020438.html i.MX8MP LVDS 显示子系统全解析&#xff1a;设备树配置与 DRM 架构详解 在…

keep-alive实现原理及Vue2/Vue3对比分析

一、keep-alive基本概念 keep-alive是Vue的内置组件&#xff0c;用于缓存组件实例&#xff0c;避免重复渲染。它具有以下特点&#xff1a; 抽象组件&#xff1a;自身不会渲染DOM&#xff0c;也不会出现在父组件链中包裹动态组件&#xff1a;缓存不活动的组件实例&#xff0c;…

安卓jetpack compose学习笔记-Navigation基础学习

目录 一、Navigation 二、BottomNavigation Compose是一个偏向静态刷新的UI组件&#xff0c;如果不想要自己管理页面切换的复杂状态&#xff0c;可以以使用Navigation组件。 页面间的切换可以NavHost&#xff0c;使用底部页面切换栏&#xff0c;可以使用脚手架的bottomBarNav…