【动手学深度学习】4.2~4.3 多层感知机的实现


目录

    • 4.2. 多层感知机的从零开始实现
      • 1)初始化模型参数
      • 2)激活函数
      • 3)模型
      • 4)损失函数
      • 5)训练
    • 4.3. 多层感知机的简洁实现
      • 1)模型
      • 2)小结


.

4.2. 多层感知机的从零开始实现

现在让我们实现一个多层感知机。 为了与之前softmax回归获得的结果进行比较, 我们将继续使用Fashion-MNIST图像分类数据集。

import torch
from torch import nn
from d2l import torch as d2lbatch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

.

1)初始化模型参数

回想一下,Fashion-MNIST中的每个图像由 28 x 28 = 784 个灰度像素值组成。 所有图像共分为10个类别。 忽略像素之间的空间结构, 我们可以将每个图像视为具有784个输入特征 和10个类的简单分类数据集。

首先,我们将实现一个具有单隐藏层的多层感知机, 它包含256个隐藏单元。

注意,我们可以将这两个变量都视为超参数。 通常,我们选择2的若干次幂作为层的宽度。 因为内存在硬件中的分配和寻址方式,这么做往往可以在计算上更高效。

我们用几个张量来表示我们的参数。 注意,对于每一层我们都要记录一个权重矩阵和一个偏置向量。 跟以前一样,我们要为损失关于这些参数的梯度分配内存。

num_inputs, num_outputs, num_hiddens = 784, 10, 256W1 = nn.Parameter(torch.randn(num_inputs, num_hiddens, requires_grad=True) * 0.01)
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
W2 = nn.Parameter(torch.randn(num_hiddens, num_outputs, requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))params = [W1, b1, W2, b2]

.

2)激活函数

为了确保我们对模型的细节了如指掌, 我们将实现ReLU激活函数, 而不是直接调用内置的relu函数。

def relu(X):a = torch.zeros_like(X)  # 创建与输入X形状相同的全零张量return torch.max(X, a)   # 逐元素比较X和零张量,返回较大值

.

3)模型

因为我们忽略了空间结构, 所以我们使用reshape将每个二维图像转换为一个长度为num_inputs的向量。 只需几行代码就可以实现我们的模型。

def net(X):X = X.reshape((-1, num_inputs))H = relu(X@W1 + b1)  # 这里“@”代表矩阵乘法return (H@W2 + b2)

.

4)损失函数

由于我们已经从零实现过softmax函数, 因此在这里我们直接使用高级API中的内置函数来计算softmax和交叉熵损失。

loss = nn.CrossEntropyLoss(reduction='none')

.

5)训练

幸运的是,多层感知机的训练过程与softmax回归的训练过程完全相同。 可以直接调用d2l包的train_ch3函数(参见3.6节 ), 将迭代周期数设置为10,并将学习率设置为0.1.

num_epochs, lr = 10, 0.1
updater = torch.optim.SGD(params, lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

在这里插入图片描述

为了对学习到的模型进行评估,我们将在一些测试数据上应用这个模型。

d2l.predict_ch3(net, test_iter)

在这里插入图片描述


4.3. 多层感知机的简洁实现

本节将介绍通过高级API更简洁地实现多层感知机。

import torch
from torch import nn
from d2l import torch as d2l

.

1)模型

与softmax回归的简洁实现相比, 唯一的区别是我们添加了2个全连接层(之前我们只添加了1个全连接层)。 第一层是隐藏层,它包含256个隐藏单元,并使用了ReLU激活函数。 第二层是输出层。

net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),nn.ReLU(),nn.Linear(256, 10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights);

训练过程的实现与我们实现softmax回归时完全相同, 这种模块化设计使我们能够将与模型架构有关的内容独立出来。

batch_size, lr, num_epochs = 256, 0.1, 10
loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=lr)train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

在这里插入图片描述

.

2)小结

  • 我们可以使用高级API更简洁地实现多层感知机。

  • 对于相同的分类问题,多层感知机的实现与softmax回归的实现相同,只是多层感知机的实现里增加了带有激活函数的隐藏层。

.


声明:资源可能存在第三方来源,若有侵权请联系删除!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/84473.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/84473.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

54-Oracle 23 ai DBMS_HCHECK新改变-从前的hcheck.sql

Oracle Hcheck(Health Check)是Oracle数据库内置的健康监测工具,自动化检查数据库的核心问题,包括数据字典一致性、性能瓶颈、空间使用及安全隐患。本质是数据字典的CT扫描仪,其核心价值在于将“字典逻辑错误”这类灰色…

AI 产品的“嵌点”(Embedded Touchpoints)

核心主题: AI 产品的成功不在于功能的强大与独立,而在于其能否作为“嵌点”(Embedded Touchpoints)无缝融入用户现有的行为流(Flow),消除微小摩擦,在用户真正需要的时机和场景中“无…

如何在WordPress中添加导航菜单?

作为一个用了很多年 WordPress 的用户,我特别清楚导航菜单有多重要。一个清晰的导航菜单能让访问者快速找到他们想要的信息,同时也能提升网站的用户体验。而对于WordPress用户来说,学会如何添加和自定义导航菜单是构建高质量网站的第一步。今…

【pdf】Java代码生成PDF

目录 依赖 创建单元格 表格数据行辅助添加方法 创建表头单元格 创建下划线 创建带下划线的文字 创建PDF 依赖 <dependency><groupId>com.itextpdf</groupId><artifactId>itextpdf</artifactId><version>5.4.2</version> <…

Vite 的“心脏移植”:Rolldown

1. 现状&#xff1a;你搁这儿玩双截棍呢&#xff1f; 现在Vite这逼样&#xff1a;开发用esbuild&#xff0c;生产用Rollup&#xff0c;精分现场是吧&#xff1f;大型项目尼玛启动慢成狗&#xff0c;请求多到炸穿地心&#xff0c;生产/dev环境差异能让你debug到原地升天&#x…

【网络安全】文件上传型XSS攻击解析

引言 文件上传功能作为现代Web应用的核心交互模块&#xff0c;其安全防护水平直接关系到系统的整体安全性。本文基于OWASP、CVE等权威研究&#xff0c;结合2024-2025年最新漏洞案例&#xff0c;系统剖析了文件上传场景下的XSS攻击技术演进路径。研究揭示&#xff1a;云原生架构…

Java 集合框架底层数据结构实现深度解析

Java 集合框架&#xff08;Java Collections Framework, JCF&#xff09;是支撑高效数据处理的核心组件&#xff0c;其底层数据结构的设计直接影响性能与适用场景。本文从线性集合、集合、映射三大体系出发&#xff0c;系统解析ArrayList、LinkedList、HashMap、TreeSet等核心类…

Dify动手实战教程(进阶-知识库:新生入学指南)

目录 进阶-知识库&#xff1a;新生入学指南 1.创建知识库 2.创建Agent 去年agent智能体爆火&#xff0c;我自己也使用了多款智能体产品来搭建agent解决生活中的实际问题&#xff0c;如dify、coze等等。dify作为一个开源的框架得到了大量的应用&#xff0c;如一些需要隐私保护…

Vue3+TypeScript+ Element Plus 从Excel文件导入数据,无后端(点击按钮,选择Excel文件,由前端解析数据)

在 Vue 3 TypeScript Element Plus 中实现文件导入功能&#xff0c;可以通过以下步骤完成&#xff1a; 1. 安装依赖 bash 复制 下载 npm install xlsx # 用于解析Excel文件 npm install types/xlsx -D # TypeScript类型声明 2. 组件实现 vue 复制 下载 <templ…

一些torch函数用法总结

1.torch.nonzero(input, *, as_tupleFalse) 作用&#xff1a;在PyTorch中用于返回输入张量中非零元素的位置索引。 返回值&#xff1a;返回一个张量&#xff0c;每行代表一个非零元素的索引。 参数含义&#xff1a; &#xff08;1&#xff09;input:输入的PyTorch 张量。 …

moments_object_model_3d这么理解

这篇文章是我对这个算子的理解,和三个输出结果分别用在什么地方 算子本身 moments_object_model_3d( : : ObjectModel3D, MomentsToCalculate : Moments) MomentsToCalculate:对应三个可选参数,分别是 1, mean_points: 就是点云在xyz方向上坐标的平均值 2, central_m…

性能测试|数据说话!在SimForge平台上用OpenRadioss进行汽车碰撞仿真,究竟多省时?

Radioss是碰撞仿真领域中十分成熟的有限元仿真软件&#xff0c;可以对工程中许多非线性问题进行求解&#xff0c;例如汽车碰撞、产品跌落、导弹爆炸、流固耦合分析等等。不仅可以提升产品的刚度、强度、碰撞的安全性能等&#xff0c;还可以在降低产品研发成本的同时提升研发效率…

数据结构学习——KMP算法

//KMP算法 #include <iostream> #include <string> #include <vector> #include <cstdlib>using namespace std;//next数组值的推导void getNext(string &str, vector<int>& next){int strlong str.size();//next数组的0位为0next[0]0;…

博士,超28岁,出局!

近日&#xff0c;长沙市望城区《2025年事业引才博士公开引进公告》引发轩然大波——博士岗位年龄要求28周岁及以下&#xff0c;特别优秀者也仅放宽至30周岁。 图源&#xff1a;网络 这份规定让众多"高龄"博士生直呼不合理&#xff0c;并在社交平台掀起激烈讨论。 图源…

使用Nuitka打包Python程序,编译为C提高执行效率

在 Python 的世界里&#xff0c;代码打包与发布一直是开发者关注的重要话题。前面我们介绍了Pyinstaller的使用&#xff0c;尽管 PyInstaller 是最常用的工具之一&#xff0c;但对于性能、安全性、兼容性有更高要求的项目&#xff0c;Nuitka 正迅速成为更优的选择。本文将全面介…

基于机器学习的恶意请求检测

好久没写文章了&#xff0c;忙毕业设计ING&#xff0c;终于做好了发出来。 做了针对恶意URL的检测&#xff0c;改进了杨老师这篇参考文献的恶意请求检测的方法 [网络安全自学篇] 二十三.基于机器学习的恶意请求识别及安全领域中的机器学习-CSDN博客 选择使用了XGBoost算法进…

深入理解XGBoost(何龙 著)学习笔记(五)

深入理解XGBoost&#xff08;何龙 著&#xff09;学习笔记&#xff08;五&#xff09; 本文接上一篇&#xff0c;内容为线性回归&#xff0c;介绍三部分&#xff0c;首先介绍了"模型评估”&#xff0c;然后分别提供了线性回归的模型代码&#xff1a;scikit-learn的Linear…

工业级MySQL基准测试专家指南

工业级MySQL基准测试专家指南 一、深度风险识别增强版 风险类型典型表现进阶检测方案K8s存储性能抖动PVC卷IOPS骤降50%使用kubestone进行CSI驱动压力测试HTAP读写冲突OLAP查询导致OLTP事务超时用TPCH+Sysbench混合负载测试冷热数据分层失效压缩表查询耗时激增10倍监控INNODB_C…

Spring WebFlux和Spring MVC的对比

原文网址&#xff1a;Spring WebFlux和Spring MVC的对比-CSDN博客 简介 本文介绍Spring WebFlux和Spring MVC的区别。 Webflux&#xff1a;是异步非阻塞的&#xff08;IO多路复用&#xff09;&#xff0c;基于Netty。适合网络转发类的应用&#xff0c;比如&#xff1a;网关。…

解析401 Token过期自动刷新机制:Kotlin全栈实现指南

在现代Web应用中&#xff0c;Token过期导致的401错误是影响用户体验的关键问题。本文将手把手实现一套完整的Token自动刷新机制&#xff0c;覆盖从原理到实战的全过程。 一、为什么需要Token自动刷新&#xff1f; 当用户使用应用时&#xff0c;会遇到两种典型场景&#xff1a;…