基于CNN的FashionMNIST数据集识别6——DenseNet模型

源码

import torch
from torch import nn
from torchsummary import summary"""
DenseNet的核心组件:稠密层(DenseLayer)
实现特征复用机制,每个层的输出会与所有前序层的输出在通道维度拼接
"""class DenseLayer(nn.Module):def __init__(self, input_channels, growth_rate):super().__init__()# 批归一化 + ReLU + 1x1卷积 (瓶颈层,减少计算量)self.bn1 = nn.BatchNorm2d(input_channels)self.conv1 = nn.Conv2d(input_channels, 4 * growth_rate, kernel_size=1)# 批归一化 + ReLU + 3x3卷积 (特征提取层)self.bn2 = nn.BatchNorm2d(4 * growth_rate)self.conv2 = nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1)self.relu = nn.ReLU()def forward(self, x):# 前向传播:BN->ReLU->Conv(1x1)->BN->ReLU->Conv(3x3)out = self.conv1(self.relu(self.bn1(x)))out = self.conv2(self.relu(self.bn2(out)))# 将新特征与输入特征在通道维度拼接(实现特征复用)return torch.cat([x, out], 1)"""
稠密块(DenseBlock):由多个稠密层组成
每个稠密层的输入包含前面所有层的特征图
"""class DenseBlock(nn.Module):def __init__(self, num_layers, input_channels, growth_rate):super().__init__()layers = []# 构建num_layers个稠密层for i in range(num_layers):# 每层的输入通道数 = 初始通道数 + 已添加的特征图数layers.append(DenseLayer(input_channels + i * growth_rate, growth_rate))self.block = nn.Sequential(*layers)def forward(self, x):return self.block(x)"""
过渡层(TransitionLayer):用于压缩特征图尺寸和通道数
包含1x1卷积和平均池化
"""class TransitionLayer(nn.Module):def __init__(self, input_channels, output_channels):super().__init__()# 压缩通道数的1x1卷积self.bn = nn.BatchNorm2d(input_channels)self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=1)# 下采样用的平均池化self.pool = nn.AvgPool2d(2, stride=2)self.relu = nn.ReLU()def forward(self, x):# 前向传播:BN->ReLU->Conv(1x1)->AvgPoolout = self.conv(self.relu(self.bn(x)))return self.pool(out)"""
完整的DenseNet网络结构
包含初始卷积层、多个稠密块+过渡层、分类层
"""class DynamicDenseNet(nn.Module):def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=5):super().__init__()# 初始卷积层(标准CNN开始结构)self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),  # 下采样nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  # 进一步下采样)# 构建稠密块和过渡层num_channels = 64  # 初始通道数for i, num_layers in enumerate(block_config):# 添加稠密块block = DenseBlock(num_layers, num_channels, growth_rate)self.features.add_module(f'denseblock{i + 1}', block)# 更新通道数(每个稠密层增加growth_rate个通道)num_channels += num_layers * growth_rate# 不是最后一个块时添加过渡层if i != len(block_config) - 1:trans = TransitionLayer(num_channels, num_channels // 2)self.features.add_module(f'transition{i + 1}', trans)num_channels = num_channels // 2  # 过渡层压缩通道数# 分类层self.classifier = nn.Sequential(nn.BatchNorm2d(num_channels),nn.ReLU(),nn.AdaptiveAvgPool2d((1, 1)),  # 全局平均池化nn.Flatten(),nn.Linear(num_channels, num_classes)  # 全连接输出分类结果)def forward(self, x):features = self.features(x)return self.classifier(features)# 测试代码
if __name__ == "__main__":device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = DynamicDenseNet().to(device)# 打印网络结构和参数统计(输入尺寸为3x224x224)print(summary(model, (3, 224, 224)))

流程图

设计理念

密集连接机制

在DenseNet中,每个层都与其后续的所有层直接连接。这意味着:

  • 第l层的输入 = 所有前序层(0到l-1)的特征图拼接
  • 数学表示:x_l = H_l([x_0, x_1, ..., x_{l-1}])
  • 与传统架构相比,缓解了梯度消失问题,增强了特征传播

特征复用机制

  • 每个层都可以访问所有前序层的特征图
  • 网络自动学习在不同层级复用特征
  • 减少了特征冗余,提高了参数效率

瓶颈层设计

每个DenseLayer包含:

  1. BN-ReLU-Conv(1×1)层‌:作为瓶颈层,减少特征图数量和计算量
    • 将输入通道压缩到4×growth_rate
  2. BN-ReLU-Conv(3×3)层‌:主特征提取层
    • 输出growth_rate个特征图(通常growth_rate=12-48)

增长率(growth_rate)参数

  • 控制每个层添加到特征图的通道数
  • 较小的growth_rate也能获得优异性能(如k=12 vs ResNet k=64)
  • 决定模型容量和参数效率的关键超参数

过渡层设计

  • 1×1卷积‌:压缩特征通道数(通常减少50%)
  • 2×2平均池化‌:下采样特征图尺寸
  • 公式:θ = 压缩因子(通常0.5)
    • output_channels = θ × input_channels

 充电:BatchNorm2d的用法

batchnorm2d是PyTorch中用于2D输入的批归一化(Batch Normalization)层。

参数类型默认值说明
num_featuresint-输入通道数C
epsfloat1e-5数值稳定项
momentumfloat0.1运行统计量更新系数
affineboolTrue是否启用γ/β可学习参数
track_running_statsboolTrue是否记录运行统计量

通常只需要设置输入通道数即可。比如:

conv = nn.Conv2d(in_c, out_c, 3)
bn = nn.BatchNorm2d(out_c)  # 注意与卷积输出通道一致
relu = nn.ReLU()
output = relu(bn(conv(input)))

bn层可以做初始化设置,比如:

bn = nn.BatchNorm2d(64)
# 初始化缩放因子为1,偏移为0
nn.init.constant_(bn.weight, 1)  
nn.init.constant_(bn.bias, 0)

 需要注意的是,当批次数量太小时,使用bn层可能表现不稳定。当batch<16时,建议使用GroupNorm方法做替代

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/85152.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/85152.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL 中 INSERT ... ON DUPLICATE KEY UPDATE 为什么会导致主键自增失效?

最近开发的过程中&#xff0c;使用ai生成代码&#xff0c;写了一条这样的SQL&#xff1a;INSERT … ON DUPLICATE KEY UPDATE&#xff0c;然后发现一个奇怪的现象&#xff1a; 为什么使用这个语法后&#xff0c;自增主键&#xff08;AUTO_INCREMENT&#xff09;的值会跳跃甚至…

jenkins流水线打包vue无权限

jenkins在使用npm命令进行拉取依赖时,创建目录会报错无权限&#xff0c;如下如所示 这是因为npm 出于安全考虑不支持以 root 用户运行&#xff0c;即使你用 root 用户身份运行了&#xff0c;npm 会自动转成一个叫 nobody 的用户来运行&#xff0c;而这个用户权限非常低 若需要…

快速实现golang的grpc服务

文章目录 1、安装服务2、检查安装版本情况3、编写proto文件4、生成代码5、实现业务逻辑6、创建provider7、测试调用 1、安装服务 1、protoc安装 需去官网下载 protobuf 2、命令行安装protoc-gen-go和protoc-gen-go-grpc $ go install google.golang.org/protobuf/cmd/protoc-…

C++ 学习 多线程 2025年6月17日18:41:30

多线程(标准线程库 <thread>) 创建线程 #include <iostream> #include <thread>void hello() {std::cout << "Hello from thread!\n"; }int main() {// 创建线程并执行 hello() std::thread t(hello); //线程对象&#xff0c;传入可调用对…

常见的测试工具及分类

Web测试工具是保障Web应用质量的核心支撑&#xff0c;根据测试类型&#xff08;功能、性能、安全、自动化等&#xff09;和场景需求&#xff0c;可分为多个类别。以下从​​八大核心测试类型​​出发&#xff0c;梳理常见工具及其特点、适用场景&#xff1a; ​​一、功能测试工…

七牛存储sdk在springboot完美集成和应用 七牛依赖 自动化配置

文章目录 概要依赖配置属性配置类配置文件业务层控制层运行结果亮点 概要 七牛存储很便宜的&#xff0c;在使用项目的用好官方封装好的sdk&#xff0c;结合springboot去使用很方便&#xff0c;我本地用的是springoot3spring-boot-autoconfigure 依赖 <dependency><…

Java相关-链表-设计链表-力扣707

你可以选择使用单链表或者双链表&#xff0c;设计并实现自己的链表。 单链表中的节点应该具备两个属性&#xff1a;val 和 next 。val 是当前节点的值&#xff0c;next 是指向下一个节点的指针/引用。 如果是双向链表&#xff0c;则还需要属性 prev 以指示链表中的上一个节点…

C# 关于LINQ语法和类型的使用

常用语法&#xff0c;具体问题具体分析 1. Select2. SelectMany3. Where4. Take5. TakeWhile6. SkipWhile7. Join8. GroupJoin9. OrderBy10. OrderByDescending11. ThenBy12. Concat13. Zip14. Distinct15. Except16. Union17. Intersect18. Concat19. Reverse20. SequenceEqua…

华为OD-2024年E卷-小明周末爬山[200分] -- python

问题描述&#xff1a; 题目描述 周末小明准备去爬山锻炼&#xff0c;0代表平地&#xff0c;山的高度使用1到9来表示&#xff0c;小明每次爬山或下山高度只能相差k及k以内&#xff0c;每次只能上下左右一个方向上移动一格&#xff0c;小明从左上角(0,0)位置出发 输入描述 第一行…

Android:使用OkHttp

1、权限&#xff1a; <uses-permission android:name"android.permission.INTERNET" /> implementation com.squareup.okhttp3:okhttp:3.4.1 2、GET&#xff1a; new XXXTask ().execute("http://192.168.191.128:9000/xx");private class XXXTask…

Vue3+Element Plus动态表格列宽设置

在 Vue3 Element Plus 中实现动态设置表格列宽&#xff0c;可以通过以下几种方式实现&#xff1a; 方法 1&#xff1a;动态绑定 width 属性&#xff08;推荐&#xff09; vue 复制 下载 <template><el-table :data"tableData" style"width: 100%…

【JVM目前使用过的参数总结】

JVM参数总结 笔记记录 JVM-栈相关JVM-方法区(元空间)相关JVM-堆相关 JVM-栈相关 .-XX:ThreadStackSize1M -Xss1m 上面的简写形式【设置栈的大小】 JVM-方法区(元空间)相关 -XX:MaxMetaspaceSize10m 【设置最大元空间大小】 JVM-堆相关 -XX:MaxHeapSize10m -Xmx10m 上面的简写形…

AI辅助高考志愿填报-专业全景解析与报考指南

高考志愿填报&#xff0c;这可是关系到孩子未来的大事儿&#xff01;最近&#xff0c;我亲戚家的孩子也面临着这个难题&#xff0c;昨晚一个电话就跟我聊了好久&#xff0c;问我报啥专业好。说实话&#xff0c;这问题真不好回答&#xff0c;毕竟每个孩子情况不一样&#xff0c;…

Android Studio Windows安装与配置指南

Date: 2025-06-14 20:07:12 author: lijianzhan 内容简介 文章中&#xff0c;主要是为了初次接触 Android 开发的用户提供详细的关于 Android Studio 安装以及配置教程&#xff0c;涵盖环境准备、软件下载、安装配置全流程&#xff0c;重点解决路径命名、组件选择、工作空间设置…

SpringAI+DeepSeek-了解AI和大模型应用

一、认识AI 1.人工智能发展 AI&#xff0c;人工智能&#xff08;Artificial Intelligence&#xff09;&#xff0c;使机器能够像人类一样思考、学习和解决问题的技术。 AI发展至今大概可以分为三个阶段&#xff1a; 其中&#xff0c;深度学习领域的自然语言处理(Natural Lan…

IP5362至为芯支持无线充的22.5W双C口双向快充移动电源方案芯片

英集芯IP5362是一款应用于移动电源&#xff0c;充电宝&#xff0c;手机&#xff0c;平板电脑等支持无线充模式的22.5W双向快充移动电源方案SOC芯片,集成同步升降压转换器、锂电池充电管理、电池电量指示等功能。兼容全部快充协议&#xff0c;同步开关放电支持最大22.5W输出功率…

手游刚开服就被攻击怎么办?如何防御DDoS?

手游新上线时遭遇DDoS攻击是常见现象&#xff0c;可能导致服务器瘫痪、玩家流失甚至项目失败。面对突如其来的攻击&#xff0c;开发者与运营商需要迅速响应并建立长效防御机制。本文提供应急处理步骤与防御策略&#xff0c;助力游戏稳定运营。 一、手游开服遭攻击的应急响应 快…

秋招是开发算法一起准备,还是只准备一个

THE LAST TIME 昨天晚上半夜有个星球的26届的同学&#xff0c;私信问我。说目前是只准备开发还是开发算法一起准备&#xff08;两者技术知识都挺欠缺的&#xff09; 看到这里&#xff0c;肯定有很多同学会说。马上都该秋招了&#xff0c;还什么多线程开工&#xff0c;赶紧能住编…

web项目部署配置HTTPS遇到的问题解决方法

今天使用nginxtomcatssl完成了web项目的部署&#xff0c;本以为没有什么问题&#xff0c;但是在页面测试的时候又蹦出了这么一个问题&#xff0c;大致是说由于配置了HTTPS&#xff0c;但是之前的请求是通过HTTP请求的&#xff0c;所以现在被拦截&#xff0c;由于缺少某些权限信…

理解与建模弹性膜-AI云计算数值分析和代码验证

弹性膜在连接生物学理解和工程创新方面至关重要&#xff0c;因为它们能够模拟软组织力学、实现先进的细胞培养系统和促进柔性设备&#xff0c;广泛应用于软组织生物力学、细胞培养、生物膜建模和生物医学工程等领域。 ☁️AI云计算数值分析和代码验证 弹性膜在连接生物学理解和…