ConvMixer模型:纯卷积为何能够媲美Transformer架构?深入浅出原理与Pytorch代码逐行讲解实现

        ConvMixer 是一个简洁的视觉模型,仅使用标准的卷积层,达到与基于自注意力机制的视觉 Transformer(ViT)相似的性能,由此证明纯卷积架构依然很强大。

核心原理:极简的卷积设计:

        它摒弃了复杂的自注意力模块,只依赖于两种基础的卷积操作:深度卷积(Depthwise Convolution)逐点卷积(Pointwise Convolution)

       制作一杯混合果汁。我们不会把整个水果直接扔进搅拌机,而是先切成小块(分块)。然后,搅拌机有两个关键动作:第一,刀片高速旋转,让每种水果块自己先碎掉(空间混合);第二,整个杯子里的碎块因为搅动而互相融合在一起(通道混合)。

        ConvMixer 的设计与此相似。它认为,复杂的图像特征提取,可以被分解为这两个最基本、最核心的“搅拌”动作,而不需要像 Vision Transformer 那样引入复杂的自注意力机制。

我们来一步步看这个模型是如何工作的。

1. 分块嵌入 (Patch Embedding):

传统卷积的起点:

        传统的卷积网络(如 VGG)通常在开头使用小的卷积核(比如 3x3),步长为1或2。这意味着网络一开始的视野非常小,它是在逐个像素地、非常局部地观察图像。它需要堆叠很多层,才能慢慢地将局部信息组合起来,形成对一个更大区域的理解。

ConvMixer 的革新:

ConvMixer 借鉴了 Vision Transformer (ViT) 的一个核心思想:不要一开始就纠结于像素细节,而是直接把图像切成一块块(Patches),把每一块作为一个基本处理单元。

它如何用卷积实现这一点呢?请看代码:

nn.Conv2d(in_channels=3, out_channels=dim, kernel_size=7, stride=7)
#当卷积核的大小和移动步长相同时,效果就是卷积核在图像上进行不重叠的滑动。
#每滑动一次,这个 7x7 的卷积核就完整地覆盖了一个 7x7 的图像块(Patch)。
#它将这个块内的所有像素信息(3个通道的 7x7=49 个像素)进行一次计算,然后“压缩”成 dim 个通道的 一个 像素点。

这一步的意义:

  1. 降维与提炼:瞬间将高分辨率的图像(如 224x224x3)转换成一个低分辨率的特征图(如 32x32x768)。这大大减少了后续计算量。

  2. 视角转变:强迫模型从一开始就从一个“区域”(Patch)的层面去理解图像,而不是从单个像素。这与人类的视觉习惯更相似,我们看一张图也是先看整体布局和各个区域,再看细节。

  3. 信息嵌入out_channels=dim 这个参数(例如 dim=768)意味着每个图像块被转换成了一个包含 768 个特征的向量。这个过程被称为“嵌入”(Embedding),它将原始的像素信息转化成了更利于模型处理的、高维的抽象特征

2. ConvMixer 层:

        这是模型的核心,它由 深度卷积 (Depthwise Convolution)逐点卷积 (Pointwise Convolution) 构成。这种组合也被称为 深度可分离卷积 (Depthwise Separable Convolution),是 MobileNet 等轻量级网络的基石。

深度卷积 (Depthwise Conv):空间混合

        经过分块嵌入后,我们得到了一个 dim 通道(比如 768 个通道)的特征图。每个通道都可以看作是图像在某个特定方面的特征表达(比如某个通道可能对轮廓敏感,另一个对纹理敏感)。一个 9x9 的普通卷积核,在计算输出特征图的一个点时,会同时查看输入特征图上 9x9 区域内的 所有 768 个通道的信息,然后把它们加权求和。这是“空间混合”和“通道混合”同时进行的,计算开销巨大。

深度卷积却将这两个过程分离开。深度卷积只负责空间混合。

具体过程:

  1. 一个通道,一个专属卷积核:如果输入有 C 个通道,深度卷积就会使用 C 个扁平的(2D)卷积核(例如 3x3x1)。

  2. 独立工作:第1个卷积核只负责在第1个输入通道上滑动,第2个卷积核只负责第2个通道……以此类推。

  3. 保持通道数:处理完成后,输出的通道数仍然是 C。它只在每个通道内部进行了空间特征提取,但通道之间还是完全隔离的。

核心目的:用极低的计算成本,在每个特征通道内部有效地捕捉空间模式。

逐点卷积 (Pointwise Convolution):通道混合:

        深度卷积完成了空间特征整理,但留下了致命问题:通道之间完全没有信息交流。这就像一个公司里,销售、技术、市场三个部门都各自完成了自己的KPI,但他们之间从不开会,公司无法形成合力。逐点卷积就是来主持这场“跨部门会议”的。它只专注于第二步:通道混合。它的工作方式非常简单,就是一次 1x1 的卷积。

        

具体过程

  1. 微型卷积核:它的卷积核大小是 1x1。这意味着它在空间上看的范围只有一个像素点,所以它完全不做空间混合。

  2. 贯穿所有通道:这个 1x1 的卷积核是立体的(例如 1x1xC,C是深度卷积的输出通道数,;比如768个通道)。在特征图的每一个像素点上,它都会同时考虑所有 768个通道的值,然后进行加权求和,输出一个新值。

  3. 重组特征:通过使用 N 个这样的 1x1xC 卷积核,它就可以将输入的 C 个通道的信息,重新组合成 N 个全新的、更有意义的特征通道。

核心目的:在不同通道之间建立联系,让模型学习如何将从不同通道提取出的空间特征(比如“有笔直的轮廓”、“有红色的纹理”)组合成更高级的概念(比如“这是一支笔”)。

当 深度卷积 和 逐点卷积 按顺序组合在一起时,就构成了大名鼎鼎的 深度可分离卷积。

流程:输入 -> 深度卷积 (空间混合) -> 逐点卷积 (通道混合) -> 输出

这个结构可以成功的原因来自于它背后的假设:空间相关性(一个区域内的像素关系)和通道相关性(不同特征之间的关系)是可以被分开处理的,事实证明,这种解耦思想很成功。

3. 数据参数对比:

假设我们有如下任务:

  • 输入特征图: 16x16x256 (高 x 宽 x 通道数)

  • 输出特征图: 16x16x512

  • 卷积核大小: 3x3

方案一:标准卷积

  • 需要 5123x3x256 的立体卷积核。

  • 总参数量 = 3×3×256×512=1,179,648

方案二:深度可分离卷积

  1. 深度卷积 (空间混合):

    • 需要 2563x3x1 的扁平卷积核。

    • 参数量 = 3×3×256=2,304

    • 得到一个 16x16x256 的中间特征图。

  2. 逐点卷积 (通道混合):

    • 需要 5121x1x256 的卷积核,将 256 通道变为 512 通道。

    • 参数量 = 1×1×256×512=131,072

  • 总参数量 = 2,304+131,072=133,376

结果对比: 标准卷积需要约 118 万 参数,而深度可分离卷积只需要约 13 万 参数,参数量减少到了原来的 11% 左右

这就是为什么深度可分离卷积成为了 MobileNet、Xception、ConvMixer 等高效模型的基石。它用极低的成本,实现了与标准卷积非常接近的特征提取能力。

4. Pytorch代码逐行讲解实现:

我们回顾一下结构:

1. 核心组件:ConvMixerLayer

我们先构建模型最小、也是最核心的重复单元——ConvMixerLayer。它包含了我们详细讨论过的 深度卷积逐点卷积残差连接

        

import torch
import torch.nn as nnclass ConvMixerLayer(nn.Module):"""ConvMixer 的核心重复层。包含一个深度卷积和一个逐点卷积,并通过残差连接。"""def __init__(self, dim, kernel_size=9):# 初始化 PyTorch 模块super().__init__()# --- 定义层的各个组件 ---# 1. 深度卷积 (Depthwise Convolution)#    负责在每个通道内部进行空间信息混合。self.depthwise_conv = nn.Conv2d(dim,                      # 输入通道数。dim,                      # 输出通道数与输入相同。kernel_size=kernel_size,  # 使用一个较大的卷积核(如9x9)来获取大感受野。groups=dim,               # 分组数=通道数,这是实现“深度卷积”的关键技巧。padding="same"            # 'same' 填充可以确保卷积后特征图的高和宽不变。)# 2. 激活函数 (Activation)#    为模型引入非线性,GELU 是 Transformer 中常用激活函数。self.activation = nn.GELU()# 3. 批归一化 (Batch Normalization)#    在网络层之间稳定和加速训练。self.norm = nn.BatchNorm2d(dim)# 4. 逐点卷积 (Pointwise Convolution)#    负责在通道之间混合信息,它本质上就是一个 1x1 的标准卷积。self.pointwise_conv = nn.Conv2d(dim,                      # 输入通道数。dim,                      # 输出通道数。kernel_size=1             # **核大小为1x1,是实现“逐点卷积”的关键**。)def forward(self, x):# 定义数据如何“流过”这个层 (前向传播)# 输入 x 的维度: [批次大小, 通道数, 高, 宽]# 1. 保存原始输入,用于最后的残差连接residual = x# 2. 应用第一个处理块:深度卷积 -> 激活 -> 归一化x = self.depthwise_conv(x)x = self.activation(x)x = self.norm(x)# 3. 应用第二个处理块:逐点卷积 -> 激活 -> 归一化x = self.pointwise_conv(x)x = self.activation(x)x = self.norm(x)# 4. 完成残差连接return x + residual

2. 整体架构:ConvMixer 模型

现在,我们把 ConvMixerLayer 堆叠起来,并加上开头的“分块嵌入”和结尾的“分类头”,构成完整的 ConvMixer 模型。

class ConvMixer(nn.Module):"""完整的 ConvMixer 模型架构。"""def __init__(self, dim, depth, kernel_size=9, patch_size=7, num_classes=1000):super().__init__()# --- 1. 分块嵌入 (Patch Embedding) ---# 使用一个卷积层同时实现图像分块和特征嵌入。self.patch_embedding = nn.Sequential(nn.Conv2d(3,                        # 输入是RGB图像,所以有3个通道。dim,                      # 输出通道数,即我们想要的嵌入维度。kernel_size=patch_size,   # 卷积核大小等于块大小。stride=patch_size         # 步长等于核大小,确保分块不重叠。),nn.GELU(),                    # 同样使用 GELU 激活函数。nn.BatchNorm2d(dim)           # 批归一化。)# --- 2. 堆叠 ConvMixer 层 ---self.mixer_layers = nn.Sequential(*[ConvMixerLayer(dim=dim, kernel_size=kernel_size) for _ in range(depth)])# --- 3. 分类头 (Classification Head) ---# a. 全局平均池化#    将每个通道的 HxW 特征图压缩成一个 1x1 的值。self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))# b. 全连接层 (分类器)#    将池化后的向量映射到最终的类别数量上。self.classifier = nn.Linear(dim, num_classes)def forward(self, x):# 定义数据在整个模型中的流动路径# 初始输入 x 维度: [批次大小, 3, 224, 224] (以ImageNet为例)# 1. 应用分块嵌入#    x 维度变为 -> [批次大小, dim, 32, 32] (224 / 7 = 32)x = self.patch_embedding(x)# 2. 通过所有 ConvMixer 层#    维度保持不变 -> [批次大小, dim, 32, 32]x = self.mixer_layers(x)# 3. 应用全局平均池化#    x 维度变为 -> [批次大小, dim, 1, 1]x = self.global_avg_pool(x)# 4. 展平张量以适应全连接层#    `torch.flatten(x, 1)` 会将从第1个维度(通道维)开始的所有维度拍平。#    x 维度变为 -> [批次大小, dim]x = torch.flatten(x, 1)# 5. 通过分类器得到最终输出#    x 维度变为 -> [批次大小, num_classes]return self.classifier(x)

3. 实例化与测试 

最后,让我们创建模型的一个实例,并用一个假的图像数据来测试它,看看整个流程是否能跑通。

# --- 实例化一个 ConvMixer-1536/20 模型 ---
# 这是论文中提出的一个高性能版本配置
# dim=1536, depth=20, kernel_size=9, patch_size=7
model = ConvMixer(dim=1536,depth=20,kernel_size=9,patch_size=7,num_classes=1000  # ImageNet 数据集的类别数
)# 打印模型结构,可以清晰地看到我们定义的每一层
# print(model)# --- 创建一个假的输入图像张量进行测试 ---
# 模拟一个批次包含4张 224x224 的3通道彩色图像
dummy_images = torch.randn(4, 3, 224, 224)# 将假图像输入模型,得到输出
output = model(dummy_images)# 打印输出张量的形状
# 预期输出: torch.Size([4, 1000]),代表每张图片都得到了1000个类别的得分
print(f"输入张量形状: {dummy_images.shape}")
print(f"输出张量形状: {output.shape}")

OK,结束,希望可以帮助大家学会这个轻量化模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/90474.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/90474.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

教程:如何通过代理服务在国内高效使用 Claude API 并集成到 VSCode

对于许多开发者来说,直接访问 Anthropic 的 Claude API 存在网络障碍。本文将介绍一个第三方代理服务,帮助你稳定、高效地利用 Claude 的强大能力,并将其无缝集成到你的开发工作流中。 一、服务介绍 我们使用的是 open.xiaojingai.com 这个…

从零开始:Vue 3 + TypeScript 项目创建全记录

一次完整的现代前端项目搭建经历,踩坑与收获并存 📖 前言 最近创建了一个新的 Vue 3 项目,整个过程中遇到了不少有趣的选择和决策点。作为一个技术复盘,我想把这次经历分享出来,希望能帮助到其他开发者,特别是那些刚接触 Vue 3 生态的朋友们。 🛠️ 项目初始化:选择…

[spring6: @EnableWebSocket]-源码解析

注解 EnableWebSocket Retention(RetentionPolicy.RUNTIME) Target(ElementType.TYPE) Documented Import(DelegatingWebSocketConfiguration.class) public interface EnableWebSocket {}DelegatingWebSocketConfiguration Configuration(proxyBeanMethods false) public …

Nacos 封装与 Docker 部署实践

Nacos 封装与 Docker 部署指南 0 准备工作 核心概念​ 命名空间:用于隔离不同环境(如 dev、test、prod)或业务线,默认命名空间为public。​ 数据 ID:配置集的唯一标识,命名规则推荐为{服务名}-{profile}.{扩…

Vue2——4

组件的样式冲突 scoped默认情况:写在组件中的样式会 全局生效 → 因此很容易造成多个组件之间的样式冲突问题。1. 全局样式: 默认组件中的样式会作用到全局2. 局部样式: 可以给组件加上 scoped 属性, 可以让样式只作用于当前组件原理:当前组件内标签都被…

30天打好数模基础-逻辑回归讲解

案例代码实现一、代码说明本案例针对信用卡欺诈检测二分类问题,完整实现逻辑回归的数据生成→预处理→模型训练→评估→阈值调整→决策边界可视化流程。数据生成:模拟1000条交易数据,其中欺诈样本占20%(类不平衡)&…

CDH yarn 重启后RM两个备

yarn rmadmin -transitionToActive --forcemanual rm1 cd /opt/cloudera/parcels/CDH/lib/zookeeper/bin/ ./zkCli.sh -server IT-CDH-Node01:2181 查看是否存在残留的ActiveBreadCrumb节点 ls /yarn-leader-election/yarnRM #若输出只有[ActiveBreadCrumb](正常应…

HTML5音频技术及Web Audio API深入解析

本文还有配套的精品资源&#xff0c;点击获取 简介&#xff1a;音频处理在IT行业中的多媒体、游戏开发、在线教育和音乐制作等应用领域中至关重要。本文详细探讨了HTML5中的 <audio> 标签和Web Audio API等技术&#xff0c;涉及音频的嵌入、播放、控制以及优化。特别…

每日面试题13:垃圾回收器什么时候STW?

STW是什么&#xff1f;——深入理解JVM垃圾回收中的"Stop-The-World"在Java程序运行过程中&#xff0c;JVM会通过垃圾回收&#xff08;GC&#xff09;自动管理内存&#xff0c;释放不再使用的对象以腾出空间。但你是否遇到过程序突然卡顿的情况&#xff1f;这可能与G…

【系统全面】常用SQL语句大全

一、基本查询语句 查询所有数据&#xff1a; SELECT * FROM 表名;查询特定列&#xff1a; SELECT 列名1, 列名2 FROM 表名;条件查询&#xff1a; SELECT * FROM 表名 WHERE 条件;模糊查询&#xff1a; SELECT * FROM 表名 WHERE 列名 LIKE 模式%;排序查询&#xff1a; SELECT *…

Spring之SSM整合流程详解(Spring+SpringMVC+MyBatis)

Spring之SSM整合流程详解-SpringSpringMVCMyBatis一、SSM整合的核心思路二、环境准备与依赖配置2.1 开发环境2.2 Maven依赖&#xff08;pom.xml&#xff09;三、整合配置文件&#xff08;核心步骤&#xff09;3.1 数据库配置&#xff08;db.properties&#xff09;3.2 Spring核…

C++STL系列之set和map系列

前言 set和map都是关联式容器&#xff0c;stl中树形结构的有四种&#xff0c;set&#xff0c;map&#xff0c;multiset,multimap.本次主要是讲他们的模拟实现和用法。 一、set、map、multiset、multimap set set的中文意思是集合&#xff0c;集合就说明不允许重复的元素 1……

Linux 磁盘挂载,查看uuid

lsblk -o NAME,FSTYPE,LABEL,UUID,MOUNTPOINT,SIZEsudo ntfsfix /dev/nvme1n1p1sudo mount -o remount,rw /dev/nvme1n1p1 /media/yake/Datasudo ntfsfix /dev/sda2sudo mount -o remount,rw /dev/sda2 /media/yake/MyData

【AJAX】XMLHttpRequest、Promise 与 axios的关系

目录 一、AJAX原理 —— XMLHttpRequest 1.1 使用XMLHttpRequest 二、 XMLHttpRequest - 查询参数 &#xff08;就是往服务器后面拼接要查询的字符串&#xff09; 三、 地区查询 四、 XMLHttpRequest - 数据提交 五、 认识Promise 5.1 为什么 JavaScript 需要异步&#…

C++中的stack和queue

C中的stack和queue 前言 这一节的内容对于stack和queue的使用介绍会比较少&#xff0c;主要是因为stack和queue的使用十分简单&#xff0c;而且他们的功能主要也是在做题的时候才会显现。这一栏目暂时不会写关于做题的内容&#xff0c;后续我会额外开一个做题日记的栏目的。 这…

Spring Bean生命周期七步曲:定义、实例化、初始化、使用、销毁

各位小猿&#xff0c;程序员小猿开发笔记&#xff0c;希望大家共同进步。 引言 1.整体流程图 2.各阶段分析 1️⃣定义阶段 1.1 定位资源 Spring 扫描 Component、Service、Controller 等注解的类或解析 XML/Java Config 中的 Bean 定义 1.2定义 BeanDefinition 解析类信息…

API安全监测工具:数字经济的免疫哨兵

&#x1f4a5; 企业的三重致命威胁 1. 漏洞潜伏的定时炸弹 某支付平台未检测出API的批量数据泄露漏洞&#xff0c;导致230万用户信息被盗&#xff0c;面临GDPR 1.8亿欧元罚单&#xff08;IBM X-Force 2024报告&#xff09;。传统扫描器对逻辑漏洞漏检率超40%&#xff08;OWASP基…

Matplotlib详细教程(基础介绍,参数调整,绘图教程)

目录 一、初识Matploblib 1.1 安装 Matplotlib 1.2、Matplotlib 的两种接口风格 1.3、Figure 和 Axes 的深度理解 1.4 设置画布大小 1.5 设置网格线 1.6 设置坐标轴 1.7 设置刻度和标签 1.8 添加图例和标题 1.9 设置中文显示 1.10 调整子图布局 二、常用绘图教程 2…

Redis高可用架构演进面试笔记

1. 主从复制架构 核心概念Redis单节点并发能力有限&#xff0c;通过主从集群实现读写分离提升性能&#xff1a; Master节点&#xff1a;负责写操作Slave节点&#xff1a;负责读操作&#xff0c;从主节点同步数据 主从同步流程 全量同步&#xff08;首次同步&#xff09;建立连接…

无人机保养指南

定期清洁无人机在使用后容易积累灰尘、沙砾等杂物&#xff0c;需及时清洁。使用软毛刷或压缩空气清除电机、螺旋桨和机身缝隙中的杂质。避免使用湿布直接擦拭电子元件&#xff0c;防止短路。电池维护锂电池是无人机的核心部件&#xff0c;需避免过度放电或充电。长期存放时应保…