Pytorch模型复现笔记-STN(空间注意力Transformer网络)讲解+架构搭建(可直接copy运行)+ MNIST数据集视角调整实验

Spatial Transformer Networks

本文了讲述STN的基本架构,空间几何注意力模块的基本原理,冒烟测试以及STN在MNIST数据集用于模型自动调整图片视角的实验,如果大家有不懂或者发现了错误的地方,欢迎讨论。

  • 中文名:空间Transformer网络
  • 论文链接:Arxiv

我更倾向于叫它为Spatio Geometry Transformer Network, 因为它的注意力同时包括了是旋转,平移,仿射等多种几何变换,而不是单纯地裁剪以注意空间里面的重点。

目录

  • Spatial Transformer Networks
    • 模型简介
    • 提出背景
    • 设计思路
    • 达到效果及优势
    • 对后续模型的影响
    • 网络结构
    • Pytorch模型实现+MNIST数据集视角调整实验
      • 准备库、数据集、数据加载器
      • 定义网络结构
      • 定义训练/测试过程
      • 可视化效果

模型简介

  • 作者:Ghassen HAMROUNI
  • 发布年:2015
  • 为什么这么叫:因为它用使用了空间几何注意力
  • 主要成就:第一个使用空间几何注意力的卷积神经网络

STN是对任何空间变化的推广,其允许神经网络学习如何对输入图像进行 空间变换,以增强模型的 几何不变性。例如,其可以对感兴趣的区域进行裁剪,或者缩放和矫正图像的方向。这个机制对CNN很有用,因为其对旋转、缩放、甚至更一般的放射变换并非不变。

在这里插入图片描述

提出背景

之前有哪些相同目的的模型/方法?

在STN出现之前,主要是用纯粹的CNN来对图像进行特征图的提取。但是由于CNN对于图像的几何变换的鲁棒性不强,因此研究人员设计了几种方法来改良CNN对于几何变换的鲁棒性:

  1. 数据增强(Data Augmentation):这是最简单也最常用的方法。通过对训练集中的图片进行随机旋转、缩放、平移或裁剪等操作,来增加数据的多样性。这使得模型在训练时能接触到更多不同几何形态的样本,从而提高其泛化能力
  2. 多尺度或者多角度训练:直接在多个尺度或角度下对同一图像进行训练,迫使网络学习对这些变换不变的特征。
  3. 使用手工设计的特征(如SIFT、ORB等): 在传统的计算机视觉任务中,人们会使用这些对几何变换具有一定不变性的特征来处理问题。例如,SIFT特征描述子在一定程度上对尺度和旋转是不变的。

之前的模型/方法有什么不足?

  1. 数据增强的局限性: 数据增强虽然有效,但它是一种“静态”的、预先定义好的方法。它并不能让网络自适应地学习应该对哪些变换进行处理。换句话说,模型是在被动地接受经过变换的数据,而不是主动地去“寻找”并“矫正”那些重要的区域。
  2. 计算效率低下:多尺度或多角度训练会显著增加计算量和内存消耗,因为需要为每个变换后的版本都进行一次前向传播。
  3. 无法处理复杂的、非预期的变换: 数据增强通常只覆盖简单的变换,如旋转和平移。对于更复杂的、特定于任务的“扭曲”或“不规则”变换,效果不佳。

设计思路

这个模型针对不足提出了什么改进方案?解决了什么问题?有什么人类直觉在里面?

鉴于之前方法的缺陷,作者从人类直觉的角度进行了思考,他认为,人类之所以能够从一个偏移,旋转,或者不同视角下的图片中还原原本的图片(比如把一不同视角下的5,无论仰视还是俯视都可以看出来),是因为我们脑袋中 “自带一个用来进行动态几何变换的机制” (可以理解为自带一个“自适应”的几何变换矩阵),我们能根据注意力自动调整这个矩阵的参数,把图像校正到我们大脑中最容易理解和识别的“理想状态”。
在这里插入图片描述
以下是几何变换矩阵的原理:
在这里插入图片描述
如果对变换矩阵施加如下约束,那么这个矩阵则可以对原图进行旋转,平移,以及仿射变换。
在这里插入图片描述

但是实际的代码实现中并不会对其施加以上约束,因为模型可能通过学习学到更加高级的几何变换,而不仅仅局限于以上三种变换。

达到效果及优势

在这里插入图片描述

  • 性能提升: 相比于其他没有使用STN的模型(如 Cimpoi '15, Simon '15 等),使用了STN的CNN模型在CUB-200-2011鸟类分类数据集上的准确率有了明显的提升。在高分辨率的图片上性能提升更加显著。
  • 可解释的空间变换注意力: 这是STN最直观也最令人兴奋的优势。图表右侧的图片展示了2xST-CNN和4xST-CNN模型中STN模块学习到的空间变换。论文作者在图中特别指出,对于2xST-CNN模型,一个STN模块(红色框)学习定位和放大鸟的头部,而另一个STN模块(绿色框)则学习定位和放大鸟的身体。也就是每个STN模块都注意到了不同的,但是对鸟的分类至关重要东西!
  • 即插即用的模块:STN最大的优势之一就是他能非常容易地插入任何现有的CNN,并且只需要很小的修改!

对后续模型的影响

  1. 开创“可学习的变换”思想
    STN首次将空间变换的能力作为可学习的模块集成到神经网络中。它证明了网络可以自己决定如何对输入数据进行几何变换,而不是依赖于预先设定好的规则(如数据增强)。 这种思想被广泛应用于各种需要处理非刚性、非线性变换的任务中。例如,在医学图像处理中,STN的思想被用来进行图像配准(Image Registration),自动对齐不同时间或不同设备拍摄的病灶图像。

  2. 空间注意力机制的先驱
    尽管STN的关注点是“几何变换”,但它通过定位并变换最关键的区域,实际上实现了一种形式的注意力。它让网络将“注意力”集中在最关键的像素或特征上,并将其“摆正”以方便后续处理。它的成功启发了后续的注意力机制(Attention Mechanism)研究。虽然STN是“空间注意力”,更严谨一定来说叫“空间几何注意力”,但它证明了让网络“有选择地”关注输入中最重要的部分是提高性能的有效手段。这为后来更广泛的通道注意力(Channel Attention)自注意力(Self-Attention)以及Transformer模型的兴起奠定了基础。

  3. “即插即用”模块化设计的典范
    STN模块可以轻松地插入到任何CNN架构中,这极大地降低了其应用门槛,并展示了模块化设计在深度学习中的巨大潜力。* 这种设计理念被广泛采纳。如今,许多深度学习模型都由各种可插拔的模块组成,比如SENet中的“通道注意力”模块、ECA-Net中的“高效通道注意力”模块等等。这些模块都遵循了STN的“即插即用”设计思想,让研究人员可以更容易地进行模型改进和创新。

总而言之,STN的贡献远不止于提高了一点点准确率。它引入的 “可学习变换”“空间注意力”“模块化设计” 等核心思想,深刻地影响了后续的计算机视觉和深度学习研究,成为连接传统CNN和现代Transformer模型的一个重要桥梁。

网络结构

在这里插入图片描述
Spatial Transformer 模组可以分解成如下三个关键组成部分:

  1. 定位网络(localisation net):其中包括两个全连接层,第一个层负责提取图片中的基础几何信息,第二个层负责根据基础几何信息回归出几何变换矩阵
  2. 网格生成器(grid generator): 负责根据生成的变换矩阵生成变换网格,本质上是定义了图像的变换
  3. 采样器(Sampler):利用定义好的grid对原图片进行变换

Pytorch模型实现+MNIST数据集视角调整实验

准备库、数据集、数据加载器

首先我们把库和数据集,以及数据加载器准备好:


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as npplt.ion()   # interactive modefrom six.moves import urllib
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Training dataset
train_loader = torch.utils.data.DataLoader(datasets.MNIST(root='.', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])), batch_size=64, shuffle=True, num_workers=4)
# Test dataset
test_loader = torch.utils.data.DataLoader(datasets.MNIST(root='.', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])), batch_size=64, shuffle=True, num_workers=4)

定义网络结构

接下来就是搭建我们带有一个简单Spatio Transformer 模块的网络了:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear(320, 50)# 修复:等号'='而不是'-'self.fc2 = nn.Linear(50, 10)self.flatten = nn.Flatten()# STN模块组成-特征提取localization子网络, 用来先提取能反应图像几何信息的高级特征# 修复:括号'()'而不是')'self.localization = nn.Sequential(nn.Conv2d(1, 8, kernel_size=7),nn.MaxPool2d(2, stride=2),nn.ReLU(True),nn.Conv2d(8, 10, kernel_size=5),nn.MaxPool2d(2, stride=2),nn.ReLU(True))# STN模块组成-全连接定位网络-回归仿射矩阵self.fc_loc = nn.Sequential(nn.Linear(10 * 3 * 3, 32), nn.ReLU(True),# 直接把几何变换矩阵全连接出来nn.Linear(32, 3 * 2))# 网络权重初始化,保证一开始的变换矩阵从什么都不做的单位帧self.fc_loc[2].weight.data.zero_() # 定位网络直接用0初始化权重self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) """[1, 0, 0][0, 1, 0]"""def stn(self, x):xs = self.localization(x)xs = self.flatten(xs)# 修复:xs_theta = self.fc_loc(xs)# 重塑成矩阵形状theta = theta.view(-1, 2, 3)# 网格生成部分,对图像引用几何变换矩阵产生新图像grid = F.affine_grid(theta, x.size(), align_corners=True)x = F.grid_sample(x, grid, align_corners=True)return xdef forward(self, x):x = self.stn(x)# 传给分类层# Perform the usual forward passx = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))x = self.flatten(x)x = F.relu(self.fc1(x))x = F.dropout(x, training=self.training)x = self.fc2(x)return F.log_softmax(x, dim=1)# 冒烟测试
# 修复:缺失 device 的定义
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net().to(device)
# 修复:缺失 .to(device)
input = torch.randn(1, 1, 28, 28).to(device)
output = model(input)
# 修复:输出的维度是 [1, 10],所以应该使用 output.shape
print(output.shape) # 输出torch.Size([1, 10])

定义训练/测试过程

optimizer = optim.SGD(model.parameters(), lr=0.01)def train(epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if batch_idx % 500 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))
#
# A simple test procedure to measure the STN performances on MNIST.
#def test():with torch.no_grad():model.eval()test_loss = 0correct = 0for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)# sum up batch losstest_loss += F.nll_loss(output, target, size_average=False).item()# get the index of the max log-probabilitypred = output.max(1, keepdim=True)[1]correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

可视化效果

我们这里直接可视化原图像在经过空间注意力模块调整之后会变成啥样:

def convert_image_np(inp):"""Convert a Tensor to numpy image."""inp = inp.numpy().transpose((1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])inp = std * inp + meaninp = np.clip(inp, 0, 1)return inp# We want to visualize the output of the spatial transformers layer
# after the training, we visualize a batch of input images and
# the corresponding transformed batch using STN.def visualize_stn():with torch.no_grad():# Get a batch of training datadata = next(iter(test_loader))[0].to(device)input_tensor = data.cpu()transformed_input_tensor = model.stn(data).cpu()in_grid = convert_image_np(torchvision.utils.make_grid(input_tensor))out_grid = convert_image_np(torchvision.utils.make_grid(transformed_input_tensor))# Plot the results side-by-sidef, axarr = plt.subplots(1, 2)axarr[0].imshow(in_grid)axarr[0].set_title('Dataset Images')axarr[1].imshow(out_grid)axarr[1].set_title('Transformed Images')for epoch in range(1, 20 + 1):train(epoch)test()# Visualize the STN transformation on some input batch
visualize_stn()plt.ioff()
plt.show()

在这里插入图片描述
最终的准确率在epoch20之后达到了99%,并且可以看到STN模块确实对原图进行了几何变换,把图像校正到了对模型,甚至对人类都更容易理解和识别的“理想状态”。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/93788.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/93788.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【LeetCode】16. 最接近的三数之和

文章目录16. 最接近的三数之和题目描述示例 1:示例 2:提示:解题思路算法分析问题本质分析排序双指针法详解双指针移动策略搜索过程可视化各种解法对比算法流程图边界情况处理时间复杂度分析空间复杂度分析关键优化点实际应用场景测试用例设计…

微信小程序实现蓝牙开启自动播放BGM

下面是一个完整的微信小程序实现方案,当蓝牙设备连接时自动播放背景音乐(BGM)。实现思路监听蓝牙设备连接状态当检测到蓝牙设备连接时,自动播放音乐当蓝牙断开时,停止音乐播放处理相关权限和用户交互完整代码实现1. 项目结构text/pages/index…

XML 序列化与操作详解笔记

一、XML 基础概念XML&#xff08;eXtensible Markup Language&#xff0c;可扩展标记语言&#xff09;是一种用于存储和传输数据的标记语言&#xff0c;由 W3C 制定&#xff0c;具有以下特点&#xff1a;可扩展性&#xff1a;允许自定义标记&#xff08;如<Student>、<…

第八十四章:实战篇:图 → 视频:基于 AnimateDiff 的视频合成链路——让你的图片“活”起来,瞬间拥有“电影感”!

AI图生视频前言&#xff1a;从“刹那永恒”到“动态大片”——AnimateDiff&#xff0c;让图片“活”起来&#xff01;第一章&#xff1a;痛点直击——静态图像到视频&#xff0c;不是“幻灯片”那么简单&#xff01;第二章&#xff1a;探秘“时间魔法”&#xff1a;AnimateDiff…

2025深大计算机考研复试经验贴(已上岸)

如果你在初试出分前看到此贴 我建议&#xff1a; 准备机试和简历&#xff0c;即使你不估分&#xff1a;因为如果要准备春招的话&#xff0c;也总要刷题和做简历的。尽早估分&#xff0c;查一下往年的复试线&#xff0c;如果有望进复试&#xff0c;可尽早开始准备。 Preface …

用Pygame开发桌面小游戏:从入门到发布

一、引言 Pygame是一个基于Python的跨平台游戏开发库,它提供了简单易用的图形、声音和输入处理功能,非常适合新手入门游戏开发。本文将以"经典游戏合集"项目为例,带你一步步了解如何使用Pygame开发、打包和发布自己的桌面小游戏。 二、开发环境搭建 安装Python:…

CSS backdrop-filter:给元素背景添加模糊与色调的高级滤镜

在现代网页设计中&#xff0c;半透明元素搭配背景模糊效果已成为流行趋势 —— 从毛玻璃导航栏、模态框遮罩&#xff0c;到卡片悬停效果&#xff0c;这种设计能让界面更具层次感和高级感。实现这一效果的核心 CSS 属性&#xff0c;正是backdrop-filter。它能对元素背后的内容&a…

检索增强生成(RAG) 缓存增强生成(CAG) 生成中检索(RICHES) 知识库增强语言模型(KBLAM)

以下是当前主流的四大知识增强技术方案对比&#xff0c;涵盖核心原理、适用场景及最新发展趋势&#xff0c;为开发者提供清晰的技术选型参考&#xff1a; &#x1f50d; 一、RAG&#xff08;检索增强生成&#xff09;​​ 核心原理​&#xff1a; 动态检索外部知识库&#xff0…

LLM(大语言模型)的工作原理 图文讲解

目录 1. 条件概率&#xff1a;上下文预测的基础 2. LLM 是如何“看着上下文写出下一个词”的&#xff1f; 补充说明&#xff08;重要&#xff09; &#x1f4cc; Step 1: 输入处理 &#x1f4cc; Step 2: 概率计算 &#x1f4cc; Step 3: 决策选择 &#x1f914; 一个有…

Python netifaces 库详解:跨平台网络接口与 IP 地址管理

一、前言 在现代网络编程中&#xff0c;获取本机的网络接口信息和 IP 配置是非常常见的需求。 例如&#xff1a; 开发一个需要选择合适网卡的 网络服务&#xff1b;在多网卡环境下实现 流量路由与控制&#xff1b;在系统诊断工具中展示 IP/MAC 地址、子网掩码、默认网关&#x…

HTML应用指南:利用POST请求获取上海黄金交易所金价数据

上海黄金交易所&#xff08;SGE&#xff09;作为中国唯一经国务院批准、专门从事黄金等贵金属交易的国家级市场平台&#xff0c;自成立以来始终秉持“公开、公平、公正”的原则&#xff0c;致力于构建规范、高效、透明的贵金属交易市场体系。交易所通过完善的交易机制、严格的风…

C++常见面试题-1.C++基础

一、C 基础 1.1 语言特性与区别C 与 C 的主要区别是什么&#xff1f;C 为何被称为 “带类的 C”&#xff1f; 主要区别&#xff1a;C 引入了面向对象编程&#xff08;OOP&#xff09;特性&#xff08;类、继承、多态等&#xff09;&#xff0c;而 C 是过程式编程语言&#xff1…

Tomcat里catalina.sh详解

在 Tomcat 中&#xff0c;catalina.sh&#xff08;Linux/macOS&#xff09;或 catalina.bat&#xff08;Windows&#xff09;是 核心的启动和关闭脚本&#xff0c;用于控制 Tomcat 服务器的运行。它是 Tomcat 的“主控脚本”&#xff0c;负责设置环境变量、启动/关闭 JVM 进程&…

STM32之MCU和GPIO

一、单片机MCU 1.1 单片机和嵌入式 嵌入式系统 以计算机为核心&#xff0c;tips&#xff1a;计算机【处理单元&#xff0c;内存 硬盘】 可以控制的外部设备&#xff0c;传感器&#xff0c;电机&#xff0c;继电器 嵌入式开发 数据源--> 处理器(CPU MCU MPU) --> 执行器 …

22_基于深度学习的桃子成熟度检测系统(yolo11、yolov8、yolov5+UI界面+Python项目源码+模型+标注好的数据集)

目录 项目介绍&#x1f3af; 功能展示&#x1f31f; 一、环境安装&#x1f386; 环境配置说明&#x1f4d8; 安装指南说明&#x1f3a5; 环境安装教学视频 &#x1f31f; 二、数据集介绍&#x1f31f; 三、系统环境&#xff08;框架/依赖库&#xff09;说明&#x1f9f1; 系统环…

数据结构:二叉树oj练习

在讲今天的题目之前&#xff0c;我们还需要讲一下二叉树的以下特点&#xff1a; 对任意一颗二叉树&#xff0c;如果度为0的节点个数是n0&#xff0c;度为2的节点个数是n2&#xff0c;则有n0n21. 证明&#xff1a;二叉树总的节点个数是n&#xff0c;那么有nn0n1n2 二叉树的度为…

RabbitMQ高级特性——TTL、死信队列、延迟队列、事务、消息分发

目录 一、TTL 1.1设置消息的TTL 1.2设置队列的TTL 1.3两者之间的区别 二、死信队列 2.1死信的概念 2.2死信产生的条件&#xff1a; 2.3死信队列的实现 死信队列的工作原理 2.4常⻅⾯试题 三、延迟队列 3.1概念 3.2应用场景 3.3RabbitMQ 实现延迟队列的核心原理 1…

神经网络设计中关于BN归一化(Normalization)的讨论

在神经网络的结构中&#xff0c;我们常常可以看见归一化&#xff08;Normalization&#xff09;如BN的出现&#xff0c;无论是模型的backbone或者是neck的设计都与它有着重大的关系。 因此引发了我对它的思考&#xff0c;接下来我将从 是什么&#xff08;知识领域&#xff0c;诞…

MacOS 安全机制与“文件已损坏”排查完整指南

1. 背景说明macOS 为了保护系统安全&#xff0c;内置了多个安全机制&#xff1a;机制作用是否影响第三方 AppSIP (System Integrity Protection)保护系统关键文件/目录不被篡改高风险 App/驱动可能受限Gatekeeper限制未签名/未认证 App 运行阻止“未知开发者” App文件隔离属性…

package.json文件中的devDependencies和dependencies对象有什么区别?

前端项目的package.json文件中&#xff0c;dependencies和devDependencies对象都用于指定项目所依赖的软件包&#xff0c;但它们在项目的开发和生产环境中的使用有所不同。1.dependencies&#xff1a;dependencies是指定项目在生产环境中运行所需要的依赖项。这些依赖项通常包括…