小白的进阶之路系列之四----人工智能从初步到精通pytorch自定义数据集下

本篇涵盖的内容

在之前的文章中,我们已经讨论了如何获取数据,转换数据以及如何准备自定义数据集,本篇文章将涵盖更加深入的问题,希望通过详细的代码示例,帮助大家了解PyTorch自定义数据集是如何应对各种复杂实际情况中,数据处理的。

更加详细的,我们将讨论下面一些内容:

主题内容
7 Model 0:没有数据增强的TinyVGG到这个阶段,我们已经准备好了数据,让我们建立一个能够拟合数据的模型。我们还将创建一些训练和测试函数来训练和评估我们的模型。
8 探索损失曲线损失曲线是观察你的模型如何训练/改进的好方法。它们也是一种很好的方法来判断你的模型是过拟合还是欠拟合。
9 Model 1:带数据增强功能的TinyVGG到目前为止,我们已经尝试了一个没有数据增强的模型?
10 比较模型结果让我们比较不同模型的损失曲线,看看哪个表现更好,并讨论一些改进性能的选项。
11 对自定义图像进行预测我们的模型是在披萨、牛排和寿司图像的数据集上训练的。在本节中,我们将介绍如何使用我们训练好的模型来预测现有数据集之外的图像。

7 Model 0:没有数据增强的TinyVGG

好了,我们已经看到了如何把数据从文件夹里的图像变成变换后的张量。

现在让我们构建一个计算机视觉模型,看看我们是否可以将图像分类为披萨、牛排或寿司。

首先,我们将从一个简单的变换开始,仅将图像大小调整为(64,64)并将它们转换为张量。

7.1 为模型0创建转换和加载数据

# Create simple transform
simple_transform = transforms.Compose([ transforms.Resize((64, 64)),transforms.ToTensor(),
])

很好,现在我们有了一个简单的变换,让我们

  • 加载数据,首先使用torchvision.datasets.ImageFolder()将每个训练和测试文件夹转换为Dataset

  • 然后使用torch.utils.data.DataLoader())转换为数据加载器。

  • 我们将把batch_size=32和num_workers设置为机器上尽可能多的cpu(这取决于您使用的机器)。

# 1. Load and transform data
from torchvision import datasets
train_data_simple = datasets.ImageFolder(root=train_dir, transform=simple_transform)
test_data_simple = datasets.ImageFolder(root=test_dir, transform=simple_transform)# 2. Turn data into DataLoaders
import os
from torch.utils.data import DataLoader# Setup batch size and number of workers 
BATCH_SIZE = 32
NUM_WORKERS = os.cpu_count()
print(f"Creating DataLoader's with batch size {BATCH_SIZE} and {NUM_WORKERS} workers.")# Create DataLoader's
train_dataloader_simple = DataLoader(train_data_simple, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)test_dataloader_simple = DataLoader(test_data_simple, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)print(train_dataloader_simple, test_dataloader_simple)

输出为:

Creating DataLoader's with batch size 32 and 16 workers.
<torch.utils.data.dataloader.DataLoader object at 0x0000024974F734D0> <torch.utils.data.dataloader.DataLoader object at 0x0000024974F07A80>

很好dataloader已经创建好了,现在让我们设立模型。

7.2创建TinyVGG模型类

在上一篇文章中,我们使用了来自CNN解释器网站的TinyVGG模型。

让我们重新创建相同的模型,只不过这次我们将使用彩色图像而不是灰度图像(对于RGB像素,in_channels=3而不是in_channels=1)。

class TinyVGG(nn.Module):"""Model architecture copying TinyVGG from: https://poloclub.github.io/cnn-explainer/"""def __init__(self, input_shape: int, hidden_units: int, output_shape: int) -> None:super().__init__()self.conv_block_1 = nn.Sequential(nn.Conv2d(in_channels=input_shape, out_channels=hidden_units, kernel_size=3, # how big is the square that's going over the image?stride=1, # defaultpadding=1), # options = "valid" (no padding) or "same" (output has same shape as input) or int for specific number nn.ReLU(),nn.Conv2d(in_channels=hidden_units, out_channels=hidden_units,kernel_size=3,stride=1,padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2,stride=2) # default stride value is same as kernel_size)self.conv_block_2 = nn.Sequential(nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(2))self.classifier = nn.Sequential(nn.Flatten(),# Where did this in_features shape come from? # It's because each layer of our network compresses and changes the shape of our input data.nn.Linear(in_features=hidden_units*16*16,out_features=output_shape))def forward(self, x: torch.Tensor):x = self.conv_block_1(x)# print(x.shape)x = self.conv_block_2(x)# print(x.shape)x = self.classifier(x)# print(x.shape)return x# return self.classifier(self.conv_block_2(self.conv_block_1(x))) # <- leverage the benefits of operator fusiontorch.manual_seed(42)
model_0 = TinyVGG(input_shape=3, # number of color channels (3 for RGB) hidden_units=10, output_shape=len(train_data.classes)).to(device)
print(model_0)

输出为:

TinyVGG((conv_block_1): Sequential((0): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU()(2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU()(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(conv_block_2): Sequential((0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU()(2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU()(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(classifier): Sequential((0): Flatten(start_dim=1, end_dim=-1)(1): Linear(in_features=2560, out_features=3, bias=True)</

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/news/907328.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

DeepSeek实战:打造智能数据分析与可视化系统

DeepSeek实战:打造智能数据分析与可视化系统 1. 数据智能时代:DeepSeek数据分析系统入门 在数据驱动的决策时代,智能数据分析系统正成为企业核心竞争力。本节将使用DeepSeek构建一个从数据清洗到可视化分析的全流程智能系统。 1.1 系统核心功能架构 class DataAnalysisS…

力扣100题---字母异位词分组

1.字母异位词分组 给你一个字符串数组&#xff0c;请你将 字母异位词 组合在一起。可以按任意顺序返回结果列表。 字母异位词 是由重新排列源单词的所有字母得到的一个新单词。 方法一&#xff1a;字母排序 class Solution {public List<List<String>> groupAnagr…

使用子查询在 SQL Server 中进行数据操作

在 SQL Server 中&#xff0c;子查询&#xff08;Subquery&#xff09;是一种在查询中嵌套另一个查询的技术&#xff0c;可以用来执行复杂的查询、过滤数据或进行数据计算。子查询通常被用在 SELECT、INSERT、UPDATE 或 DELETE 语句中&#xff0c;可以帮助我们高效地解决问题。…

Flask集成pyotp生成动态口令

Python中的pyotp模块是一个用于生成和验证一次性密码&#xff08;OTP&#xff09;的库&#xff0c;支持基于时间&#xff08;TOTP&#xff09;和计数器&#xff08;HOTP&#xff09;的两种主流算法。它遵循RFC 4226&#xff08;HOTP&#xff09;和RFC 6238&#xff08;TOTP&…

触控精灵 ADB运行模式填写电脑端IP教程

•ADB模式&#xff0c;如果你手机已经root则可以直接运行&#xff0c;无需安装电脑端。 •ADB模式&#xff0c;如果你手机没有root&#xff0c;那你可以windows电脑下载【极限投屏】软件&#xff0c;然后你的手机和电脑的网络要同一个wifi&#xff0c;然后把你电脑的ip地址填写…

【Python】 -- 趣味代码 - 佩奇

文章目录 文章目录 00 佩奇程序设计框架1. 绘图设置2. 绘制卡通人物的各个部分3. 主程序总结01 佩奇程序设计00 佩奇程序设计框架 这段代码使用 turtle 模块绘制了一个粉色的卡通人物图像,主要功能包括绘制鼻子、头、耳朵、眼睛、腮、嘴、身体、手、脚和尾巴等部分。代码的主…

uniapp-商城-69-shop(2-商品列表,点击商品展示,商品的详情, vuex的使用,rich-text使用)

页面中将我们的数据进行了罗列,对于单个数据的展示,还需要进行开发,这里使用了点击商品后,进行弹窗展示。 同样这里用一个组件来进行实现该弹窗的展示。 本文介绍了商品详情弹窗的实现方案。主要采用Vuex进行状态管理,通过几个关键组件协同工作: 商品列表组件productItem…

C# Datatable筛选过滤各方式详解

在C#中&#xff0c;DataTable提供了多种筛选过滤数据的方法&#xff0c;以下是常用的几种方式及其特点&#xff1a; 1. ‌Select方法筛选‌ 这是最基础的筛选方式&#xff0c;支持类似SQL的表达式语法 // 单条件筛选 DataRow[] rows dt.Select("Age > 25");// …

计算机网络中的路由算法:互联网的“路径规划师”

计算机网络中的路由算法&#xff1a;互联网的“路径规划师” 当你打开浏览器&#xff0c;输入 www.example.com 并敲下回车&#xff0c;数据会从你的电脑出发&#xff0c;穿越一个个路由器&#xff0c;最终抵达目标服务器。这一路上&#xff0c;数据包是怎么知道该走哪条路的&…

硬件工程师笔记——三极管Multisim电路仿真实验汇总

目录 1 三极管基础 更多电子器件基础知识汇总链接 1.1 工作原理 NPN型三极管的工作原理 PNP型三极管的工作原理 1.2 三极管的特性曲线 输入特性曲线 理想和现实输出特性 三极管的主要参数包括&#xff1a; 2 三极管伏安特性 2.1 伏安特性仿真 Multisim使用说明链接…

Linux 进阶命令篇

一、Linux 系统软件安装命令 &#xff08;一&#xff09;Ubuntu 系统&#xff08;基于 Debian&#xff09; apt &#xff1a;是 Ubuntu 系统中常用的包管理工具&#xff0c;可以自动处理软件依赖关系。 安装命令格式 &#xff1a;sudo apt install 软件名 示例 &#xff1a;…

LVS-DR 负载均衡群集

目录 一、LVS-DR集群 1、LVS-DR 工作原理 2、数据包流向分析 3、LVS-DR 模式特点 二、直接路由模式&#xff08;LVS-DR&#xff09; 1、准备案例环境 2、配置负载调度器&#xff08;101&#xff09; &#xff08;1&#xff09;配置虚拟IP 地址&#xff08;VIP&#xff…

提升 GitHub Stats 的 6 个关键策略

哈哈&#xff0c;GitHub 的 “B-” 评级 其实是个玄学问题&#xff0c;但确实有一些 快速提升的技巧&#xff01;你的数据看起来 提交数&#xff08;147&#xff09;和 PR&#xff08;9&#xff09;不算少&#xff0c;但 Stars&#xff08;21&#xff09;和贡献项目数&#xff…

常见的垃圾回收算法原理及其模拟实现

1.标记 - 清除&#xff08;Mark - Sweep&#xff09;算法&#xff1a; 这是一种基础的垃圾回收算法。首先标记所有可达的对象&#xff0c;然后清除未被标记的对象。 缺点是会产生内存碎片。 原理&#xff1a; 如下图分配一段内存&#xff0c;假设已经存储上数据了 标记所有…

卷积神经网络(CNN):原理、架构与实战

卷积神经网络&#xff08;CNN&#xff09;&#xff1a;原理、架构与实战 卷积神经网络&#xff08;Convolutional Neural Network, CNN&#xff09;是深度学习领域的一项重要突破&#xff0c;特别擅长处理具有网格结构的数据&#xff0c;如图像、音频和视频。自 2012 年 AlexN…

RabbitMQ 集群与高可用方案设计(二)

三、为什么需要集群与高可用方案 &#xff08;一&#xff09;业务需求驱动 随着业务的快速发展和用户量的急剧增长&#xff0c;系统面临的挑战也日益严峻。在这种情况下&#xff0c;对消息队列的可靠性、吞吐量和负载均衡能力提出了更高的要求&#xff0c;而单机部署的 Rabbi…

《ChatGPT o3抗命:AI失控警钟还是成长阵痛?》

ChatGPT o3 “抗命” 事件起底 在人工智能的飞速发展进程中&#xff0c;OpenAI 于 2025 年推出的 ChatGPT o3 推理模型&#xff0c;犹如一颗重磅炸弹投入了技术的海洋&#xff0c;激起千层浪。它被视为 “推理模型” 系列的巅峰之作&#xff0c;承载着赋予 ChatGPT 更强大问题解…

RK3568DAYU开发板-平台驱动开发:I2C驱动(原理、源码、案例分析)

1、程序介绍 本程序是基于OpenHarmony标准系统编写的平台驱动案例&#xff1a;I2C 系统版本:openharmony5.0.0 开发板:dayu200 编译环境:ubuntu22 部署路径&#xff1a; //sample/04_platform_i2c 2、基础知识 2.1、I2C简介 I2C&#xff08;Inter Integrated Circuit&a…

在UniApp中开发微信小程序实现图片、音频和视频下载功能

随着微信小程序的迅猛发展&#xff0c;越来越多的开发者选择通过UniApp框架来进行跨平台应用开发。UniApp能够让开发者在一个代码库中同时发布iOS、Android和小程序等多平台应用。而在实际开发过程中&#xff0c;很多应用都需要实现一些常见的下载功能&#xff0c;例如图片、音…

鸿蒙5.0项目开发——接入有道大模型翻译

鸿蒙5.0项目开发——接入有道大模型翻译 【高心星出品】 项目效果图 项目功能 文本翻译功能 支持文本输入和翻译结果显示 使用有道翻译API进行翻译 支持自动检测语言&#xff08;auto&#xff09; 支持双向翻译&#xff08;源语言和目标语言可互换&#xff09; 文本操作…