PyTorch 数据加载全攻略:从自定义数据集到模型训练

目录

一、为什么需要数据加载器?

二、自定义 Dataset 类

1. 核心方法解析

2. 代码实现

三、快速上手:TensorDataset

1. 代码示例

2. 适用场景

四、DataLoader:批量加载数据的利器

1. 核心参数说明

2. 代码示例

五、实战:用数据加载器训练线性回归模型

1. 完整代码

2. 代码解析

六、总结与拓展


在深度学习实践中,数据加载是模型训练的第一步,也是至关重要的一环。高效的数据加载不仅能提高训练效率,还能让代码更具可维护性。本文将结合 PyTorch 的核心 API,通过实例详解数据加载的全过程,从自定义数据集到批量训练,带你快速掌握 PyTorch 数据处理的精髓。

一、为什么需要数据加载器?

在处理大规模数据时,我们不可能一次性将所有数据加载到内存中。PyTorch 提供了DatasetDataLoader两个核心类来解决这个问题:

  • Dataset:负责数据的存储和索引
  • DataLoader:负责批量加载、打乱数据和多线程处理

简单来说,Dataset就像一个 "仓库",而DataLoader是 "搬运工",负责把数据按批次运送到模型中进行训练。

二、自定义 Dataset 类

当我们需要处理特殊格式的数据(如自定义标注文件、特殊预处理)时,就需要自定义数据集。自定义数据集需继承torch.utils.data.Dataset,并实现三个核心方法:

1. 核心方法解析

  • __init__:初始化数据集,加载数据路径或原始数据
  • __len__:返回数据集的样本数量
  • __getitem__:根据索引返回单个样本(特征 + 标签)

2. 代码实现

import torch
from torch.utils.data import Datasetclass MyDataset(Dataset):def __init__(self, data, labels):# 初始化数据和标签self.data = dataself.labels = labelsdef __len__(self):# 返回样本总数return len(self.data)def __getitem__(self, index):# 根据索引返回单个样本sample = self.data[index]label = self.labels[index]return sample, label# 使用示例
if __name__ == "__main__":# 生成随机数据x = torch.randn(1000, 100, dtype=torch.float32)  # 1000个样本,每个100个特征y = torch.randn(1000, 1, dtype=torch.float32)   # 对应的标签# 创建自定义数据集dataset = MyDataset(x, y)print(f"数据集大小:{len(dataset)}")print(f"第一个样本:{dataset[0]}")  # 查看第一个样本

三、快速上手:TensorDataset

如果你的数据已经是 PyTorch 张量(Tensor),且不需要复杂的预处理,那么TensorDataset会是更好的选择。它是 PyTorch 内置的数据集类,能快速将特征和标签绑定在一起。

1. 代码示例

from torch.utils.data import TensorDataset, DataLoader# 生成张量数据
x = torch.randn(1000, 100, dtype=torch.float32)
y = torch.randn(1000, 1, dtype=torch.float32)# 使用TensorDataset包装数据
dataset = TensorDataset(x, y)  # 特征和标签按索引对应# 查看样本
print(f"样本数量:{len(dataset)}")
print(f"第一个样本特征:{dataset[0][0].shape}")
print(f"第一个样本标签:{dataset[0][1]}")

2. 适用场景

  • 数据已转换为 Tensor 格式
  • 不需要复杂的预处理逻辑
  • 快速搭建训练流程(如验证代码可行性)

四、DataLoader:批量加载数据的利器

有了数据集,还需要高效的批量加载工具。DataLoader可以实现:

  • 批量读取数据(batch_size
  • 打乱数据顺序(shuffle
  • 多线程加载(num_workers

1. 核心参数说明

参数作用
dataset要加载的数据集
batch_size每批样本数量(常用 32/64/128)
shuffle每个 epoch 是否打乱数据(训练时设为 True)
num_workers加载数据的线程数(加速数据读取)

2. 代码示例

# 创建DataLoader
dataloader = DataLoader(dataset=dataset,batch_size=32,      # 每批32个样本shuffle=True,       # 训练时打乱数据num_workers=2       # 2个线程加载
)# 遍历数据
for batch_idx, (batch_x, batch_y) in enumerate(dataloader):print(f"第{batch_idx}批:")print(f"特征形状:{batch_x.shape}")  # (32, 100)print(f"标签形状:{batch_y.shape}")  # (32, 1)if batch_idx == 2:  # 只看前3批break

五、实战:用数据加载器训练线性回归模型

下面结合一个完整案例,展示如何使用TensorDatasetDataLoader训练模型。我们将实现一个线性回归任务,预测生成的随机数据。

1. 完整代码

from sklearn.datasets import make_regression
import torch
from torch.utils.data import TensorDataset, DataLoader
from torch import nn, optim# 生成回归数据
def build_data():bias = 14.5# 生成1000个样本,100个特征x, y, coef = make_regression(n_samples=1000,n_features=100,n_targets=1,bias=bias,coef=True,random_state=0  # 固定随机种子,保证结果可复现)# 转换为Tensor并调整形状x = torch.tensor(x, dtype=torch.float32)y = torch.tensor(y, dtype=torch.float32).view(-1, 1)  # 转为列向量bias = torch.tensor(bias, dtype=torch.float32)coef = torch.tensor(coef, dtype=torch.float32)return x, y, coef, bias# 训练函数
def train():x, y, true_coef, true_bias = build_data()# 构建数据集和数据加载器dataset = TensorDataset(x, y)dataloader = DataLoader(dataset=dataset,batch_size=100,  # 每批100个样本shuffle=True     # 训练时打乱数据)# 定义模型、损失函数和优化器model = nn.Linear(in_features=x.size(1), out_features=y.size(1))  # 线性层criterion = nn.MSELoss()  # 均方误差损失optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降# 训练50个epochepochs = 50for epoch in range(epochs):for batch_x, batch_y in dataloader:# 前向传播y_pred = model(batch_x)loss = criterion(batch_y, y_pred)# 反向传播和参数更新optimizer.zero_grad()  # 清空梯度loss.backward()        # 计算梯度optimizer.step()       # 更新参数# 打印结果print(f"真实权重:{true_coef[:5]}...")  # 只显示前5个print(f"预测权重:{model.weight.detach().numpy()[0][:5]}...")print(f"真实偏置:{true_bias}")print(f"预测偏置:{model.bias.item()}")if __name__ == "__main__":train()

2. 代码解析

  1. 数据生成:用make_regression生成带噪声的回归数据,并转换为 PyTorch 张量。
  2. 数据集构建:用TensorDataset将特征和标签绑定,方便后续加载。
  3. 批量加载DataLoader按批次读取数据,每次训练用 100 个样本。
  4. 模型训练:线性回归模型通过梯度下降优化,最终输出预测的权重和偏置,与真实值对比。

六、总结与拓展

本文介绍了 PyTorch 中数据加载的核心工具:

  • 自定义 Dataset:灵活处理特殊数据格式
  • TensorDataset:快速包装张量数据
  • DataLoader:高效批量加载,支持多线程和数据打乱

在实际项目中,你可以根据数据类型选择合适的工具:

  • 处理图片:用ImageFolder(PyTorch 内置,支持按文件夹分类)
  • 处理文本:自定义 Dataset 读取文本文件并转换为张量
  • 大规模数据:结合num_workerspin_memory(针对 GPU 加速)

掌握数据加载是深度学习的基础,用好这些工具能让你的训练流程更高效、更易维护。快去试试用它们处理你的数据吧!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/89341.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/89341.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python--plist文件的读取

Python练习:读取Apple Plist文件 Plist文件简介 ​​定义​​:Apple公司创建的基于XML结构的文件格式​​特点​​:采用XML语法组织数据,可存储键值对、数组等结构化信息文件扩展名​​:.plist应用场景: ​​iOS系统:​…

JAVA几个注解记录

在Java中,Data、AllArgsConstructor和NoArgsConstructor是Lombok库提供的注解,用于自动生成Java类中的样板代码(如getter、setter、构造函数等),从而减少冗余代码,提高开发效率。以下是它们的详细功能和使用…

js对象简介、内置对象

对象、内置对象 jarringslee 对象 对象(object)是js的一种引用数据类型,是一种无序的数据集合“ul”(类比于数组,有序的数据集合“ol”)。 基本上等于结构体。 对象的声明 //基本方法 let 对象名 {声…

【工程篇】07:如何打包conda环境并拷贝到另一台服务器上

这是一份以名为 qwen2.5-vl 的 Conda 环境为例的详细操作手册,指导您如何将其打包并迁移至另一台服务器。操作手册:迁移 Conda 环境 qwen2.5-vl 至新服务器 本文档将提供两种有效的方法来迁移您的 qwen2.5-vl 环境。请根据您的具体需求和服务器条件选择最…

rustdesk远控电脑替代todesk,平替向日葵等软件

rustdesk网页端远控电脑docker run --restart always \ --privileged \ -p 9000:9000 \ -p 21114:21114 \ -p 21115:21115 \ -p 21116:21116 \ -p 21116:21116/udp \ -p 21117:21117 \ -p 21118:21118 \ -p 21119:21119 \ -e KEYj8muHpzr2HK00zm9D94b1UFkaJ1bEiWsyA1qxb1nOA \ …

板凳-------Mysql cookbook学习 (十二--------1)

第9章 存储例程,触发器和计划事件 326 9.0 概述 326 9.1 创建复合语句对象 329 mysql> -- 恢复默认分隔符 mysql> DELIMITER ; mysql>mysql> DROP FUNCTION IF EXISTS avg_mail_size; Query OK, 0 rows affected (0.02 sec)mysql> DELIMITER $$ mysq…

密码学系列文(3)--分组密码

一、分组密码概述分组密码是许多系统安全的一个重要组成部分,可用于构造:拟随机数生成器流密码消息认证码(MAC)和杂凑函数消息认证技术、数据完整性机构、实体认证协议以及单钥数字签字体制的核心组成部分应用中对于分组密码的要求:安全性运行…

WCDB soci 查询语句

测试代码 #pragma once #include <string> #include <vector>// Assume OperationLog is a struct representing a row in the table struct OperationLog {int id;std::string op_type;std::string op_subtype;std::string details;std::string timestamp; };clas…

lesson16:Python函数的认识

目录 一、为什么需要函数&#xff1f; 1. 拒绝重复造轮子 2. 让代码像句子一样可读 3. 隔离变化&#xff0c;降低维护成本 二、函数的定义&#xff1a;编写高质量函数的5个要素 基本语法框架 1. 函数命名的黄金法则&#xff08;PEP8规范&#xff09; 2. 不可或缺的文档…

通过轮询方式使用LoRa DTU有什么缺点?

在物联网系统中&#xff0c;DTU&#xff08;Data Transfer Unit&#xff09;通常用于通过485或M-Bus等接口抄读子设备的数据&#xff0c;并将这些数据传输到平台侧。然而&#xff0c;如果DTU采用轮询方式与平台通信&#xff0c;会带来一系列问题&#xff0c;尤其是在功耗和系统…

Syntax Error: Error: PostCSS received undefined instead of CSS string

报错&#xff1a;Syntax Error: Error: PostCSS received undefined instead of CSS string npm rebuild node-sass报错&#xff1a;npm i canvas 报错 canvas2.11.2 run install node-pre-gyp install --fallback-to-build --update-binary npm install canvas --canvas_binar…

人工智能之数学基础:概率论和数理统计在机器学习的地位

概率和统计的概念概率统计是各类学科中唯一一门专门研究随机现象的规律性的学科&#xff0c;随机现象的广泛性决定了这一学科的重要性。概率论是数学的分支&#xff0c;它研究的是如何定量描述随机现象及其规律。我们之前经常在天气软件上看到&#xff1a;“今天下雨的概率是95…

第十四章 Stream API

JAVA语言引入了一个流式Stream API,这个API对集合数据进行操作&#xff0c;类似于使用SQL执行的数据库查询&#xff0c;同样可以使用Stream API并行执行操作。Stream和Collection的区别Collection:静态的内存数据结构&#xff0c;强调的是数据。Stream API:和集合相关的计算操作…

Oracle数据库各版本间的技术迭代详解

今天我想和大家聊聊一个我们可能每天都在用&#xff0c;但未必真正了解的技术——Oracle数据库的版本。如果你是企业的IT工程师&#xff0c;可能经历过“升级数据库”的头疼&#xff1b;如果你是业务负责人&#xff0c;可能疑惑过“为什么一定要换新版本”&#xff1b;甚至如果…

论文reading学习记录3 - weekly - 模块化视觉端到端ST-P3

文章目录前言一、摘要与引言二、Related Word2.1 可解释的端到端架构2.2 鸟瞰图2.3 未来预测2.4 规划三、方法3.1 感知bev特征积累3.1.1 空间融合&#xff08;帧的对齐&#xff09;3.1.2 时间融合3.2 预测&#xff1a;双路径未来建模3.3 规划&#xff1a;先验知识的整合与提炼4…

crawl4ai--bitcointalk爬虫实战项目

&#x1f4cc; 项目目标本项目旨在自动化抓取 Bitcointalk 论坛中指定板块的帖子数据&#xff08;包括主贴和所有回复&#xff09;&#xff0c;并提取出结构化信息如标题、作者、发帖时间、用户等级、活跃度、Merit 等&#xff0c;以便进一步分析或使用。本项目只供科研学习使用…

调用 System.gc() 的弊端及修复方式

弊端分析不可控的执行时机System.gc() 仅是 建议 JVM 执行垃圾回收&#xff0c;但 JVM 可自由忽略该请求&#xff08;尤其是高负载时&#xff09;。实际回收时机不确定&#xff0c;无法保证内存及时释放。严重的性能问题Stop-The-World 停顿&#xff1a;触发 Full GC 时会暂停所…

git merge 和 git rebase 的区别

主要靠一张图&#xff1a;区别 git merge git checkout feature git merge master此时在feature上git会自动产生一个新的commit 修改的是当前分支 feature。 git rebase git checkout feature git rebase master&#xff08;在feature分支上执行&#xff0c;修改的是master分支…

Java学习--JVM(2)

JVM提供垃圾回收机制&#xff0c;其也是JVM的核心机制&#xff0c;其主要是实现自动回收不再被引用的对象所占用的内存&#xff1b;对内存进行整理&#xff0c;防止内存碎片化&#xff1b;以及对内存分配配进行管理。JVM 通过两种主要算法判断对象是否可回收&#xff1a;引用计…

用大模型(qwen)提取知识三元组并构建可视化知识图谱:从文本到图谱的完整实现

引言 知识图谱作为一种结构化的知识表示方式&#xff0c;在智能问答、推荐系统、数据分析等领域有着广泛应用。在信息爆炸的时代&#xff0c;如何从非结构化文本中提取有价值的知识并进行结构化展示&#xff0c;是NLP领域的重要任务。知识三元组&#xff08;Subject-Relation-O…