pytorch版本densenet代码讲解

DenseNet 模型代码详解

下面是 DenseNet 模型代码的逐部分详细解析:

1. 导入模块

import re
from collections import OrderedDict
from functools import partial
from typing import Any, Optionalimport torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from torch import Tensorfrom ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface
  • re: 正则表达式模块,用于处理权重名称的转换
  • OrderedDict: 有序字典,用于按顺序构建网络层
  • partial: 创建部分函数,用于预设图像转换参数
  • torch.nn: PyTorch 的神经网络模块
  • torch.utils.checkpoint: 内存优化技术,减少训练时的内存占用
  • ImageClassification: 图像分类的预处理转换
  • register_model: 注册模型的装饰器
  • Weights/WeightsEnum: 预训练权重相关类
  • _IMAGENET_CATEGORIES: ImageNet 数据集类别标签
  • 模型工具函数: 覆盖参数、处理旧版接口等

2. DenseNet 基础层 (_DenseLayer)

class _DenseLayer(nn.Module):def __init__(self, num_input_features: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False) -> None:super().__init__()# 第一个卷积块 (1x1 卷积)self.norm1 = nn.BatchNorm2d(num_input_features)self.relu1 = nn.ReLU(inplace=True)self.conv1 = nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)# 第二个卷积块 (3x3 卷积)self.norm2 = nn.BatchNorm2d(bn_size * growth_rate)self.relu2 = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)self.drop_rate = float(drop_rate)self.memory_efficient = memory_efficient
  • Bottleneck 结构: 由两个卷积层组成,减少计算量
  • 1x1 卷积: 降维,输出通道数为 bn_size * growth_rate
  • 3x3 卷积: 主卷积层,输出通道数为 growth_rate
  • memory_efficient: 是否使用梯度检查点节省内存

前向传播逻辑

    def bn_function(self, inputs: list[Tensor]) -> Tensor:# 拼接所有输入特征concated_features = torch.cat(inputs, 1)# 通过第一个卷积块bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))return bottleneck_outputdef forward(self, input: Tensor) -> Tensor:if isinstance(input, Tensor):prev_features = [input]else:prev_features = input# 内存高效模式处理if self.memory_efficient and self.any_requires_grad(prev_features):bottleneck_output = self.call_checkpoint_bottleneck(prev_features)else:bottleneck_output = self.bn_function(prev_features)# 通过第二个卷积块new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))# 应用Dropoutif self.drop_rate > 0:new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)return new_features
  • 特征拼接: 将前面所有层的输出拼接在一起
  • 梯度检查点: 在内存高效模式下,使用检查点减少内存占用
  • Dropout: 随机丢弃部分神经元,防止过拟合

3. Dense 块 (_DenseBlock)

class _DenseBlock(nn.ModuleDict):def __init__(self,num_layers: int,num_input_features: int,bn_size: int,growth_rate: int,drop_rate: float,memory_efficient: bool = False,) -> None:super().__init__()# 创建多个密集层for i in range(num_layers):layer = _DenseLayer(num_input_features + i * growth_rate,growth_rate=growth_rate,bn_size=bn_size,drop_rate=drop_rate,memory_efficient=memory_efficient,)self.add_module("denselayer%d" % (i + 1), layer)
  • 模块字典: 存储多个密集层
  • 输入特征计算: 每增加一层,输入特征增加 growth_rate 个通道

前向传播

    def forward(self, init_features: Tensor) -> Tensor:features = [init_features]# 逐层处理并收集输出for name, layer in self.items():new_features = layer(features)features.append(new_features)# 拼接所有层的输出return torch.cat(features, 1)
  • 特征累积: 每一层的输出都添加到特征列表中
  • 特征拼接: 将所有层的输出沿通道维度拼接

4. 过渡层 (_Transition)

class _Transition(nn.Sequential):def __init__(self, num_input_features: int, num_output_features: int) -> None:super().__init__()# 压缩特征维度self.norm = nn.BatchNorm2d(num_input_features)self.relu = nn.ReLU(inplace=True)self.conv = nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)# 空间下采样self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
  • 特征压缩: 1x1 卷积减少通道数(通常减半)
  • 空间降维: 平均池化减小特征图尺寸

5. DenseNet 主模型

class DenseNet(nn.Module):def __init__(self,growth_rate: int = 32,block_config: tuple[int, int, int, int] = (6, 12, 24, 16),num_init_features: int = 64,bn_size: int = 4,drop_rate: float = 0,num_classes: int = 1000,memory_efficient: bool = False,) -> None:super().__init__()_log_api_usage_once(self)  # 记录API使用情况# 初始卷积层self.features = nn.Sequential(OrderedDict([("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),("norm0", nn.BatchNorm2d(num_init_features)),("relu0", nn.ReLU(inplace=True)),("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),]))# 构建多个Dense块和过渡层num_features = num_init_featuresfor i, num_layers in enumerate(block_config):# 添加Dense块block = _DenseBlock(num_layers=num_layers,num_input_features=num_features,bn_size=bn_size,growth_rate=growth_rate,drop_rate=drop_rate,memory_efficient=memory_efficient,)self.features.add_module("denseblock%d" % (i + 1), block)num_features += num_layers * growth_rate# 添加过渡层(最后一个块除外)if i != len(block_config) - 1:trans = _Transition(num_features, num_features // 2)self.features.add_module("transition%d" % (i + 1), trans)num_features = num_features // 2# 最终批归一化self.features.add_module("norm5", nn.BatchNorm2d(num_features))# 分类器self.classifier = nn.Linear(num_features, num_classes)# 参数初始化for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.constant_(m.bias, 0)
  • 初始卷积层: 快速下采样输入图像
  • 块配置: 控制每个Dense块中的层数
  • 通道管理: 通过过渡层压缩通道数
  • Kaiming初始化: 卷积层的权重初始化
  • 批归一化初始化: 权重设为1,偏置设为0

前向传播

    def forward(self, x: Tensor) -> Tensor:features = self.features(x)out = F.relu(features, inplace=True)out = F.adaptive_avg_pool2d(out, (1, 1))  # 全局平均池化out = torch.flatten(out, 1)  # 展平特征out = self.classifier(out)  # 分类return out
  • 特征提取: 通过多个Dense块和过渡层
  • 全局平均池化: 将特征图转换为特征向量
  • 全连接层: 输出分类结果

6. 权重加载函数

def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> None:# 匹配旧版权重名称模式pattern = re.compile(r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$")state_dict = weights.get_state_dict(progress=progress, check_hash=True)# 转换权重名称for key in list(state_dict.keys()):res = pattern.match(key)if res:new_key = res.group(1) + res.group(2)state_dict[new_key] = state_dict[key]del state_dict[key]# 加载权重model.load_state_dict(state_dict)
  • 权重名称转换: 适配旧版权重命名方式
  • 哈希校验: 确保下载的权重文件完整无误

7. 模型工厂函数

def _densenet(growth_rate: int,block_config: tuple[int, int, int, int],num_init_features: int,weights: Optional[WeightsEnum],progress: bool,**kwargs: Any,
) -> DenseNet:# 根据权重调整输出类别数if weights is not None:_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))# 创建模型model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)# 加载预训练权重if weights is not None:_load_state_dict(model=model, weights=weights, progress=progress)return model
  • 参数覆盖: 根据预训练权重调整输出类别数
  • 灵活配置: 支持不同DenseNet变体

8. 预训练权重定义

_COMMON_META = {"min_size": (29, 29),  # 最小输入尺寸"categories": _IMAGENET_CATEGORIES,  # ImageNet类别"recipe": "https://github.com/pytorch/vision/pull/116",  # 训练方法
}class DenseNet121_Weights(WeightsEnum):IMAGENET1K_V1 = Weights(url="https://download.pytorch.org/models/densenet121-a639ec97.pth",transforms=partial(ImageClassification, crop_size=224),  # 图像预处理meta={**_COMMON_META,"num_params": 7978856,  # 参数量"_metrics": {  # 性能指标"ImageNet-1K": {"acc@1": 74.434,  # top-1准确率"acc@5": 91.972,  # top-5准确率}},"_ops": 2.834,  # 计算量 (GFLOPs)"_file_size": 30.845,  # 文件大小 (MB)},)DEFAULT = IMAGENET1K_V1  # 默认权重
  • 权重元数据: 包含模型性能和资源信息
  • 预处理定义: 指定图像分类任务的预处理流程
  • 性能指标: 提供在ImageNet上的评估结果

9. 模型变体实现

@register_model()  # 注册模型
@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1))
def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:weights = DenseNet121_Weights.verify(weights)  # 验证权重return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
  • DenseNet121: 增长率32,块配置[6,12,24,16],初始特征64
  • DenseNet169: 增长率32,块配置[6,12,32,32],初始特征64
  • DenseNet201: 增长率32,块配置[6,12,48,32],初始特征64
  • DenseNet161: 增长率48,块配置[6,12,36,24],初始特征96

DenseNet 关键特点

  1. 密集连接: 每一层都接收前面所有层的特征图作为输入
  2. 特征重用: 通过拼接实现多层次特征融合
  3. 瓶颈设计: 1×1卷积减少计算量
  4. 过渡层: 压缩特征维度和空间尺寸
  5. 高效内存: 可选的内存优化模式

DenseNet通过密集连接促进了特征重用,减少了梯度消失问题,提高了参数效率,在各种视觉任务中表现出色。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/87639.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/87639.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前端常见设计模式深度解析

# 前端常见设计模式深度解析一、设计模式概述 设计模式是解决特定问题的经验总结,前端开发中常用的设计模式可分为三大类: 创建型模式:处理对象创建机制(单例、工厂等)结构型模式:处理对象组合(…

React 学习(3)

核心API——React.creatElement()方法优点:将创建元素、添加属性和事件、添加内容和子元素等使用原生dom需要进行复杂操作才能实现的功能集成在一个API中。1.该方法接收三个参数第一个是要创建的元素的名称(小写是因为如果,大写开头会被react…

倾斜摄影无人机飞行航线规划流程详解

在倾斜摄影测量项目中,航线规划的严谨性直接决定了最终三维模型的质量与完整性。照片覆盖不全、模型空洞、纹理模糊或分辨率不达标等问题,往往源于规划阶段对关键细节的疏忽。本文将系统梳理倾斜摄影无人机航线规划的核心流程与关键要点,旨在…

Minio大文件分片上传

一、引入依赖 <dependency><groupId>io.minio</groupId><artifactId>minio</artifactId><version>8.3.3</version></dependency> 二、自定义Minio客户端 package com.gstanzer.video.controller;import com.google.common.c…

Jenkins 插件深度应用:让你的CI/CD流水线如虎添翼 [特殊字符]

Jenkins 插件深度应用&#xff1a;让你的CI/CD流水线如虎添翼 &#x1f680; 嘿&#xff0c;各位开发小伙伴&#xff01;今天咱们来聊聊Jenkins的插件生态系统。如果说Jenkins是一台强大的引擎&#xff0c;那插件就是让这台引擎发挥最大威力的各种零部件。准备好了吗&#xff1…

密码学(斯坦福)

密码学笔记 \huge{密码学笔记} 密码学笔记 斯坦福大学密码学的课程笔记 课程网址&#xff1a;https://www.bilibili.com/video/BV1Rf421o79E/?spm_id_from333.337.search-card.all.click&vd_source5cc05a038b81f6faca188e7cf00484f6 概述 密码学的使用背景 安全信息保护…

代码随想录算法训练营第四十六天|动态规划part13

647. 回文子串 题目链接&#xff1a;647. 回文子串 - 力扣&#xff08;LeetCode&#xff09; 文章讲解&#xff1a;代码随想录 思路&#xff1a; 以dp【i】表示以s【i】结尾的回文子串的个数&#xff0c;发现递推公式推导不出来此路不通 以dp【i】【j】表示s【i】到s【j】的回…

基于四种机器学习算法的球队数据分析预测系统的设计与实现

文章目录 有需要本项目的代码或文档以及全部资源&#xff0c;或者部署调试可以私信博主项目介绍项目展示随机森林模型XGBoost模型逻辑回归模型catboost模型每文一语 有需要本项目的代码或文档以及全部资源&#xff0c;或者部署调试可以私信博主 项目介绍 本项目旨在设计与实现…

http、SSL、TLS、https、证书

一、基础概念 1.HTTP HTTP (超文本传输协议) 是一种用于客户端和服务器之间传输超媒体文档的应用层协议&#xff0c;是万维网的基础。 简而言之&#xff1a;一种获取和发送信息的标准协议 2.SSL 安全套接字层&#xff08;SSL&#xff09;是一种通信协议或一组规则&#xf…

在 C++ 中,判断 `std::string` 是否为空字符串

在 C 中&#xff0c;判断 std::string 是否为空字符串有多种方法&#xff0c;以下是最常用的几种方式及其区别&#xff1a; 1. 使用 empty() 方法&#xff08;推荐&#xff09; #include <string>std::string s; if (s.empty()) {// s 是空字符串 }特性&#xff1a; 时间…

【Harmony】鸿蒙企业应用详解

【HarmonyOS】鸿蒙企业应用详解 一、前言 1、应用类型定义速览&#xff1a; HarmonyOS目前针对应用分为三种类型&#xff1a;普通应用&#xff0c;游戏应用&#xff0c;企业应用。 而企业应用又分为&#xff0c;企业普通应用和设备管理应用MDM&#xff08;Mobile Device Man…

Linux云计算基础篇(8)

VIM 高级特性插入模式按 i 进入插入模式。按 o 在当前行下方插入空行并进入插入模式。按 O 在当前行上方插入空行并进入插入模式。命令模式:set nu 显示行号。:set nonu 取消显示行号。:100 光标跳转到第 100 行。G 光标跳转到文件最后一行。gg 光标跳转到文件第一行。30G 跳转…

Linux进程单例模式运行

Linux进程单例模式运行 #include <iostream> #include <stdlib.h> #include <unistd.h> #include <string.h> #include <stdio.h> #include <sys/types.h> #include <sys/stat.h> #include <fcntl.h>int write_pid(const cha…

【Web 后端】部署服务到服务器

文章目录 前言一、如何启动服务二、挂载和开机启动服务1. 配置systemctl 服务2. 创建server用户3. 启动服务 总结 前言 如果你的后端服务写好了如果部署到你的服务器呢&#xff0c;本次通过fastapi写的服务实例&#xff0c;示范如何部署到服务器&#xff0c;并做服务管理。 一…

国产MCU学习Day5——CW32F030C8T6:窗口看门狗功能全解析

每日更新教程&#xff0c;评论区答疑解惑&#xff0c;小白也能变大神&#xff01;" 目录 一.窗口看门狗&#xff08;WWDG&#xff09;简介 二.窗口看门狗寄存器列表 三.窗口看门狗复位案例 一.窗口看门狗&#xff08;WWDG&#xff09;简介 CW32F030C8T6 内部集成窗口看…

2025年文件加密软件分享:守护数字世界的核心防线

在数字化时代&#xff0c;数据已成为个人与企业的宝贵资产&#xff0c;文件加密软件通过复杂的算法&#xff0c;确保信息在存储、传输与共享过程中的保密性、完整性与可用性。一、文件加密软件的核心原理文件加密软件算法以其高效性与安全性广泛应用&#xff0c;通过对文件数据…

node.js下载教程

1.项目环境文档 语雀 2.nvm安装教程与nvm常见命令,超详细!-阿里云开发者社区 C:\Windows\System32>nvm -v 1.2.2 C:\Windows\System32>nvm list available Error retrieving "http://npm.taobao.org/mirrors/node/index.json": HTTP Status 404 C:\Window…

(AI如何解决问题)在一个项目,跳转到外部html页面,页面布局

问题描述目前&#xff0c;ERP后台有很多跳转外部链接的地方&#xff0c;会直接打开一个tab显示。因为有些页面是适配手机屏幕显示&#xff0c;放在浏览器会超级大。不美观&#xff0c;因此提出优化。修改前&#xff1a;修改后&#xff1a;思考过程1、先看下代码&#xff1a;log…

网络通信协议与虚拟网络技术相关整理(上)

#作者&#xff1a;程宏斌 文章目录 tcp协议udp协议arp协议icmp协议dhcp协议BGP协议OSPF协议BGP vs OSPF 对比表VLAN&#xff08;Virtual LAN&#xff09;VXLAN&#xff08;Virtual Extensible LAN&#xff09;IPIP&#xff08;IP-in-IP&#xff09;vxlan/vlan/ipip网桥/veth网…

物联网软件层面的核心技术体系

物联网软件层面的核心技术体系 物联网(IoT)软件技术栈是一个多层次的复杂体系&#xff0c;涵盖从设备端到云平台的完整解决方案。以下是物联网软件层面的关键技术分类及详细说明&#xff1a; 一、设备端软件技术 1. 嵌入式操作系统 实时操作系统(RTOS)&#xff1a; FreeRTO…