【数据准备】——深度学习.全连接神经网络

目录

1 数据加载器

1.1 构建数据类

1.1.1 Dataset类

1.1.2 TensorDataset类

1.2 数据加载器

2 数据加载案例

2.1 加载csv数据集

2.2 加载图片数据集

2.3 加载官方数据集

2.4 pytorch实现线性回归


1 数据加载器

分数据集和加载器2个步骤~

1.1 构建数据类

1.1.1 Dataset类

Dataset是一个抽象类,是所有自定义数据集应该继承的基类。它定义了数据集必须实现的方法。

必须实现的方法

  1. __len__: 返回数据集的大小

  2. __getitem__: 支持整数索引,返回对应的样本

在 PyTorch 中,构建自定义数据加载类通常需要继承 torch.utils.data.Dataset 并实现以下几个方法:

  1. __init__ 方法 用于初始化数据集对象:通常在这里加载数据,或者定义如何从存储中获取数据的路径和方法。
def __init__(self, data, labels):self.data = dataself.labels = labels

      2.__len__ 方法 返回样本数量:需要实现,以便 Dataloader加载器能够知道数据集的大小。

def __len__(self):return len(self.data)

      3.__getitem__ 方法 根据索引返回样本:将从数据集中提取一个样本,并可能对样本进行预处理或变换。

def __getitem__(self, index):sample = self.data[index]label = self.labels[index]return sample, label

如果你需要进行更多的预处理或数据变换,可以在 __getitem__ 方法中添加额外的逻辑。

import torch
from torch.utils.data import Dataset,DataLoader,TensorDataset
from sklearn.datasets import make_regression
from torch import nn,optim
# 自定义数据集类
# 1.继承dataset类
# 2.实现__init__方法,初始化外部的数据
# 3.实现__len__方法,用来返回数据集的长度
# 4.实现__getitem__方法,根据索引获取对应位置的数据
class MyDataset(Dataset):def __init__(self,data,labels):self.data=dataself.labels=labelsdef __len__(self):return len(self.data)def __getitem__(self, index):sample=self.data[index]label=self.labels[ index]return sample,labeldef test01():x=torch.randn(100,20)y=torch.randn(100,1)dataset=MyDataset(x,y)print( dataset[0])if __name__=='__main__':test01()

1.1.2 TensorDataset类

TensorDatasetDataset的一个简单实现,它封装了张量数据,适用于数据已经是张量形式的情况。

特点

  1. 简单快捷:当数据已经是张量形式时,无需自定义Dataset类

  2. 多张量支持:可以接受多个张量作为输入,按顺序返回

  3. 索引一致:所有张量的第一个维度必须相同,表示样本数量

def test03():torch.manual_seed(0)# 创建特征张量和标签张量features = torch.randn(100, 5)  # 100个样本,每个样本5个特征labels = torch.randint(0, 2, (100,))  # 100个二进制标签# 创建TensorDatasetdataset = TensorDataset(features, labels)# 使用方式与自定义Dataset相同print(len(dataset))  # 输出: 100print(dataset[0])  # 输出: (tensor([...]), tensor(0))

1.2 数据加载器

在训练或者验证的时候,需要用到数据加载器批量的加载样本。

DataLoader 是一个迭代器,用于从 Dataset 中批量加载数据。它的主要功能包括:

  • 批量加载:将多个样本组合成一个批次。

  • 打乱数据:在每个 epoch 中随机打乱数据顺序。

  • 多线程加载:使用多线程加速数据加载。

创建DataLoader:

# 创建 DataLoader
dataloader = DataLoader(dataset,          # 数据集batch_size=10,    # 批量大小shuffle=True,     # 是否打乱数据num_workers=2     # 使用 2 个子进程加载数据
)

遍历:

# 遍历 DataLoader
# enumerate返回一个枚举对象(iterator),生成由索引和值组成的元组
for batch_idx, (samples, labels) in enumerate(dataloader):print(f"Batch {batch_idx}:")print("Samples:", samples)print("Labels:", labels)

案例:

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader# 定义数据加载类
class CustomDataset(Dataset):#略......def test01():# 简单的数据集准备data_x = torch.randn(666, 20, requires_grad=True, dtype=torch.float32)data_y = torch.randn(data_x.size(0), 1, dtype=torch.float32)dataset = CustomDataset(data_x, data_y)# 构建数据加载器data_loader = DataLoader(dataset, batch_size=8, shuffle=True)for i, (batch_x, batch_y) in enumerate(data_loader):print(batch_x, batch_y)breakif __name__ == "__main__":test01()

2 数据加载案例

通过一些数据集的加载案例,真正了解数据类及数据加载器。

2.1 加载csv数据集

import torch
from torch import nn
from torch.utils.data import TensorDataset,DataLoader
import pandas as pd
from torchvision import datasets,transformsdef build_csv_data(filepath):df=pd.read_csv(filepath)df.drop(["学号","姓名"],axis=1,inplace=True)# print(df.head())samples=df.iloc[...,:-1]labels=df.iloc[...,-1]# print(samples.head())# print(labels.head())samples=torch.tensor(samples.values)labels=torch.tensor(labels.values)# print(samples)# print(labels)return samples,labelsdef load_csv_data():filepath="./datasets/大数据答辩成绩表.csv"samples,labels=build_csv_data(filepath)dataset=TensorDataset(samples,labels)dataloader=DataLoader(dataset=dataset,batch_size=1,shuffle=True)for sample,label in dataloader:print(sample)print(label)breakif __name__=="__main__":load_csv_data()

2.2 加载图片数据集

import torch
from torch import nn
from torch.utils.data import TensorDataset,DataLoader
import pandas as pd
from torchvision import datasets,transformsdef load_img_data():path="./datasets/animals"transform=transforms.Compose([# 图片缩放 把所有图片缩放到同一尺寸transforms.Resize((224,224)),# 把PIL图片或numpy数组转为张量transforms.ToTensor(),])dataset=datasets.ImageFolder(root=path,transform=transform)dataloader=DataLoader(dataset=dataset,batch_size=4,shuffle=True)for x,y in dataloader:print(x.shape)print(x)print(y)breakif __name__=="__main__":load_img_data()

2.3 加载官方数据集

在 PyTorch 中官方提供了一些经典的数据集,如 CIFAR-10、MNIST、ImageNet 等,可以直接使用这些数据集进行训练和测试。

数据集:Datasets — Torchvision 0.22 documentation

常见数据集:

  • MNIST: 手写数字数据集,包含 60,000 张训练图像和 10,000 张测试图像。

  • CIFAR10: 包含 10 个类别的 60,000 张 32x32 彩色图像,每个类别 6,000 张图像。

  • CIFAR100: 包含 100 个类别的 60,000 张 32x32 彩色图像,每个类别 600 张图像。

  • COCO: 通用对象识别数据集,包含超过 330,000 张图像,涵盖 80 个对象类别。

torchvision.transforms 和 torchvision.datasets 是 PyTorch 中处理计算机视觉任务的两个核心模块,它们为图像数据的预处理和标准数据集的加载提供了强大支持。

transforms 模块提供了一系列用于图像预处理的工具,可以将多个变换组合成处理流水线。

datasets 模块提供了多种常用计算机视觉数据集的接口,可以方便地下载和加载。

参考如下:

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms, datasetsdef test():transform = transforms.Compose([transforms.ToTensor(),])# 训练数据集data_train = datasets.MNIST(root="./data",train=True,download=True,transform=transform,)trainloader = DataLoader(data_train, batch_size=8, shuffle=True)for x, y in trainloader:print(x.shape)print(y)break# 测试数据集data_test = datasets.MNIST(root="./data",train=False,download=True,transform=transform,)testloader = DataLoader(data_test, batch_size=8, shuffle=True)for x, y in testloader:print(x.shape)print(y)breakdef test006():transform = transforms.Compose([transforms.ToTensor(),])# 训练数据集data_train = datasets.CIFAR10(root="./data",train=True,download=True,transform=transform,)trainloader = DataLoader(data_train, batch_size=4, shuffle=True, num_workers=2)for x, y in trainloader:print(x.shape)print(y)break# 测试数据集data_test = datasets.CIFAR10(root="./data",train=False,download=True,transform=transform,)testloader = DataLoader(data_test, batch_size=4, shuffle=False, num_workers=2)for x, y in testloader:print(x.shape)print(y)breakif __name__ == "__main__":test()test006()

2.4 pytorch实现线性回归

import torch
from torch.utils.data import Dataset,DataLoader,TensorDataset
from sklearn.datasets import make_regression
from torch import nn,optim# pytorch实现线性回归
def build_data(in_features,out_features):bias=14.5# 生成的数据需要转换成tensorx,y,coef=make_regression(n_samples=1000,n_features=in_features,n_targets=out_features,coef=True,bias=bias,noise=0.1,random_state=42)x=torch.tensor(x,dtype=torch.float32)y=torch.tensor(y,dtype=torch.float32).view(-1,1) # 注意要把y转换成二维数组(本来是一维) 否则会报警告coef=torch.tensor(coef,dtype=torch.float32)bias=torch.tensor(bias,dtype=torch.float32)return x,y,coef,biasdef train():# 数据准备in_features=10out_features=1x,y,coef,bias=build_data(in_features,out_features)dataset=TensorDataset(x,y)dataloader=DataLoader(dataset=dataset,batch_size=100,shuffle=True)# 定义网络模型model=nn.Linear(in_features,out_features)# 定义损失函数criterion=nn.MSELoss()# 优化器opt=optim.SGD(model.parameters(),lr=0.1)epochs=20for epoch in range(epochs):for tx,ty in dataloader:y_pred=model(tx)loss=criterion(y_pred,ty)opt.zero_grad()loss.backward()opt.step()print(f'epoch:{epoch},loss:{loss.item()}')# detach()、data:作用是将计算图中的weight参数值获取出来print(f"真实权重:{coef.numpy()},训练权重:{model.weight.detach().numpy()}") # datach()相当于把weight从计算图中抽离出来print(f"真实偏置:{bias},训练偏置:{model.bias.item()}")if __name__=='__main__':train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/91552.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/91552.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

健康生活,从细节开始

健康生活,从细节开始在当今快节奏的生活中,健康逐渐成为人们关注的焦点。拥有健康的身体,才能更好地享受生活、追求梦想。那么,如何才能拥有健康呢?这就需要我们从生活中的点滴细节入手,培养良好的生活习惯…

javax.servlet.http.HttpServletResponse;API导入报错解决方案

javax.servlet.http.HttpServletResponse;API导入报错解决方案与Postman上传下载文件验证 1. 主要错误:缺少 Servlet API 依赖 错误信息显示 javax.servlet.http 包不存在。这是因为你的项目缺少 Servlet API 依赖。 解决方案: 如果你使用的是 Maven&…

reids依赖删除,但springboot仍然需要redis参数才能启动

背景:项目需要删除redis。我删除完项目所有配置redis的依赖,启动报错。[2025-07-17 15:08:37:561] [DEBUG] [restartedMain] DEBUG _.s.w.s.H.Mappings - [detectHandlerMethods,295] [] - o.s.b.a.w.s.e.BasicErrorController:{ [/error]}: error(HttpS…

【前端】CSS类命名规范指南

在 CSS 中,合理且规范的 class 命名格式对项目的可维护性和协作效率至关重要。以下是主流的 class 命名规范和方法论:一、核心命名原则语义化命名:描述功能而非样式 ✅ .search-form(描述功能)❌ .red-text&#xff08…

C++网络编程 4.UDP套接字(socket)编程示例程序

以下是基于UDP协议的完整客户端和服务器代码。UDP与TCP的核心区别在于无连接特性&#xff0c;因此代码结构会更简单&#xff08;无需监听和接受连接&#xff09;。 UDP服务器代码&#xff08;udp_server.cpp&#xff09; #include <iostream> #include <sys/socket.h&…

King’s LIMS:实验室数字化转型的智能高效之选

实验室数字化转型不仅是技术升级&#xff0c;更是管理理念和工作方式的革新。LIMS系统作为这一转型的核心工具&#xff0c;能够将分散的实验数据转化为可分析、可复用的资产&#xff0c;为科研决策提供支持&#xff1b;规范检测流程&#xff0c;减少人为干预&#xff0c;确保结…

【力扣 中等 C】97. 交错字符串

目录 题目 解法一 题目 待添加 解法一 bool isInterleave(char* s1, char* s2, char* s3) {const int len1 strlen(s1);const int len2 strlen(s2);const int len3 strlen(s3);if (len1 len2 ! len3) {return false;}if (len1 < len2) {return isInterleave(s2, s1,…

Class9简洁实现

Class9简洁实现 %matplotlib inline import torch from torch import nn from d2l import torch as d2l# 初始化训练样本、测试样本、样本特征维度和批量大小 n_train,n_test,num_inputs,batch_size 20,100,200,5 # 设置真实权重和偏置 true_w,true_b torch.ones((num_inputs…

ELK日志分析,涉及logstash、elasticsearch、kibana等多方面应用,必看!

目录 ELK日志分析 1、下载lrzsc 2、下载源包 3、解压文件,下载elasticsearch、kibana、 logstash 4、配置elasticsearch 5、配种域名解析 6、配置kibana 7、配置logstash 8、进行测试 ELK日志分析 1、下载lrzsc [rootlocalhost ~]# hostnamectl set-hostname elk ##…

终极剖析HashMap:数据结构、哈希冲突与解决方案全解

文章目录 引言 一、HashMap底层数据结构&#xff1a;三维存储架构 1. 核心存储层&#xff08;硬件优化设计&#xff09; 2. 内存布局对比 二、哈希冲突的本质与数学原理 1. 冲突产生的必然性 2. 冲突概率公式 三、哈希冲突解决方案全景图 1. 链地址法&#xff08;Hash…

1.1.5 模块与包——AI教你学Django

1.1.5 模块与包&#xff08;Django 基础学习细节&#xff09; 模块和包是 Python 项目组织和代码复用的基础。Django 项目本质上就是由多个模块和包组成。理解和灵活运用模块与包机制&#xff0c;是写好大型项目的关键。 一、import、from-import、as 的用法 1. import 用于导入…

UE5 相机后处理材质与动态参数修改

一、完整实现流程1. 创建后处理材质材质设置&#xff1a;在材质编辑器中&#xff0c;将材质域(Material Domain)设为后处理(Post Process)设置混合位置(Blendable Location)&#xff08;如After Tonemapping&#xff09;创建标量/向量参数&#xff08;如Intensity, ColorTint&a…

Django基础(三)———模板

前言 在之前的文章中&#xff0c;视图函数只是直接返回文本&#xff0c;而在实际生产环境中其实很少这样用&#xff0c;因为实际的页面大多是带有样式的HTML代码&#xff0c;这可以让浏览器渲染出非常漂亮的页面。目前市面上有非常多的模板系统&#xff0c;其中最知名最好用的…

mysql6表清理跟回收空间

mysql6表清理跟回收空间 文章目录mysql6表清理跟回收空间1.清理表2.备份表或者备份库3.回收表空间4.查看5.验证业务1.清理表 ## 登录 C:\Program Files\MySQL\MySQL Server 5.6\bin>mysql -uroot -p Enter password: ****** Welcome to the MySQL monitor. Commands end w…

Java-74 深入浅出 RPC Dubbo Admin可视化管理 安装使用 源码编译、Docker启动

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; &#x1f680; AI篇持续更新中&#xff01;&#xff08;长期更新&#xff09; AI炼丹日志-30-新发布【1T 万亿】参数量大模型&#xff01;K…

VSCode同时支持Vue2和Vue3开发的插件指南

引言 随着Vue生态系统的演进&#xff0c;许多开发者面临着在同一开发环境中同时处理Vue 2和Vue 3项目的需求。Visual Studio Code (VSCode)作为最受欢迎的前端开发工具之一&#xff0c;其插件生态对Vue的支持程度直接影响开发效率。本文将深入探讨如何在VSCode中配置插件组合&a…

卷积神经网络CNN的Python实现

一、环境准备与库导入 在开始实现卷积神经网络之前&#xff0c;需要确保开发环境已正确配置&#xff0c;并导入必要的Python库。常用的深度学习框架有TensorFlow和PyTorch&#xff0c;本示例将基于Keras&#xff08;可使用TensorFlow后端&#xff09;进行实现&#xff0c;因为K…

js是实现记住密码自动填充功能

记住密码自动填充使用js实现记住密码功能&#xff0c;在下次打开登陆页面的时候进行获取并自动填充到页面【cookie和localStorage】使用js实现记住密码功能&#xff0c;在下次打开登陆页面的时候进行获取并自动填充到页面【cookie和localStorage】 //添加功能----记住上一个登陆…

【Java】文件编辑器

代码&#xff1a;&#xff08;SimpleEditor.java&#xff09;import java.awt.Color; import java.awt.Font; import java.awt.Insets; import java.awt.BorderLayout;import java.awt.event.ActionEvent; import java.awt.event.ActionListener;import java.io.BufferedReader…

PyTorch中torch.topk()详解:快速获取最大值索引

torch.topk(similarities, k=2).indices 是什么意思 torch.topk(similarities, k=2).indices 是 PyTorch 中用于获取张量中最大值元素及其索引的函数。在你的代码中,它的作用是从 similarities 向量里找出得分最高的2个元素的位置索引。 1. 核心功能:找出张量中最大的k个值…