Pytorch-04 搭建神经网络架构工作流

搭建神经网络架构

在pytorch中,神经网络被抽象成由一系列对数据执行特定操作的层或者模块组成,比如下面的Attention实现,每个块都是一个模块或者层。
在这里插入图片描述

如果你想快速搭建网络架构,torch.nn这个命名空间提供了所有很多开箱即用的层/模块/算子:
在这里插入图片描述
如果你想自定义一个模块也是完全可以的。每个模块都是nn.Module的子类,你只需要继承然后复写即可,这个后面有例子。

这种简洁的架构抽象可以让使用pytorch的人们快速搭建并管理精妙的模型架构。

接下来,我们将搭建一个神经网络来分类FashionMNIST数据集,来过一遍搭建网络的工作流。

import os
import torch
from torch import nn
from torch.utils.data import Dataloader
from torchvision import datasets, transforms

1. 获取可能的加速设备

为了在 加速器(accelerator) 上训练我们的模型,例如 CUDAMPSMTIAXPU,我们将遵循以下逻辑:

如果当前设备有可用的加速器,我们就使用它;否则,我们将使用 CPU

device = torch.accelerator.current_accelerator().type if  torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

2. 搭建网络结构

2.1 定义网络类

通过继承nn.Module,我们可以定义我们的神经网络类,并且在__init__里面定义我们要用到的模块或者层。然后实现forward方法来定义对输入模型的数据的实际操作以及操作顺序,并且返回推理结果。

class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.faltten = nn.Faltten() # 展平层self.linear_relu_stack = nn.Sequential( # 定义一个序列模块,被调用时会依次执行所含模块nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10),)def forward(self, x):x = self.flatten(x)logit = self.linearr_relu_stack(x)return logits

注意,__init__只负责把需要的块给初始化出来,具体数据是怎么在块间流动由forward实现。

2.2 实例化网络并查看结构

现在我们实例化网络,并且把它搬到device侧,然后打印出他的结构:

model = NeuralNetwork().to(device)
print(model)

在这里插入图片描述

2.3 进行网络“冒烟测试”

搭建好网络结构之后,强烈建议进行一次“冒烟测试”,用一个符合输入shape的tensor看看整个网络能不能跑通。

要给模型传入数据进行推理,直接给模型传入数据即可,千万别直接调用forward方法,因为model(x)还会做一些forward没做的一些其他必要操作。

X = torch.rand(1, 28, 28, device=device)
logits = model(X)
print(logits.shape)
pred_probab = nn.Softmax(dim=1)(logits)
print(pred_probab)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")

在这里插入图片描述

给模型输入数据之后,模型返回一个2维的tensor,dim=0的数据是batch中的具体样本idx,dim=1的数据则是输出的这个样本的所属10个不同类别的预测值。最后我们套一层nn.Softmax, 就可以获得每个类别的概率pred_probab了。最后对其使用argmax(1)找到该张量在dim=1维度上的最大值索引,就获得了这一次推理的分类结果。

3. 进阶操作:获取模型当前的参数

如果你想要一点可解释性,你可能得用到这个

神经网络中的许多层都是参数化的,也就是说,它们有相关的权重(weights)偏差(biases),这些值会在训练过程中进行优化。

当你的模型继承自 nn.Module 时,PyTorch 会自动追踪模型对象中定义的所有字段。因此,你可以通过模型的 parameters()named_parameters() 方法来访问所有这些参数。

print(model)for name, param in model.named_parameters():pritn(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n") # 矩阵获取前两行,bias获取前两个

在这个例子中,我们遍历了每一个参数,并打印出它的尺寸(size)和部分值预览。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/91916.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/91916.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

从“碎片化”到“完美重组”:IP报文的分片艺术

前言 在网络通信中,当IP层需要传输的数据包大小超过数据链路层的MTU限制时,就必须进行分片处理。本文将完整解析IP分片的工作机制,包括分片字段的作用、如何减少分片,以及分片报文的组装原理。 IP报头解析请参考&#xff…

[GESP202306 四级] 2023年6月GESP C++四级上机题超详细题解,附带讲解视频!

本文为2023年6月GESP C四级的上机题目的详细题解!觉得写的不错或者有帮助可以点个赞啦! (第一次讲解视频,有问题可以指出,不足之处也可以指出) 目录 题目一讲解视频: 题目二讲解视频: 题目一: 幸运数 题目大意: …

内网穿透 FRP 配置指南

关键词:内网穿透、FRP配置、frps、frpc、远程访问、自建服务器、反向代理、TCP转发、HTTP转发 在开发或部署项目时,我们经常遇到内网设备无法被公网访问的问题,例如你想从外网访问你家里的 NAS、远程调试开发板,或是访问本地测试环…

SpringBoot 信用卡检测、OpenAI gym、OCR结合、DICOM图形处理、知识图谱、农业害虫识别实战

信用卡欺诈检测通常使用公开数据集 数据准备与预处理 信用卡欺诈检测通常使用公开数据集如Kaggle的信用卡交易数据集。数据预处理包括处理缺失值、标准化数值特征、处理类别特征。在Spring Boot中,可以使用pandas或sklearn进行数据预处理。 // 示例:使用Spring Boot读取CS…

使用 Docker 部署 Golang 程序

Docker 是部署 Golang 应用程序的绝佳方式,它可以确保环境一致性并简化部署流程。以下是完整的指南: 1. 准备 Golang 应用程序 首先确保你的 Go 应用程序可以正常构建和运行。一个简单的示例 main.go: package mainimport ("fmt""net/http" )func ha…

从零开始的CAD|CAE开发: LBM源码实现分享

起因:上期我们写了流体仿真的经典案例: 通过LBM,模拟计算涡流的形成,当时承诺: 只要验证通过,就把代码开源出来;ok.验证通过了,那么我也就将代码全都贴出来代码开源并贴出:public class LidDrivenCavityFlow : IDisposable{public LidDrivenCavityFlow(int width 200, int hei…

仓库管理系统-17-前端之物品类型管理

文章目录 1 表设计(goodstype) 2 后端代码 2.1 Goodstype.java 2.2 GoodstypeMapper.java 2.3 GoodstypeService.java 2.4 GoodstypeServiceImpl.java 2.5 GoodstypeController.java 3 前端代码 3.1 goodstype/GoodstypeManage.vue 3.2 添加菜单 3.3 页面显示 1、goodstype表设…

共识算法深度解析:PoS/DPoS/PBFT对比与Python实现

目录 共识算法深度解析:PoS/DPoS/PBFT对比与Python实现 1. 引言:区块链共识的核心挑战 2. 共识算法基础 2.1 核心设计维度 2.2 关键评估指标 3. PoS(权益证明)原理与实现 3.1 核心机制 3.2 Python实现 4. DPoS(委托权益证明)原理与实现 4.1 核心机制 4.2 Python实现 5. P…

3.JVM,JRE和JDK的关系是什么

3.JVM,JRE和JDK的关系是什么 1.JDK(Java Development Kit),是功能齐全的Java SDK,包含JRE和一些开发工具(比如java.exe,运行工具javac.exe编译工具,生成.class文件,javaw.exe,大多用…

深度学习技术发展思考笔记 || 一项新技术的出现,往往是为了解决先前范式中所暴露出的特定局限

深度学习领域的技术演进,遵循着一个以问题为导向的迭代规律。一项新技术的出现,往往是为了解决先前范式中所暴露出的特定局限。若将这些新技术看作是针对某个问题的“解决方案”,便能勾勒出一条清晰的技术发展脉络。 例如,传统的前…

Promise的reject处理: then的第二个回调 与 catch回调 笔记250804

Promise的reject处理: then的第二个回调 与 catch回调 笔记250804 Promise 错误处理深度解析:then 的第二个回调 vs catch 在 JavaScript 的 Promise 链式调用中,错误处理有两种主要方式:.then() 的第二个回调函数和 .catch() 方法。这两种方…

Maven模块化开发与设计笔记

1. 模块化开发模块化开发是将大型应用程序拆分成多个小模块的过程,每个模块负责不同的功能。这有助于降低系统复杂性,提高代码的可维护性和可扩展性。2. 聚合模块聚合模块(父模块)用于组织和管理多个子模块。它定义了项目的全局配…

sqli-labs:Less-21关卡详细解析

1. 思路🚀 本关的SQL语句为: $sql"SELECT * FROM users WHERE username($cookee) LIMIT 0,1";注入类型:字符串型(单引号、括号包裹)、GET操作提示:参数需以)闭合关键参数:cookee p…

大模型+垂直场景:技术纵深、场景适配与合规治理全景图

大模型垂直场景:技术纵深、场景适配与合规治理全景图​​核心结论​:2025年大模型落地已进入“深水区”,技术价值需通过 ​领域纵深(Domain-Deep)​、数据闭环(Data-Driven)​、部署友好&#x…

Kotlin Daemon 简介

Kotlin Daemon 是 Kotlin 编译器的一个后台进程,旨在提高编译性能。它通过保持编译环境的状态来减少每次编译所需的启动时间,从而加快增量编译的速度。 Kotlin Daemon 的主要功能增量编译: 只编译自上次编译以来发生更改的文件,节…

鸿蒙南向开发 编写一个简单子系统

文章目录 前言给设备,编写一个简单子系统总结 一、前言 对于应用层的开发,搞了十几年,其实已经有点开发腻的感觉了,翻来覆去,就是调用api,页面实现,最多就再加个性能优化,但对底层…

超详细:2026年博士申请时间线

博士申请是一场持久战,需要提前规划。那么,如何科学安排2026年博士申请时间线?SCI论文发表的最佳时间节点是什么?今天给所有打算申博的同学们,详细解析下,每个时间节点的重点内容。2025年4月:是…

Python爬虫实战:研究tproxy代理工具,构建电商数据采集系统

1. 引言 1.1 研究背景 在大数据与人工智能技术快速发展的背景下,网络数据已成为企业决策、学术研究、舆情监控的核心资源。据 Statista 统计,2024 年全球互联网数据总量突破 180ZB,其中 80% 为非结构化数据,需通过爬虫技术提取与转化。Python 凭借其简洁语法与丰富的爬虫…

HighgoDB查询慢SQL和阻塞SQL

文章目录环境文档用途详细信息环境 系统平台:N/A 版本:6.0,5.6.5,5.6.4,5.6.3,5.6.1,4.5.2,4.5,4.3.4.9,4.3.4.8,4.3.4.7,4.3.4.6,4.3.4.5,4.3.4.4,4.3.4.3,4.3.4.2,4.3.4,4.7.8,4.7.7,4.7.6,4.7.5,4.3.2 文档用途 本文介绍了如何对数据库日志进行分析…

day15 SPI

1串行外设接口概述1.1基本概念SPI(串行外设接口)是一种高速、全双工、同步的串行通信协议。串行外设接口一般是需要4根线来进行通信(NSS、MISO、MOSI、SCK),但是如果打算实现单向通信(最少3根线&#xff09…