Pytorch-02数据集和数据加载器的基本原理和基本操作

1. 为什么要有数据集类和数据加载器类?

一万个人会有一万种获取并处理原始数据样本的代码,这会导致对数据的操作代码标准不一,并且很难复用。
在这里插入图片描述

为了解决这个问题,Pytorch提供了两种最基本的数据相关类:

  • torch.utils.data.Dataset: 一个数据集对象,包含每个数据样本路径以及对应标签
  • torch.utils.data.DataLoader:持有一个对Dataloader的迭代器,通过调用Dataset__getitem__函数方便地获取实际的样本-标签对

PyTorch 为不同的任务类型提供了方便的预加载数据集,例如 torchvision.datasets、torchaudio.datasets 等。这些数据集都是 torch.utils.data.Dataset 的子类,可以直接通过dataset.数据集名称的方式来方便的下载经典的数据集,在下面你会看到它的使用例。

2. Dataset类的使用方法

2.1 加载一个Fashion-MNIST数据集

Fashion-MNIST 是一个来自 Zalando 的文章图像数据集,包含 60,000 个训练样本和 10,000 个测试样本。每个样本由一张 28×28 的灰度图像和其对应的 10 个类别中的一个标签组成。

这是一个使用TorchVision预加载数据集类加载Fashion-MNIST 数据集的例子,如下是每个参数代表的意思:

  • root:是存储训练/测试数据的路径。
  • train:指定是训练数据集还是测试数据集。
  • download=True:如果数据在 root 路径下不可用,则从互联网下载。
  • transform 和 target_transform:分别指定特征和标签的转换。
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(root="data", # 指定数据集实际存放的路径(相对于本代码文件)train=True, # 指定这是训练集还是测试集download=True, # 如果在root下没有数据,从网络上自动下载transform=ToTensor() # 给每一张图片转换为Tensor的数据类型
)test_data = datasets.FashionMNIST(root="data", # 指定数据集实际存放的路径(相对于本代码文件)train=False, # 指定这是训练集还是测试集download=True, # 如果在root下没有数据,从网络上自动下载transform=ToTensor() # 给每一张图片转换为Tensor的数据类型
)

在这里插入图片描述

2.2 遍历并可视化数据集

我们可以简单的使用training_data[index]来获取Datasets类中对应index的样本。通常可以用matplotlib来可视化我们的一些训练数据集:

labels_map = { # 定义一个标签映射字典0: "T-Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",
}figure = plt.figure(figsize=(8, 8)) # 创建一个新的画布,大小为8x8英寸
cols, rows = 3, 3 # 定义展示网格尺寸 3x3的展示网格,每个网格展示i一个图片for i in range(1, cols * rows + 1): # plt的索引从1开始,配合一下sample_idx = torch.randint(len(training_data), size=(1,)).item() # 生成一个包含1个元素的张量,item()回python数据类型之后为0到数据集大小-1的随机整数img, label = training_data[sample_idx] # 本质上是在调用__getitem__函数figure.add_subplot(rows, cols, i) # 在之前创建的图形窗口中,添加一个子图(subplot),并将当前的画笔操作对象设置为当前子图plt.title(labels_map[label]) # 子图的标题设置为对应的标签字符串plt.axis("off") # 不显示坐标轴plt.imshow(img.squeeze(), cmap="gray") # 把当前网格画好
plt.show() # 展示画布

这里我并不知道为啥要使用img.squeeze()这个方法, 直到我把img的shape的打印出来:
在这里插入图片描述
现在img是一个3维的tensor,但是plt.imshow需要输入二维的tensor,所以使用squeeze的目的是把所有的尺寸为1的维度给挤压掉,将img维度降维到2维,然后就可以用plt可视化了。

在这里插入图片描述

2.3 进阶:如何制作一个自己的数据集类

自定义的 Dataset 类必须实现三个函数:__init____len____getitem__。请看下面的实现示例:FashionMNIST 图像存储在 img_dir 目录中,而它们的标签则单独保存在 annotations_file 的 CSV 文件里。

import os
import pandas as pd
from torchvision.io import decode_imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitemm__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) # iloc全写为“integer location”, 表明你要通过数据的行和列的整数索引来选择数据image = decode_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

在接下来的部分将详细解释每个方法的作用。

__init__

def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transform

这个方法会在初始化数据集的时候调用。其主要完成如下工作:

  1. 读取标签文件
  2. 指定图片文件夹路径
  3. 指定样本和标签的transform(这个下面细讲)

一个Fashion-MNIST是一个分类任务,其标签文件annotations大概长这样:

tshirt1.jpg, 0 # 样本-标签对
tshirt2.jpg, 0
......
ankleboot999.jpg, 9

__len__

这个方法是简单返回数据集的样本数量:

def __len__(self):return len(self.img_labels)

__getitem__

这个方法是Dataset类的核心,当此方法被Dataloader调用,请求特定idx的数据时,Dataset会根据idx,读取对应的图片和标签,并对它们做出各自的transform之后,返回给Dataloader,让它把图片和标签搬运到内存.

def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

3. Dataloader类的使用方法

3.1 对数据集对象配置Dataloader

Dataset类的__getitem__方法被调用的时候,他会返回一个样本-标签对。

但是在实际的模型训练中,我们还有一些别的要求,例如:

  1. 以“小批量(minibatches)”的方式传递样本。(减少单样本噪声带来的震荡,让梯度更新的方向更加稳定)
  2. 在每个周期(epoch)对数据进行重新洗牌(reshuffle),以减少模型过拟合。
  3. 使用 Python 的多进程(multiprocessing)来加快数据检索速度。

以上的要求可以通过如下的参数设定来满足:

from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True, num_workers=5)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True, num_workers=5)
  • batch_size=64 设定批量大小为64
  • shuffle=True 指定一个epoch之后dataloader持有的索引要重新洗牌
  • num_workers=5 指定dataloader会同时开启5个进程去调用dataset的__getitem__方法

以上是Dataloader最基本的用法,不过,当你有GPU的时候,我推荐你也把下面两个参数打开:
pin_memory=True 开启锁页内存,减少CPU到GPU的数据传递延迟
persistent_workers=True 每个epoch结束后不销毁dataloader所开启的worker进程,而是接着用,这样剩下了worker的初始化时间

3.2 使用Dataloader遍历数据集

给Dataset配置好对应的Dataloader后,就可以开始用dataloader遍历它了。每次遍历都会返回一个batch_size的训练图片和训练标签对(这里就是64个)。

# Display image and label.
train_features, train_labels = next(iter(train_dataloader)) # 先从train_dataloader中获得一个迭代器,然后调用next获取其下一个元素
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

在这里插入图片描述

由于开启了shuffle=True,所以每次遍历完整个数据集后train_dataloader持有的索引会被打乱。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/91740.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/91740.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

无图形界面的CentOS 7网络如何配置

进入虚拟机输入ip addr命令:从 ip addr命令的输出可以明确看出 ​​lo和 ens33是两个不同的网络接口(网卡)lo(回环接口)​​​​作用​​:虚拟的本地回环网卡,用于本机内部通信(如 1…

机器学习之线性回归的入门学习

线性回归是一种监督学习算法,用于解决回归问题。它的目标是找到一个线性关系(一条直线或一个超平面),能够最好地描述一个或多个自变量(特征)与一个因变量(目标)之间的关系。利用回归…

2-5 Dify案例实践—利用RAG技术构建企业私有知识库

目录 一、RAG技术的定义与作用 二、RAG技术的关键组件 三、RAG技术解决的问题 四、RAG技术的核心价值与应用场景 五、如何实现利用RAG技术构建企业私有知识库 六、Dify知识库实现详解 七、创建知识库 1、创建知识库 2、上传文档 3、文本分段与清洗 4、索引方式 5、…

断路器瞬时跳闸曲线数据获取方式

断路器瞬时短路电流时,时间是在60ms内的,仿真器去直接捕获电流有效值很难。按照电流互感器的电流曲线特性,电流越大,由于互感器饱和,到达一定电流值的时候,电流会趋于平稳不再上升,ADC-I曲线由线…

技巧|SwanLab记录混淆矩阵攻略

绘制混淆矩阵(Confusion Matrix),用于评估分类模型的性能。混淆矩阵展示了模型预测结果与真实标签之间的对应关系,能够直观地显示各类别的预测准确性和错误类型。 混淆矩阵是评估分类模型性能的基础工具,特别适用于多…

HTTPS的工作原理

文章目录HTTP有什么问题?1. 明文传输,容易被窃听2. 无法验证通信方身份3. 数据完整性无法保证HTTPS是如何解决这些问题的?HTTPS的工作原理1. SSL/TLS握手2. 数据加密传输3. 完整性保护4. 连接关闭总结HTTP有什么问题? 1. 明文传输…

ECMAScript2020(ES11)新特性

概述 ECMAScript2020于2020年6月正式发布, 本文会介绍ECMAScript2020(ES11),即ECMAScript的第11个版本的新特性。 以下摘自官网:ecma-262 ECMAScript 2020, the 11th edition, introduced the matchAll method for Strings, to produce an …

机器视觉引导机器人修磨加工系统助力芯片封装

芯片制造中,劈刀同轴度精度对封装质量至关重要。传统加工在精度、效率、稳定性、良率及操作便捷性上存在不足:精度不足:劈刀同轴度需控在 0.003mm 内,传统手段难达标,致芯片封装良率低;效率良率低 &#xf…

Python编程基础与实践:Python模块与包入门实践

Python模块与包的深度探索 学习目标 通过本课程的学习,学员将掌握Python中模块和包的基本概念,了解如何导入和使用标准库中的模块,以及如何创建和组织自己的模块和包。本课程将通过实际操作,帮助学员加深对Python模块化编程的理解…

【Django】-4- 数据库存储和管理

一、关于ORM ORM 是啥呀ORM 就是用 面向对象 的方式,把数据库里的数据还有它们之间的关系映射起来~就好像给数据库和面向对象之间搭了一座小桥梁🎀对应关系大揭秘面向对象和数据库里的东西,有超有趣的对应呢👇类 → 数…

深入 Go 底层原理(四):GMP 模型深度解析

1. 引言在上一篇文章中,我们宏观地了解了 Go 的调度策略。现在,我们将深入到构成这个调度系统的三大核心组件:G、M、P。理解 GMP 模型是彻底搞懂 Go 并发调度原理的关键。本文将详细解析 G、M、P 各自的职责以及它们之间是如何协同工作的。2.…

AI赋能测试:技术变革与应用展望

AI 在测试中的应用:技术赋能与未来展望 目录 AI 在测试中的应用:技术赋能与未来展望 1. 引言 1.1 测试在软件开发中的重要性 1.2 AI 技术如何改变传统测试模式 1.3 文章结构概述 2. AI 在测试中的核心应用场景 2.1 自动化测试优化 2.1.1 智能测…

Mujoco(MuJoCo,全称Multi - Joint dynamics with Contact)一种高性能的物理引擎

Mujoco(MuJoCo,全称Multi - Joint dynamics with Contact)是一种高性能的物理引擎,主要用于模拟多体动力学系统,广泛应用于机器人仿真、运动学研究、人工智能等领域。以下是关于Mujoco仿真的一些详细介绍: …

winform-窗体应用的功能介绍(部分)

1--Point实现在窗口(Form)中一个按钮(控件)的固定位置(所在位置)一个按钮(控件)的位置一般是固定的,另一个按钮在窗口中位置是随机产生的Location属性:Location new Point(X,Y);在C#的Winform应用程序里,Button控件的鼠标悬标悬浮事件是不存在内置延迟时间的。当鼠标指针进入按…

最新Windows11系统镜像,23H2 64位ISO镜像

Windows 11 主要分为 Consumer Editions(消费者版)和 Business Editions(商业版)两大类别 。消费者版主要面向家庭和个人用户,商业版则侧重于企业和商业用户。这两大类别中存在部分重叠的版本,比如专业版和…

linux基本系统服务——DNS服务

一、DNS域名解析原理DNS&#xff0c;Domain Name System&#xff0c;域名系统&#xff1a;在互联网中由大量域名解析服务器共同提供的一整套关于“域名 <--> IP地址”信息查询的数据系统!!!! C/S架构&#xff1a;DNS服务端监听UDP 53端口&#xff08;处理客户端查询&…

数据处理和统计分析——08 apply自定义函数

1 apply()函数 1.1 apply()函数简介 Pandas提供了很多数据处理的API&#xff0c;但当提供的API不能满足需求的时候&#xff0c;需要自己编写数据处理函数, 这个时候可以使用apply()函数&#xff1b;apply()函数可以接收一个自定义函数&#xff0c;可以将DataFrame的行或列数据传…

C++冰箱管理实战代码

基于C++的冰箱管理实例 以下是一些基于C++的冰箱管理实例示例,涵盖不同功能场景,每个示例聚焦特定实现点,代码可直接扩展或整合到项目中。 示例1:基础冰箱类定义 class Refrigerator { private:int capacity;std::vector<std::string> items; public:Refrigerator(…

【Python】【数据分析】Python 数据分析与可视化:全面指南

目录1. 环境准备2. 数据处理与清洗2.1 导入数据2.2 数据清洗示例&#xff1a;处理缺失值示例&#xff1a;处理异常值2.3 数据转换3. 数据分析3.1 描述性统计3.2 分组分析示例&#xff1a;按年龄分组计算工资的平均值3.3 时间序列分析4. 数据可视化4.1 基本绘图示例&#xff1a;…

【AI】AIService(基本使用与指令定制)

【AI】AIService(基本使用与指令定制) 文章目录【AI】AIService(基本使用与指令定制)1. 简介2. AIService2.1 引入依赖2.2 编写AIService接口2.3 测试代码3. 指令定制3.1 系统提示词3.2 用户提示词1. 简介 AIService可以被视为应用程序服务层的一个组件&#xff0c;提供对应的…