VGG改进(8):融合Self-Attention的CNN架构

1. 自注意力机制简介

自注意力机制是Transformer架构的核心组件,它能够计算输入序列中每个元素与其他所有元素的相关性。与CNN的局部感受野不同,自注意力机制允许模型直接建立远距离依赖关系,从而捕获全局上下文信息。

在计算机视觉中,这意味着模型不仅能够关注图像的局部特征(如边缘、纹理),还能理解这些特征在全局范围内的相互关系。这种能力对于复杂视觉任务(如场景理解、细粒度分类)尤为重要。

2. VGG16架构回顾

VGG16由牛津大学视觉几何组提出,其核心特点是使用小尺寸卷积核(3×3)构建深度网络。网络包含5个卷积块,每个块后接最大池化层进行下采样,最后通过三个全连接层完成分类。

VGG16的优势在于其简洁性和有效性,但局限性也很明显:卷积操作的局部性限制了模型捕获长距离依赖的能力,而全连接层的参数量过大容易导致过拟合。

3. 自注意力与CNN的融合策略

将自注意力机制引入CNN有多种方式,本文实现的是一种局部-全局特征融合策略:在CNN提取局部特征后,通过自注意力机制增强这些特征的全局上下文信息。

具体来说,我们在VGG16的特定卷积块后插入Transformer编码器层,使模型能够在不同抽象层次上融合全局信息。这种设计有以下优势:

  1. 多尺度特征增强:在不同深度的卷积层后添加注意力,可以捕获从低级到高级的多尺度全局信息

  2. 计算效率:仅在选定位置添加注意力模块,平衡了性能与计算开销

  3. 架构灵活性:可以选择在不同深度添加注意力,适应不同任务的需求

4. 代码实现解析

4.1 自注意力机制实现

class SelfAttention(nn.Module):"""标准Transformer自注意力机制"""def __init__(self, embed_dim, num_heads=8, dropout=0.1):super(SelfAttention, self).__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_headsassert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)self.out_proj = nn.Linear(embed_dim, embed_dim)self.dropout = nn.Dropout(dropout)

自注意力模块首先通过线性变换生成查询(Query)、键(Key)和值(Value)三个矩阵,然后将输入分割成多个头进行并行计算,最后将结果合并并通过输出投影层。

4.2 Transformer编码器层

class TransformerEncoderLayer(nn.Module):"""Transformer编码器层"""def __init__(self, embed_dim, num_heads=8, dropout=0.1, expansion_factor=4):super(TransformerEncoderLayer, self).__init__()self.self_attn = SelfAttention(embed_dim, num_heads, dropout)self.norm1 = nn.LayerNorm(embed_dim)self.norm2 = nn.LayerNorm(embed_dim)self.ffn = nn.Sequential(nn.Linear(embed_dim, embed_dim * expansion_factor),nn.ReLU(inplace=True),nn.Dropout(dropout),nn.Linear(embed_dim * expansion_factor, embed_dim),nn.Dropout(dropout))

编码器层遵循标准Transformer结构,包含一个自注意力子层和一个前馈神经网络子层,每个子层都使用残差连接和层归一化。

4.3 VGG16与注意力的融合

class VGG16WithAttention(nn.Module):def __init__(self, num_classes=1000, attention_positions=[3, 4]):super(VGG16WithAttention, self).__init__()# 卷积特征提取层self.features = nn.Sequential(...)self.attention_positions = attention_positionsself.attention_layers = nn.ModuleDict()# 在指定位置添加注意力层if 3 in attention_positions:self.attention_layers['block3'] = TransformerEncoderLayer(256)if 4 in attention_positions:self.attention_layers['block4'] = TransformerEncoderLayer(512)if 5 in attention_positions:self.attention_layers['block5'] = TransformerEncoderLayer(512)

在VGG16WithAttention类中,我们保留了原始VGG16的特征提取层,并在指定位置添加了Transformer编码器层。用户可以通过attention_positions参数灵活选择在哪些卷积块后添加注意力机制。

4.4 前向传播过程

def forward(self, x):features = []# 逐层处理特征for i, layer in enumerate(self.features):x = layer(x)# 在特定卷积块后应用注意力if i == 14 and 3 in self.attention_positions:  # 第三卷积块结束x = self._apply_attention(x, 'block3')elif i == 21 and 4 in self.attention_positions:  # 第四卷积块结束x = self._apply_attention(x, 'block4')elif i == 28 and 5 in self.attention_positions:  # 第五卷积块结束x = self._apply_attention(x, 'block5')

在前向传播过程中,模型首先通过卷积层提取特征,然后在指定位置将特征图重塑为序列形式,应用自注意力机制,最后恢复为原始形状继续传播。

5. 注意力应用的技术细节

5.1 特征图序列化

将2D特征图转换为序列是应用自注意力的关键步骤:

def _apply_attention(self, x, block_name):"""应用自注意力机制"""batch_size, channels, height, width = x.size()# 将特征图重塑为序列形式 [batch_size, seq_len, embed_dim]x_reshaped = x.view(batch_size, channels, -1).transpose(1, 2)# 应用注意力attended = self.attention_layers[block_name](x_reshaped)# 恢复原始形状attended = attended.transpose(1, 2).view(batch_size, channels, height, width)return attended

这里,我们将空间维度(高度×宽度)展平为序列长度,通道维度作为嵌入维度。这种处理方式允许自注意力机制在空间维度上建立全局依赖关系。

5.2 位置编码的考虑

值得注意的是,本文实现的版本没有显式添加位置编码。在标准Transformer中,位置编码用于提供序列中元素的位置信息。对于图像任务,位置信息至关重要,因为像素间的空间关系具有重要含义。

在实际应用中,可以考虑添加以下类型的位置编码:

  1. 可学习的位置编码:随机初始化并通过训练学习

  2. 正弦位置编码:使用不同频率的正弦和余弦函数

  3. 相对位置编码:编码元素间的相对位置而非绝对位置

6. 模型优势与应用场景

6.1 优势分析

  1. 全局上下文建模:自注意力机制使模型能够捕获长距离依赖,理解图像全局结构

  2. 多尺度特征融合:在不同深度添加注意力,实现了多尺度特征的全局融合

  3. 架构灵活性:可以选择性地在不同阶段添加注意力,平衡性能与计算开销

  4. 即插即用:注意力模块可以轻松集成到现有CNN架构中,无需大幅修改

6.2 应用场景

这种混合架构特别适合以下计算机视觉任务:

  1. 细粒度图像分类:需要捕获细微特征差异和全局上下文关系

  2. 场景理解:需要理解场景中多个对象的空间和语义关系

  3. 图像分割:全局上下文信息有助于提高边界准确性和语义一致性

  4. 目标检测:注意力机制可以帮助模型关注相关区域,提高检测精度

7. 实验与性能分析

为了验证融合注意力的VGG16的性能,我们在多个数据集上进行了实验。与原始VGG16相比,融合模型在以下方面表现出优势:

  1. 分类准确率:在ImageNet等复杂数据集上,准确率有显著提升

  2. 收敛速度:注意力机制有助于梯度传播,加速模型收敛

  3. 鲁棒性:对遮挡、旋转等干扰因素表现出更好的鲁棒性

然而,注意力机制也带来了一定的计算开销,参数量和计算量都有所增加。在实际应用中需要根据任务需求和资源约束进行权衡。

8. 扩展与变体

本文介绍的基础架构可以进一步扩展:

  1. 多头注意力:使用多个注意力头捕获不同类型的依赖关系

  2. 跨尺度注意力:在不同尺度的特征图间应用注意力机制

  3. 高效注意力:使用线性注意力、局部注意力等变体降低计算复杂度

  4. 预训练与微调:在大规模数据集上预训练后迁移到特定任务

9. 实践建议

对于希望在实际项目中应用此架构的研究人员和工程师,以下建议可能有所帮助:

  1. 注意力位置选择:浅层注意力捕获空间关系,深层注意力捕获语义关系

  2. 计算资源权衡:在计算资源受限时,可以选择性添加注意力或使用高效变体

  3. 逐步集成:先从单个注意力层开始,逐步增加复杂度

  4. 可视化分析:使用注意力可视化工具理解模型关注区域

完整代码

如下:

import torch
import torch.nn as nn
import mathclass SelfAttention(nn.Module):"""标准Transformer自注意力机制"""def __init__(self, embed_dim, num_heads=8, dropout=0.1):super(SelfAttention, self).__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_headsassert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)self.out_proj = nn.Linear(embed_dim, embed_dim)self.dropout = nn.Dropout(dropout)def forward(self, x):batch_size, seq_len, embed_dim = x.size()# 生成Q, K, Vqkv = self.qkv_proj(x).chunk(3, dim=-1)q, k, v = [part.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) for part in qkv]# 计算注意力分数scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)attn_weights = torch.softmax(scores, dim=-1)attn_weights = self.dropout(attn_weights)# 应用注意力权重attn_output = torch.matmul(attn_weights, v)attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)# 输出投影output = self.out_proj(attn_output)return outputclass TransformerEncoderLayer(nn.Module):"""Transformer编码器层"""def __init__(self, embed_dim, num_heads=8, dropout=0.1, expansion_factor=4):super(TransformerEncoderLayer, self).__init__()self.self_attn = SelfAttention(embed_dim, num_heads, dropout)self.norm1 = nn.LayerNorm(embed_dim)self.norm2 = nn.LayerNorm(embed_dim)self.ffn = nn.Sequential(nn.Linear(embed_dim, embed_dim * expansion_factor),nn.ReLU(inplace=True),nn.Dropout(dropout),nn.Linear(embed_dim * expansion_factor, embed_dim),nn.Dropout(dropout))def forward(self, x):# 自注意力子层attn_output = self.self_attn(x)x = self.norm1(x + attn_output)# 前馈网络子层ffn_output = self.ffn(x)x = self.norm2(x + ffn_output)return xclass VGG16WithAttention(nn.Module):def __init__(self, num_classes=1000, attention_positions=[3, 4]):"""Args:num_classes: 分类数量attention_positions: 在哪些卷积块后添加注意力机制 (1-5)"""super(VGG16WithAttention, self).__init__()# 卷积特征提取层self.features = nn.Sequential(# 第一层卷积块nn.Conv2d(3, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),# 第二层卷积块nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(128, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),# 第三层卷积块nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),# 第四层卷积块nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),# 第五层卷积块nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),)self.attention_positions = attention_positionsself.attention_layers = nn.ModuleDict()# 在指定位置添加注意力层if 3 in attention_positions:self.attention_layers['block3'] = TransformerEncoderLayer(256)if 4 in attention_positions:self.attention_layers['block4'] = TransformerEncoderLayer(512)if 5 in attention_positions:self.attention_layers['block5'] = TransformerEncoderLayer(512)self.avgpool = nn.AdaptiveAvgPool2d((7, 7))self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, num_classes),)def _apply_attention(self, x, block_name):"""应用自注意力机制"""batch_size, channels, height, width = x.size()# 将特征图重塑为序列形式 [batch_size, seq_len, embed_dim]x_reshaped = x.view(batch_size, channels, -1).transpose(1, 2)# 应用注意力attended = self.attention_layers[block_name](x_reshaped)# 恢复原始形状attended = attended.transpose(1, 2).view(batch_size, channels, height, width)return attendeddef forward(self, x):features = []# 逐层处理特征for i, layer in enumerate(self.features):x = layer(x)# 在特定卷积块后应用注意力if i == 14 and 3 in self.attention_positions:  # 第三卷积块结束x = self._apply_attention(x, 'block3')elif i == 21 and 4 in self.attention_positions:  # 第四卷积块结束x = self._apply_attention(x, 'block4')elif i == 28 and 5 in self.attention_positions:  # 第五卷积块结束x = self._apply_attention(x, 'block5')x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return x# 创建带注意力的VGG模型
def vgg16_with_attention(num_classes=1000, attention_positions=[3, 4]):model = VGG16WithAttention(num_classes=num_classes, attention_positions=attention_positions)return model# 示例使用
if __name__ == "__main__":model = vgg16_with_attention(num_classes=1000, attention_positions=[3, 4, 5])# 测试前向传播dummy_input = torch.randn(2, 3, 224, 224)output = model(dummy_input)print(f"Output shape: {output.shape}")print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/96298.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/96298.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ES6 面试题及详细答案 80题 (33-40)-- Symbol与集合数据结构

《前后端面试题》专栏集合了前后端各个知识模块的面试题,包括html,javascript,css,vue,react,java,Openlayers,leaflet,cesium,mapboxGL,threejs&…

PG-210-HI 山洪预警系统呼叫端:筑牢山区应急预警 “安全防线”

在山洪灾害多发的山区,及时、准确的预警信息传递是保障群众生命财产安全的关键。由 PG-210-HI 型号构成的山洪预警系统呼叫端主机,凭借其全面的功能、先进的特性与可靠的性能,成为连接管理员与群众的重要应急枢纽,为山区构建起一道…

研学旅游产品设计实训室:赋能产品落地,培养实用人才

1. 研学旅游产品设计实训室的定位与功能 研学旅游产品设计实训室是专门为学生提供研学课程与产品开发、模拟设计、项目推演、成果展示等实践活动的教学空间。该实训室应支持以下功能: 研学主题设计与目标制定; 课程内容与学习方法的选择与整合&#xf…

4215kg轻型载货汽车变速器设计cad+设计说明书

第一章 前言 3 1.1 变速器的发展环绕现状 3 1.2 本次设计目的和意义 4 第二章 传动机构布置方案分析及设计 5 2.1 传动机构结构分析与类型选择 5 2.2变速器主传动方案的选择 5 2.3 倒档传动方案 6 2..4 变速器零、部件结构方案设计 6 2.4.1 齿轮形式 …

9月10日

TCP客户端代码#include<myhead.h> #define SER_IP "192.168.108.179" //服务器&#xff49;&#xff50;地址 #define SER_PORT 8888 //服务器端口号 #define CLI_IP "192.168.108.239" //客户端&#xff49;&#xff50;地址 …

案例开发 - 日程管理 - 第七期

项目改造&#xff0c;进入 demo-schedule 项目中&#xff0c;下载 pinia 依赖在 main.js 中开启 piniaimport { createApp } from vue import App from ./App.vue import router from ./router/router.js import {createPinia} from pinialet pinia createPinia() const app …

infinityfree 网页连接内网穿透 localtunnel会换 还是用frp成功了

模型库首页 魔搭社区 fatedier/frp: A fast reverse proxy to help you expose a local server behind a NAT or firewall to the internet. 我尝试用本机ipv6&#xff0c;失败了 配置文件 - ChmlFrp 香港2才能用 只支持https CNAME解析 | 怊猫科技 | 文档 How to create …

批量更新数据:Mybatis update foreach 和 update case when 写法及比较

在平常的开发工作中&#xff0c;我们经常需要批量更新数据&#xff0c;业务需要每次批量更新几千条数据&#xff0c;采用 update foreach 写法的时候&#xff0c;接口响应 10s 左右&#xff0c;优化后&#xff0c;采用 update ... case when 写法&#xff0c;接口响应 2s 左右。…

Java基础篇04:数组、二维数组

1 数组 数组是一个数据容器&#xff0c;可用来存储一批同类型的数据。 1.1 数组的定义方式 静态初始化 数据类型[][] 数组名 {元素1&#xff0c;元素2&#xff0c;元素3}; string[][] name {"wfs","jsc","qf"} 动态初始化 数据类型[][] 数组名…

unity开发类似个人网站空间

可以用 Unity 开发 “个人网站空间” 类工具&#xff0c;但需要结合其技术特性和适用场景来判断是否合适。以下从技术可行性、优势、局限性、适用场景四个方面具体分析&#xff1a;一、技术可行性Unity 本质是游戏引擎&#xff0c;但具备开发 “桌面应用” 和 “交互内容” 的能…

SDK游戏盾如何实现动态加密

SDK游戏盾的动态加密体系通过​​密钥动态管理、多层加密架构、协议混淆、AI自适应调整及设备绑定​​等多重机制协同作用&#xff0c;实现对游戏数据全生命周期的动态保护&#xff0c;有效抵御中间人攻击、协议破解、重放攻击等威胁。以下从核心技术与实现逻辑展开详细说明&am…

TensorFlow平台介绍

什么是 TensorFlow&#xff1f; TensorFlow 是一个由 Google Brain 团队 开发并维护的 开源、端到端机器学习平台。它的核心是一个强大的数值计算库&#xff0c;特别擅长于使用数据流图来表达复杂的计算任务&#xff0c;尤其适合大规模机器学习和深度学习模型的构建、训练和部署…

TENGJUN防水TYPE-C连接器:立贴结构与IPX7防护的精密融合

在户外电子、智能家居、车载设备等对连接可靠性与空间适配性要求严苛的场景中&#xff0c;连接器不仅是信号与电力传输的“桥梁”&#xff0c;更需抵御潮湿、粉尘等复杂环境的侵蚀。TENGJUN防水TYPE-C连接器以“双排立贴”为核心设计&#xff0c;融合锌合金底座、精准尺寸控制与…

Spring Boot + Vue 项目中使用 Redis 分布式锁案例

加锁使用命令&#xff1a;set lock_key unique_value NX PX 1000NX:等同于SETNX &#xff0c;只有键不存在时才能设置成功PX&#xff1a;设置键的过期时间为10秒unique_value&#xff1a;一个必须是唯一的随机值&#xff08;UUID&#xff09;&#xff0c;通常由客户端生成…

微信小程序携带token跳转h5, h5再返回微信小程序

需求: 在微信小程序内跳转到h5, 浏览完后点击返回按钮再返回到微信小程序中 微信小程序跳转h5: 微信小程序跳转h5,这个还是比较简单的, 但要注意细节 一、微信小程序代码 1.新建跳转h5页面, 新建文件夹,新建page即可 2.使用web-view标签 wxml页面 js页面 到此为止, 小程序…

【机器学习】通过tensorflow实现猫狗识别的深度学习进阶之路

【机器学习】通过tensorflow实现猫狗识别的深度学习进阶之路 简介 猫狗识别作为计算机视觉领域的经典入门任务&#xff0c;不仅能帮助我们掌握深度学习的核心流程&#xff0c;更能直观体会到不同优化策略对模型性能的影响。本文将从 “从零搭建简单 CNN” 出发&#xff0c;逐步…

异步处理(前端面试)

Promise 1&#xff1a;使用promise原因 了解回调地狱【什么是回调地狱】 1&#xff1a;回调地狱是异步获取结果后&#xff0c;为下一个异步函数提供参数&#xff0c;层层回调嵌入回调 2&#xff1a;导致回调层次很深&#xff0c;代码维护特别困难 3&#xff1a;在没有ES6时&…

3种XSS攻击简单案例

1、接收cookie端攻击机上用python写个接收web程序flask from flask import Flask, request, Responseapp Flask(__name__)app.route(/) def save_cookie():cookie request.args.get(cookie, )if cookie:with open(/root/cookies.txt, a) as f:f.write(f"{cookie}\n"…

Docker 部署生产环境可用的 MySQL 主从架构

简介跨云服务器一主一从&#xff0c;可以自己按照逻辑配置多个从服务器 假设主服务器ip: 192.168.0.4 从服务器ip&#xff1a;192.168.0.5 系统 CentOS7.9 &#xff08;停止维护了&#xff0c;建议大家用 Ubuntu 之类的&#xff0c;我这个没办法&#xff0c;前人在云服务器上…

DeepResearch(上)

概述 OpenAI首先推出Deep Research Agent&#xff0c;深度研究智能体&#xff0c;简称DRA。 通过自主编排多步骤网络探索、定向检索和高阶综合&#xff0c;可将大量在线信息转换为分析师级别的、引用丰富的报告&#xff0c;将数小时的手动桌面研究压缩为几分钟。 作为新一代…