Python训练打卡Day38

Dataset和Dataloader类

知识点回顾:

  1. Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
  2. Dataloader类
  3. minist手写数据集的了解

        在遇到大规模数据集时,显存常常无法一次性存储所有数据,所以需要使用分批训练的方法。为此,PyTorch提供了DataLoader类,该类可以自动将数据集切分为多个批次batch,并支持多线程加载数据。此外,还存在Dataset类,该类可以定义数据集的读取方式和预处理方式。

1. DataLoader类:决定数据如何加载

2. Dataset类:告诉程序去哪里找数据,如何读取单个样本,以及如何预处理。

使用的数据集为MNIST手写数字数据集。该数据集包含60000张训练图片和10000张测试图片,每张图片大小为28*28像素,共包含10个类别。因为每个数据的维度比较小,所以既可以视为结构化数据,用机器学习、MLP训练,也可以视为图像数据,用卷积神经网络训练。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt# 设置随机种子,确保结果可复现
torch.manual_seed(42)
# 1. 数据预处理,该写法非常类似于管道pipeline
# transforms 模块提供了一系列常用的图像预处理操作# 先归一化,再标准化
transform = transforms.Compose([transforms.ToTensor(),  # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差,这个值很出名,所以直接使用
])# 2. 加载MNIST数据集,如果没有会自动下载
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)

Dataset类

import matplotlib.pyplot as plt# 随机选择一张图片,可以重复运行,每次都会随机选择
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image, label = train_dataset[sample_idx] # 获取图片和标签

为什么train_dataset[sample_idx]可以获取到图片和标签,是因为 datasets.MNIST这个类继承了torch.utils.data.Dataset类,这个类中有一个方法__getitem__,这个方法会返回一个tuple,tuple中第一个元素是图片,第二个元素是标签。

torch.utils.data.Dataset类

一个抽象基类,所有自定义数据集都需要继承它并实现两个核心方法:

- __len__():返回数据集的样本总数。

- __getitem__(idx):根据索引idx返回对应样本的数据和标签。

PyTorch 要求所有数据集必须实现__getitem__和__len__,这样才能被DataLoader等工具兼容。这是一种接口约定,类似函数参数的规范。这意味着,如果你要创建一个自定义数据集,你需要实现这两个方法,否则PyTorch将无法识别你的数据集。

 __getitem__和__len__ 是类的特殊方法(也叫魔术方法 ),它们不是像普通函数那样直接使用,而是需要在自定义类中进行定义,来赋予类特定的行为。

__getitem__方法

用于让对象支持索引操作,当使用[]语法访问对象元素时,Python 会自动调用该方法。

# 示例代码
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __getitem__(self, idx):return self.data[idx]# 创建类的实例
my_list_obj = MyList()
# 此时可以使用索引访问元素,这会自动调用__getitem__方法
print(my_list_obj[2])  # 输出:30

__len__方法

用于返回对象中元素的数量,当使用内置函数len()作用于对象时,Python 会自动调用该方法。

class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __len__(self):return len(self.data)# 创建类的实例
my_list_obj = MyList()
# 使用len()函数获取元素数量,这会自动调用__len__方法
print(len(my_list_obj))  # 输出:5
# minist数据集的简化版本
class MNIST(Dataset):def __init__(self, root, train=True, transform=None):# 初始化:加载图片路径和标签self.data, self.targets = fetch_mnist_data(root, train) # 这里假设 fetch_mnist_data 是一个函数,用于加载 MNIST 数据集的图片路径和标签self.transform = transform # 预处理操作def __len__(self): return len(self.data)  # 返回样本总数def __getitem__(self, idx): # 获取指定索引的样本# 获取指定索引的图像和标签img, target = self.data[idx], self.targets[idx]# 应用图像预处理(如ToTensor、Normalize)if self.transform is not None: # 如果有预处理操作img = self.transform(img) # 转换图像格式# 这里假设 img 是一个 PIL 图像对象,transform 会将其转换为张量并进行归一化return img, target  # 返回处理后的图像和标签

 通俗地类比:

Dataset = 厨师(准备单个菜品)

DataLoader = 服务员(将菜品按订单组合并上桌)

预处理(如切菜、调味)属于厨师的工作,而非服务员。所以在dataset就需要添加预处理步骤。

# 可视化原始图像(需要反归一化)
def imshow(img):img = img * 0.3081 + 0.1307  # 反标准化npimg = img.numpy()plt.imshow(npimg[0], cmap='gray') # 显示灰度图像plt.show()print(f"Label: {label}")
imshow(image)

Dataloader类

# 3. 创建数据加载器
train_loader = DataLoader(train_dataset,batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关shuffle=True # 随机打乱数据
)test_loader = DataLoader(test_dataset,batch_size=1000 # 每个批次1000张图片# shuffle=False # 测试时不需要打乱数据
)

总结:

Dataset 类:定义数据的内容和格式(即 “如何获取单个样本”),包括:
◦ 数据存储路径 / 来源(如文件路径、数据库查询)。
◦ 原始数据的读取方式(如图像解码为 PIL 对象、文本读取为字符串)。
◦ 样本的预处理逻辑(如裁剪、翻转、归一化等,通常通过 transform 参数实现)。
◦ 返回值格式(如 (image_tensor, label))。
• DataLoader 类:定义数据的加载方式和批量处理逻辑(即 “如何高效批量获取数据”),包括:
◦ 批量大小(batch_size)。
◦ 是否打乱数据顺序(shuffle)。

@浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/bicheng/83691.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

web3-区块链基础:从区块添加机制到哈希加密与默克尔树结构

区块链基础:从区块添加机制到哈希加密与默克尔树结构 什么是区块链 抽象的回答: 区块链提供了一种让多个参与方在没有一个唯一可信方的情况下达成合作 若有可信第三方 > 不需要区块链 [金融系统中常常没有可信的参与方] 像股票市场,或者一个国家的…

MySQL 索引:为使用 B+树作为索引数据结构,而非 B树、哈希表或二叉树?

在数据库的世界里,性能是永恒的追求。而索引,作为提升查询速度的利器,其底层数据结构的选择至关重要。如果你深入了解过 MySQL(尤其是其主流存储引擎 InnoDB),你会发现它不约而同地选择了 B树 作为索引的主…

Kafka broker 写消息的过程

Producer → Kafka Broker → Replication → Consumer|Partition chosen (by key or round-robin)|Message appended to end of log (commit log)上面的流程是kafka 写操作的大体流程。 kafka 不会特意保留message 在内存中,而是直接写入了disk。 那么消费的时候&…

leetcode hot100(两数之和、字母异位词分组、最长连续序列)

两数之和 题目链接 参考链接&#xff1a; 题目描述&#xff1a; 暴力法 双重循环查找目标值 class Solution {public int[] twoSum(int[] nums, int target) {int[] res new int[2];for(int i 0 ; i < nums.length ; i){boolean isFind false;for(int j i 1 ; j …

SkyWalking架构深度解析:分布式系统监控的利器

一、SkyWalking概述 SkyWalking是一款开源的APM(应用性能监控)系统&#xff0c;专门为微服务、云原生和容器化架构设计。它由Apache软件基金会孵化并毕业&#xff0c;已成为分布式系统监控领域的明星项目。 核心特性 ‌分布式追踪‌&#xff1a;跨服务调用链路的完整追踪‌服务…

Matlab程序设计基础

matlab程序设计基础 程序设计函数文件1.函数文件的基本结构2.创建并使用函数文件的示例3.带多个输出的函数示例4.包含子函数的函数文件 流程控制1. if 条件语句2. switch 多分支选择语句3. try-catch 异常处理语句ME与lasterr 4. while 循环语句5. for 循环语句break和continue…

Client-Side Path Traversal 漏洞学习笔记

近年来,随着Web前端技术的飞速发展,越来越多的数据请求和处理逻辑被转移到客户端(浏览器)执行。这大大提升了用户体验,但也带来了新的安全威胁。其中,Client-Side Path Traversal(客户端路径穿越,CSPT)作为一种新兴的漏洞类型,逐渐受到安全研究者和攻击者的关注。本文…

基于Socketserver+ThreadPoolExecutor+Thread构造的TCP网络实时通信程序

目录 介绍&#xff1a; 源代码&#xff1a; Socketserver-服务端代码 Socketserver客户端代码&#xff1a; 介绍&#xff1a; socketserver是一种传统的传输层网络编程接口&#xff0c;相比WebSocket这种应用层的协议来说&#xff0c;socketserver比较底层&#xff0c;soc…

【无标题】平面图四色问题P类归属的严格论证——基于拓扑收缩与动态调色算法框架

平面图四色问题P类归属的严格论证——基于拓扑收缩与动态调色算法框架 --- #### **核心定理** 任意平面图 \(G (V, E)\) 的四色着色问题可在多项式时间 \(O(|V|^2)\) 内求解&#xff0c;且算法正确性由以下三重保证&#xff1a; 1. **拓扑不变性**&#xff08;Kuratowsk…

HALCON 深度学习训练 3D 图像的几种方式优缺点

HALCON 深度学习训练 3D 图像的几种方式优缺点 ** 在计算机视觉和工业检测等领域&#xff0c;3D 图像数据的处理和分析变得越来越重要&#xff0c;HALCON 作为一款强大的机器视觉软件&#xff0c;提供了多种深度学习训练 3D 图像的方式。每种方式都有其独特的设计思路和应用场…

pytest中的元类思想与实战应用

在Python编程世界里&#xff0c;元类是一种强大而高级的特性&#xff0c;它能在类定义阶段深度定制类的创建与行为。而pytest作为热门的测试框架&#xff0c;虽然没有直接使用元类&#xff0c;但在设计机制上&#xff0c;却暗含了许多与元类思想相通的地方。接下来&#xff0c;…

以太网帧结构和封装【三】-- TCP/UDP头部信息

TCP头部用于建立可靠连接、流量控制及数据完整性校验。 Ipv4封装tcp报&#xff1a; Ipv6封装tcp报&#xff1a; UDP头部信息 UDP关键协议特性&#xff1a; 1&#xff09;无连接&#xff1a;无需握手&#xff0c;直接发送数据。 2&#xff09;不可靠性&#xff1a;不保证数据…

MySQL补充知识点学习

书接上文&#xff1a;MySQL关系型数据库学习&#xff0c;继续看书补充MySQL知识点学习。 1. 基本概念学习 1.1 游标&#xff08;Cursor&#xff09; MySQL 游标是一种数据库对象&#xff0c;它允许应用程序逐行处理查询结果集&#xff0c;而不是一次性获取所有结果。游标在需…

基于InternLM的情感调节大师FunGPT

基于书生系列大模型&#xff0c;社区用户不断创造出令人耳目一新的项目&#xff0c;从灵感萌发到落地实践&#xff0c;每一个都充满智慧与价值。“与书生共创”将陆续推出一系列文章&#xff0c;分享这些项目背后的故事与经验。欢迎订阅并积极投稿&#xff0c;一起分享经验与成…

【拓扑】1639.拓扑排序

题目描述 这是 2018 2018 2018 年研究生入学考试中给出的一个问题&#xff1a; 以下哪个选项不是从给定的有向图中获得的拓扑序列&#xff1f; 现在&#xff0c;请你编写一个程序来测试每个选项。 输入格式 第一行包含两个整数 N N N 和 M M M&#xff0c;分别表示有向图…

macOS 上使用 Homebrew 安装redis-cli

在 macOS 上使用 Homebrew 安装 redis-cli&#xff08;Redis 命令行工具&#xff09;非常简单&#xff0c;以下是详细步骤&#xff1a; 1. 安装 Redis&#xff08;包含 redis-cli&#xff09; 运行以下命令安装 Redis&#xff1a; brew install redis这会安装完整的 Redis 服…

Scratch节日 | 六一儿童节射击游戏

六一儿童节快乐&#xff01;这款超有趣的 六一儿童节射击游戏&#xff0c;让你变身小猫弓箭手&#xff0c;守护节日的快乐时光&#xff01; &#x1f3ae; 游戏玩法 上下方向键&#xff1a;控制小猫的位置&#xff0c;自由移动&#xff0c;瞄准目标&#xff01; 空格键&#…

[AI Claude] 软件测试2

好的&#xff0c;我现在为你准备一份预填充好大部分内容的测试报告和PPT内容。这里面的数据是我根据项目结构和常见的测试场景推理和编造的&#xff0c;你需要根据你的实际操作结果&#xff08;包括截图、实际数据、发现的缺陷等&#xff09;进行替换和修改。 我将按照之前定义…

程序代码篇---face_recognition库实现的人脸检测系统

以下是一个基于face_recognition库的人脸管理系统,支持从文件夹加载人脸数据、实时识别并显示姓名,以及动态添加新人脸。系统采用模块化设计,代码结构清晰,易于扩展。 一、系统架构 face_recognition_system/ ├── faces/ # 人脸数据库(按姓名命名子…

Cursor 工具项目构建指南:Java 21 环境下的 Spring Boot Prompt Rules 约束

简简单单 Online zuozuo: 简简单单 Online zuozuo 简简单单 Online zuozuo 简简单单 Online zuozuo 简简单单 Online zuozuo :本心、输入输出、结果 简简单单 Online zuozuo : 文章目录 Cursor 工具项目构建指南:Java 21 环境下的 Spring Boot Prompt Rules 约束前言项目简…