第6节 torch.nn介绍

6.1 torch.nn.Module介绍

        torch.nn.Module是 PyTorch 中构建神经网络的基础类,所有的神经网络模块都应该继承这个类。它提供了一种便捷的方式来组织和管理网络中的各个组件,包括层、参数等,同时还内置了许多用于模型训练和推理的功能。

官网:torch.nn — PyTorch 1.8.1 documentation

核心功能

(1)、网络构建:通过继承torch.nn.Module类,我们可以自定义自己的神经网络结构。在__init__方法中定义网络的各个层,在forward方法中定义数据的前向传播过程。

(2)、参数管理:torch.nn.Module会自动跟踪和管理网络中的参数(如权重和偏置)。我们可以通过parameters()方法获取网络的所有参数,方便进行优化器的配置和参数的更新。

(3)、设备转换:可以使用to()方法将模型转移到指定的设备(如 CPU 或 GPU)上,以利用不同设备的计算能力。​

(4)、状态切换:提供了train()和eval()方法来切换模型的训练和评估状态。在训练状态下,一些具有随机性的层(如 Dropout、BatchNorm)会正常工作;在评估状态下,这些层会采用确定性的行为。

6.2 torch.nn.Module常用方法

        __init__(self):构造函数,用于初始化网络的各个层和参数。在自定义网络时,需要在该方法中调用super().__init__()来初始化父类。​

        forward(self, x):前向传播方法,定义了数据在网络中的流动过程。当对模型进行调用时(如model(x)),实际上是调用了该方法。​

        parameters(self):返回一个迭代器,包含网络中的所有可学习参数。​

        named_parameters(self):返回一个迭代器,包含网络中参数的名称和对应的参数值。​

        to(self, device):将模型转移到指定的设备上。例如,model.to('cuda')将模型转移到 GPU 上。​

        train(self, mode=True):将模型设置为训练模式。​

        eval(self):将模型设置为评估模式,相当于train(mode=False)。​

        save_state_dict(self, path):保存模型的参数状态字典到指定路径。​

        load_state_dict(self, state_dict):从参数状态字典中加载模型的参数。

6.3 程序演示

6.3.1 官网提供的例子

import torch.nn as nn
import torch.nn.functional as Fclass Model(nn.Module):   #搭建的神经网络 Model继承了 Module类(父类)def __init__(self):   #初始化函数super(Model, self).__init__()   #必须要这一步,调用父类的初始化函数self.conv1 = nn.Conv2d(1, 20, 5)self.conv2 = nn.Conv2d(20, 20, 5)def forward(self, x):   #前向传播(为输入和输出中间的处理过程),x为输入x = F.relu(self.conv1(x))   #conv为卷积,relu为非线性处理return F.relu(self.conv2(x))

注意:前向传播 forward(在所有子类中进行重写)

6.3.2 自定义Model

import torch
from torch import nn# 定义一个自定义模型类Custom_Model,继承自nn.Module
# 所有的神经网络模型都应该继承nn.Module,以利用其提供的参数管理、设备转换等功能
class Custom_Model(nn.Module):# 构造函数,用于初始化模型的层和参数def __init__(self):# 调用父类nn.Module的构造函数,确保模型能够正确初始化super().__init__()# 前向传播方法,定义数据在模型中的流动和计算过程# 当对模型实例传入输入数据时,会自动调用该方法def forward(self, input):# 定义模型的计算逻辑:输入数据加1output = input + 1# 返回计算结果return outputCustom_Model = Custom_Model()
# 创建一个张量x,值为1.0,作为模型的输入数据
x = torch.tensor(1.0)
# 将输入数据x传入模型,模型会自动调用forward方法进行计算,得到输出结果
output = Custom_Model(x)
# 打印输出结果,此时输出应为2.0(1.0 + 1)
print(output)

6.4 torch.nn.functional.conv2d介绍

        torch.nn.functional.conv2d是 PyTorch 中用于执行二维卷积操作的函数,在卷积神经网络(CNN)中扮演着至关重要的角色,用于提取图像等二维数据的特征。以下是对它的详细介绍:

参数说明:

  • input (Tensor):输入张量,形状为(N, C_in, H_in, W_in)。其中,N是批量大小(batch size),表示一次处理的样本数量;C_in是输入通道数,例如对于灰度图像C_in=1,对于彩色图像(RGB 格式)C_in=3;H_in和W_in分别是输入特征图的高度和宽度。
  • weight (Tensor):卷积核(过滤器)张量,形状为(C_out, C_in, H_k, W_k) 。C_out是输出通道数,决定了经过卷积操作后生成的特征图数量;C_in必须与输入张量的通道数一致;H_k和W_k分别是卷积核的高度和宽度。
  • bias (Tensor,可选):偏置张量,形状为(C_out) ,为每个输出通道添加一个可学习的偏置值,默认值为None。
  • stride (int或tuple,默认值:1 ):卷积核在输入特征图上滑动的步长。如果是一个整数,表示在高度和宽度方向上的步长相同;如果是一个元组(stride_h, stride_w),则分别指定高度和宽度方向上的步长。
  • padding (int或tuple,默认值:0 ):在输入特征图的边缘添加填充(padding)像素。同样,整数表示在高度和宽度方向上添加相同数量的填充;元组(padding_h, padding_w)分别指定高度和宽度方向上的填充数量。填充可以用来控制输出特征图的大小,使其与输入大小相同或满足特定的尺寸要求。
  • dilation (int或tuple,默认值:1 ):卷积核元素之间的间距。dilation=1表示正常的卷积核;dilation=2时,卷积核元素之间会间隔一个位置,相当于扩大了卷积核的感受野。
  • groups (int,默认值:1 ):分组卷积的组数。当groups=1时,就是普通的卷积操作;当groups > 1时,输入通道会被分成groups组,卷积核也会相应分组,每组卷积核只与对应的一组输入通道进行卷积操作,常用于减少计算量或实现特定的网络结构,比如 AlexNet 中的分组卷积。

应用场景:

        torch.nn.functional.conv2d广泛应用于各类基于卷积神经网络的任务,如:

  • 图像分类:从输入图像中提取各种层次的特征,用于判断图像所属的类别。
  • 目标检测:提取图像特征来定位和识别目标物体。
  • 语义分割:对图像中的每个像素进行分类,以实现对图像内容的精细分割。

        总的来说,torch.nn.functional.conv2d是构建深度学习视觉模型的基础组件之一,通过合理设置其参数,可以灵活地调整卷积操作,以适应不同的任务需求。

6.4.1 卷积操作原理

6.4.2 实战演示

import torch
import torch.nn.functional as F
# 将二维矩阵转化为tensor数据类型
input = torch.tensor([[1, 2, 0, 3, 1],[0, 1, 2, 3, 1],[1, 2, 1, 0, 0],[5, 2, 3, 1, 1],[2, 1, 0, 1, 1]])
# 卷积核
kernel = torch.tensor([[1, 2, 1],[0, 1, 0],[2, 1, 0]])
# 尺寸只有高宽,不符合要求
print(input.shape)    # 5*5
print(kernel.shape)   # 3*#
input = torch.reshape(input, (1, 1, 5, 5))
kernel = torch.reshape(kernel, (1, 1, 3, 3))
print(input.shape)
print(kernel.shape)output = F.conv2d(input, kernel, stride=1)
print(output)

运行结果:

参数修改:

1)、将stride修改

        stride (int或tuple,默认值:1 ):卷积核在输入特征图上滑动的步长。如果是一个整数,表示在高度和宽度方向上的步长相同;如果是一个元组(stride_h, stride_w),则分别指定高度和宽度方向上的步长。

                output = F.conv2d(input, kernel, stride=2)

(2)、修改Padding

        padding (int或tuple,默认值:0 ):在输入特征图的边缘添加填充(padding)像素。同样,整数表示在高度和宽度方向上添加相同数量的填充;元组(padding_h, padding_w)分别指定高度和宽度方向上的填充数量。填充可以用来控制输出特征图的大小,使其与输入大小相同或满足特定的尺寸要求。

        padding=1:将输入图像左右上下两边都拓展一个像素,空的地方默认为0

                        output = F.conv2d(input, kernel, stride=1, padding=1)

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/95379.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/95379.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python自学笔记7 可视化初步

图像的组成工具库 Matplotlib:绘制静态图 Plotly: 可以绘制交互式图片 图像的绘制(Matplotlib) 创建图形,轴对象 创造等差数列 # 包含后端点 arr np.linspace(0, 1, num11) # 不包含后端点 arr_no_endpoint np.linspace(0, 1, n…

GIS 常用的矢量与栅格分析工具

矢量处理工具作用典型应用缓冲区分析Buffer环境影响区域,空间邻近度分析等,例如道路周围一公里内的学校,噪音污染影响的范围裁剪Clip例如使用A市图层裁剪全国道路数据,获取A市道路数据交集Intersect识别与LUCC、分区洪水区、基础设…

http与https协议区别;vue3本地连接https地址接口报500

文章目录问题解决方案一、问题原因分析二、解决方案详解1. 保持当前配置(推荐临时方案)2. 更安全的方案(推荐)3. 环境区分配置(最佳实践)三、为什么开发环境不用配置?问题 问题:本地…

C语言——深入理解指针(三)

C语言——深入理解指针(三) 1.回调函数是什么? 首先我们来回顾一下函数的直接调用:而回调函数就是通过函数指针调用的函数。我们将函数的指针(地址)作为参数传递给另一个函数,当这个指针被用来调…

kettle 8.2 ETL项目【四、加载数据】

一、dim_store表结构,数据来源于业务表,且随时间会有增加,属于缓慢变化维(SCD)类型二 转换步骤如下 详细步骤如下

【测试报告】SoundWave(Java+Selenium+Jmeter自动化测试)

一、项目背景 随着数字音乐内容的爆炸式增长,用户对于便捷、高效的音乐管理与播放需求日益增强。传统的本地音乐管理方式已无法满足多设备同步、在线分享与个性化推荐等现代需求。为此,我们设计并开发了一款基于Spring Boot框架的SoundWave,旨…

C++ 类和对象详解(1)

类和对象是 C 面向对象编程的核心概念,它们为代码提供了更好的封装性、可读性和可维护性。本文将从类的定义开始,逐步讲解访问限定符、类域、实例化、对象大小计算、this 指针等关键知识,并对比 C 语言与 C 在实现数据结构时的差异&#xff0…

奈飞工厂:算法优化实战

推荐系统的算法逻辑与优化技巧在流媒体行业的 “用户注意力争夺战” 中,推荐系统是决定成败的核心武器。对于拥有2.3 亿全球付费用户的奈飞(Netflix)而言,其推荐系统每天处理数十亿次用户交互,最终实现了一个惊人数据&…

【人工智能99问】BERT的训练过程和推理过程是怎么样的?(24/99)

文章目录BERT的训练过程与推理过程一、预训练过程:学习通用语言表示1. 数据准备2. MLM任务训练(核心)3. NSP任务训练4. 预训练优化二、微调过程:适配下游任务1. 任务定义与数据2. 输入处理3. 模型结构调整4. 微调训练三、推理过程…

[TryHackMe]Challenges---Game Zone游戏区

这个房间将涵盖 SQLi(手动利用此漏洞和通过 SQLMap),破解用户的哈希密码,使用 SSH 隧道揭示隐藏服务,以及使用 metasploit payload 获取 root 权限。 1.通过SQL注入获得访问权限 手工注入 输入用户名 尝试使用SQL注入…

北京JAVA基础面试30天打卡09

1.MySQL存储引擎及区别特性MyISAMMemoryInnoDBB 树索引✅ Yes✅ Yes✅ Yes备份 / 按时间点恢复✅ Yes✅ Yes✅ Yes集群数据库支持❌ No❌ No❌ No聚簇索引❌ No❌ No✅ Yes压缩数据✅ Yes❌ No✅ Yes数据缓存❌ NoN/A✅ Yes加密数据✅ Yes✅ Yes✅ Yes外键支持❌ No❌ No✅ Yes…

AI时代的SD-WAN异地组网如何落地?

在全球化运营与数字化转型浪潮下,企业分支机构、数据中心与云服务的跨地域互联需求激增。传统专线因成本高昂、部署缓慢、灵活性差等问题日益凸显不足。SD-WAN以其智能化调度、显著降本、敏捷部署和云网融合的核心优势,成为实现高效、可靠、安全异地组网…

css中的color-mix()函数

color-mix() 是 CSS 颜色模块(CSS Color Module Level 5)中引入的一个强大的颜色混合函数,用于在指定的颜色空间中混合两种或多种颜色,生成新的颜色值。它解决了传统颜色混合(如通过透明度叠加)在视觉一致性…

Github desktop介绍(GitHub官方推出的一款图形化桌面工具,旨在简化Git和GitHub的使用流程)

文章目录**1. 简化 Git 操作****2. 代码版本控制****3. 团队协作****4. 代码托管与共享****5. 集成与扩展****6. 跨平台支持****7. 适合的使用场景****总结**GitHub Desktop 是 GitHub 官方推出的一款图形化桌面工具,旨在简化 Git 和 GitHub 的使用流程,…

整数规划-分支定界

内容来自:b站数学建模老哥 如:3.4,先找小于3的,再找大于4的 逐个

JetPack系列教程(六):Paging——让分页加载不再“秃”然

前言 在Android开发的世界里,分页加载就像是一场永无止境的马拉松,每次滚动到底部,都仿佛在提醒你:“嘿,朋友,还有更多数据等着你呢!”但别担心,Google大佬们早就看透了我们的烦恼&a…

扎实基础!深入理解Spring框架,解锁Java开发新境界

大家好,今天想和大家聊聊Java开发路上绕不开的一个重要基石——Spring框架。很多朋友在接触SpringBoot、SpringCloud这些现代化开发工具时,常常会感到吃力。究其原因,往往是对其底层的Spring核心机制理解不够透彻。Spring是构建这些高效框架的…

Heterophily-aware Representation Learning on Heterogeneous Graphs

Heterophily-Aware Representation Learning on Heterogeneous Graphs (TPAMI 2025) 计算机科学 1区 I:18.6 top期刊 📌 摘要 现实世界中的图结构通常非常复杂,不仅具有全局结构上的异质性,还表现出局部邻域内的强异质相似性(heterophily)。虽然越来越多的研究揭示了图…

计算机视觉(7)-纯视觉方案实现端到端轨迹规划(思路梳理)

基于纯视觉方案实现端到端轨迹规划,需融合开源模型、自有数据及系统工程优化。以下提供一套从模型选型到部署落地的完整方案,结合前沿开源技术与工业实践: 一、开源模型选型与组合策略 1. 感知-预测一体化模型 ViP3D(清华&#…

Nginx 屏蔽服务器名称与版本信息(源码级修改)

Nginx 屏蔽服务器名称与版本信息(源码级修改) 一、背景与目的 在生产环境部署 Nginx 时,默认配置会在 Server 响应头中暴露服务类型(如 nginx)和版本号(如 nginx/1.25.4)。这些信息可能被攻击者…