使用DataLoader加载本地数据 食物分类案例

目录

一.食物分类案例

1..整合训练集测试集文档

2.导入相关的库

3.设置图片数据的格式转换

3.数据处理

4.数据打包

5.定义卷积神经网络

6.创建模型

7.训练和测试方法定义

8.损失函数和优化器

9.训练模型,测试准确率

10.测试模型


之前我们DataLoader加载的Minist手写数字集是已经封装处理好的数据,所以我们可以直接使用,现在我们学习如何使用DataLoader加载本地数据

一.食物分类案例

1..整合训练集测试集文档

我们可以看到在food_dataset文件夹有训练集和测试集两个文件夹分别存放不同食物文件夹共20种食物的照片

为了方便后面模型的对数据的读取和训练,我们将每个图片的地址和标签(标签可以自己定义)以空格分开都写在一个txt文件中,形如:

定义一个函数用来完成填写txt文件

需要注意的是os.walk()函数实在path里面漫步进入文件后还会出来进入下一个文件夹

第一次循环进入到train文件中directories为各食物名称列表,len不为0,赋值给dirs方便后面命名标签值

第二次循环时,root已经进入到第一个食物文件夹,遍历该文件夹内的所有食物照片得到路径,对root进行split()方便我们得到食物名now_dir[-1]后,再用dirs.index(now_dir[-1])获取该食物在dires中的下标值作为标签,最后将路径与标签写入文件

遍历完一个食物文件后顺便在字典中保存对应的食物和其标签

import os
dire={}
def train_test_file(root,dir):f_out=open(dir+'.txt','w')path=os.path.join(root,dir)for root,directories,files in os.walk(path):if len(directories)!=0:dirs=directorieselse:now_dir=root.split('\\')for file in files:path=os.path.join(root,file)f_out.write(path+' '+str(dirs.index(now_dir[-1]))+'\n')dire[dirs.index(now_dir[-1])]=now_dir[-1]f_out.close()

最后就是调用该函数,传入相关路径即可得到对应train.txt和test.txt

root=r'.\food_dataset'
train_dir='train'
test_dir='test'
train_test_file(root,train_dir)
train_test_file(root,test_dir)

2.导入相关的库

import torch
from torch.utils.data import Dataset,DataLoader
import numpy as np
from PIL import Image
from torchvision import transforms

3.设置图片数据的格式转换

用字典来保存训练集和测试集的相应格式转换

transforms.Compose()是将一些格式的转换组合在一起相当于一个容器

由于每个照片的大小都可能不同会影响到后面全连接层的展开输入总个数,所以这里必须统一大小

还需要将数据转化为Tensor张量类型

data_transforms={'train':transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor()]),'valid':transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor()])
}

3.数据处理

由于我们使用的是自己的数据,所以我们必须要让我们的数据集可以通过[]即索引来获取,这样DataLoader才能拿到数据并打包

定义一个类来实现上述要求,传入文件路径和上述的图片数据转换方式即可

在初始化方法__inint__中,完成一些共享空间self赋值后,我们将传入文件中的路径和标签分别保存在存储路径和存储标签的列表中

__len__方法也是必不可少的,DataLoader打包数据前会先检查总数据的大小长度够不够再完成打包

__getitem__则是我们通过索引获取信息的关键方法,只要使用索引就会调用该方法,我们在该方法中通过传进来的索引我们可以通过之前的存储列表获取图片的路径和标签,并用PIL库的Image.open()方法读取图片后根据初始化时传进来的转化格式进行转换,标签则用torch.from_numpy()方法也转化为tensor张量,最后返回图片数据和标签

class food_dataset(Dataset):#能通过索引的方式返回图片数据和标签结果def __init__(self,file_path,transform=None):self.file_path=file_pathself.imgs_paths=[]self.labels=[]self.transform=transformwith open(self.file_path) as f:samples=[x.strip().split(' ') for x in f.readlines()]for img_path,label in samples:self.imgs_paths.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs_paths)def __getitem__(self, idx):image=Image.open(self.imgs_paths[idx])if self.transform:image=self.transform(image)label=self.labels[idx]label=torch.from_numpy(np.array(label,dtype=np.int64))#label也转化为tensorreturn image,label

创建该类的对象,传入训练集和测试集的路径得到可以被DataLoader打包的数据

train_data=food_dataset(file_path='./train.txt',transform=data_transforms['train'])
test_data=food_dataset(file_path='./test.txt',transform=data_transforms['valid'])

4.数据打包

由于图片大小比较大我们就将一个图片数据打包成一个批次

train_loader=DataLoader(train_data,batch_size=1,shuffle=True)
test_loader=DataLoader(test_data,batch_size=1,shuffle=True)device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device')

5.定义卷积神经网络

只需注意必须继承nn.Module类,和我们此次训练的是彩色图片,所以第一个in_channel输入通道数是3而不是之前手写数字灰度图的1,最后全连接层的输出通道是20因为我们有20种食物

from  torch import  nn
class CNN(nn.Module):def __init__(self):super().__init__()#nn.Sequential()是将网络层组合在一起,内部不能写函数self.conv1=nn.Sequential(#1*3*256*256nn.Conv2d(in_channels=3,#输入通道数out_channels=8,kernel_size=5,stride=1,padding=2),#1*8*256*256nn.ReLU(),nn.MaxPool2d(kernel_size=2)#1*8*128*128)self.conv2 = nn.Sequential(nn.Conv2d(8,16,5,1,2),#1*16*128*128nn.ReLU(),nn.Conv2d(16,32,5,1,2),#1*32*128*128nn.ReLU(),nn.MaxPool2d(kernel_size=2)##1*32*64*64)self.conv3 = nn.Sequential(nn.Conv2d(32,64,5,1,2),#1*64*64*64nn.ReLU(),nn.Conv2d(64, 64, 5, 1, 2),#1*64*64*64nn.ReLU())self.flatten=nn.Flatten()self.out=nn.Linear(64*64*64,20)def forward(self,x):x=self.conv1(x)x=self.conv2(x)x=self.conv3(x)# x=x.view(x.size(0),-1)x=self.flatten(x)output=self.out(x)return output

6.创建模型

model=CNN().to(device)

7.训练和测试方法定义

与之前手写数字的方法并无任何不同

def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num=1for X,y in dataloader:X,y=X.to(device),y.to(device)pred=model(X)loss=loss_fn(pred,y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value=loss.item()if batch_size_num % 100 == 0:print(f'loss:{loss_value:>7f} [number:{batch_size_num}]')batch_size_num += 1
def test(dataloader,model,loss_fn):model.eval()len_data=len(dataloader.dataset)correct,loss_sum=0,0num_batch=0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss_sum += loss_fn(pred, y).item()correct+=(pred.argmax(1)==y).type(torch.float).sum().item()num_batch+=1loss_avg=loss_sum/num_batchaccuracy=correct/len_dataprint(f'Accuracy:{100 * accuracy}%\nLoss Avg:{loss_avg}')return pred.argmax(1)

8.损失函数和优化器

多分类问题选择交叉熵损失函数

loss_fn=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.0001)

9.训练模型,测试准确率

设置20轮训练

epochs=20
for i in range(epochs):print(f'==========第{i + 1}轮训练==============')train(train_loader, model, loss_fn, optimizer)print(f'第{i + 1}轮训练结束')test(test_loader,model,loss_fn)

准确率很低,后续会改进

10.测试模型

自己输入一个照片的路径通过模型来判断类别,我们需要将照片数据用上面的格式转化处理,需要注意的是我们必须手动为其添加batch维度,因为这里没用Dataloader加载数据不会自动添加batch维度

pred.argmax(1).item()是获取前向传播之后输出的最大概率的标签,.item()是将其转化为可读的形式

path=input('请输入一个图片地址: ')
image=Image.open(path)
image=data_transforms['valid'](image)
image=image.unsqueeze(0).to(device)#添加batch维度
# 注意使用DataLoader加载数据时,它会自动为批量数据添加 batch 维度
model.eval()
with torch.no_grad():pred=model.forward(image)label=pred.argmax(1).item()print('该图片是: '+dire[label])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/95390.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/95390.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

从零开始的python学习——函数(2)

ʕ • ᴥ • ʔ づ♡ど 🎉 欢迎点赞支持🎉 个人主页:励志不掉头发的内向程序员; 专栏主页:python学习专栏; 文章目录 前言 一、变量作用域 二、函数执行过程 三、链式调用 四、嵌套调用 五、函数递归 六、…

RAG 的完整流程是怎么样的?

RAG(检索增强生成)的完整流程可分为5个核心阶段:数据准备:清洗文档、分块处理(如PDF转文本切片);向量化:使用嵌入模型(如BERT、BGE)将文本转为向量&#xff1…

研发文档版本混乱的根本原因是什么,怎么办

研发文档版本混乱的根本原因通常包括缺乏统一的版本控制制度、团队协作不畅、文档管理工具使用不当以及项目需求频繁变化等因素。这些问题使得研发团队在日常工作中容易出现文档版本混乱的情况,导致信息的不一致性、沟通不畅以及开发进度的延误。为了解决这一问题&a…

ChartView的基本使用

Qt ChartView(准确类名 QChartView)是 Qt Charts 模块里最常用的图表显示控件。一句话概括:“它把 QChart 画出来,并自带缩放、平移、抗锯齿等交互能力”。QML ChartView 简介(一句话先记住:ChartView 是 Q…

系统扩展策略

1、核心指导思想:扩展立方体 在讨论具体策略前,先了解著名的扩展立方体(Scale Cube),它定义了三种扩展维度: X轴:水平复制(克隆) 策略:通过负载均衡器&#…

HBuilder X 4.76 开发微信小程序集成 uview-plus

简介 本文记录了在HBuilder中创建并配置uni-app项目的完整流程。 首先创建项目并测试运行,确认无报错后添加uView-Plus组件库。 随后修改了main.js、uni.scss、App.vue等核心文件,配置manifest.json并安装dayjs、clipboard等依赖库。 通过调整vite.c…

第4章:内存分析与堆转储

本章概述内存分析是 Java 应用性能调优的核心环节之一。本章将深入探讨如何使用 VisualVM 进行内存分析,包括堆内存监控、堆转储生成与分析、内存泄漏检测以及内存优化策略。通过本章的学习,你将掌握识别和解决内存相关问题的专业技能。学习目标理解 Jav…

面经分享一:分布式环境下的事务难题:理论边界、实现路径与选型逻辑

一、什么是分布式事务? 分布式事务是指事务的参与者、支持事务的服务器、资源服务器以及事务管理器分别位于不同的分布式系统的不同节点之上。 一个典型的例子就是跨行转账: 用户从银行A的账户向银行B的账户转账100元。 这个操作包含两个步骤: 从A账户扣减100元。 向B账户…

C++的演化历史

C是一门这样的编程语言: 兼顾底层计算机硬件系统和高层应用抽象机制从实际问题出发,注重零成本抽象、性能、可移植性、与C兼容语言特性和细节很多,学习成本较高,是一门让程序员很难敢说精通的语言 C是自由的,支持5种…

Qt6实现绘图工具:12种绘图工具全家桶!这个项目满足全部2D场景

项目概述 一个基于Qt框架开发的专业绘图工具,实现了完整的2D图形绘制、编辑和管理功能。该项目采用模块化设计,包含图形绘制、图层管理、命令模式撤销重做、用户界面等多个子系统,是学习现代C++和Qt框架的最佳实践。 核心功能特性 12种专业绘图工具 多图层绘制系统 完整的…

Linux驱动开发学习笔记

第1章 Linux驱动开发的方式mmap映射型设计方法。【不推荐】将芯片上的物理地址映射到用户空间的虚拟地址上,用户操作虚拟地址来操作硬件。使用文件操作集(file_operatiopns)设计方法。【极致推荐】platfrom总线型设置方法。【比较流行】设备树。【推荐】第2章 Linux…

mac中进行适用于IOS的静态库构建

前沿: 进行C开发完成之后,需要将代码编译成静态库,并且在IOS的手机系统中执行,因此记录该实现过程. 1主要涉及内容 1.1 整体文件架构 gongyonglocalhost Attention % tree -L 2 . ├── build │ ├── __.SYMDEF │ ├── cmake_install.cmake │ ├── CMakeCache…

C++二维数组的前缀和

C二维数组的前缀和的方法很简单&#xff0c;可以利用公式res[i][j]arr[i][j]res[i-1][j]prefix[i][j-1]-res[i-1][j-1]。输入4 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16输出1 3 6 10 6 14 24 36 15 33 54 78 28 60 96 136#include<bits/stdc.h> using namespace std; int…

Wifi开发上层学习1:实现一个wifi搜索以及打开的app

Wifi开发上层学习1&#xff1a;实现一个wifi搜索以及打开的app 文章目录Wifi开发上层学习1&#xff1a;实现一个wifi搜索以及打开的app背景demo实现1.添加系统权限以及系统签名2.布局配置3.逻辑设计3.1 wifi开关的实现3.2 wifi扫描功能3.3 连接wifi总结一、WiFi 状态控制接口二…

【DSP28335 入门教程】定时器中断:为你的系统注入精准的“心跳”

大家好&#xff0c;欢迎来到 DSP28335 的核心精讲系列。我们已经掌握了如何通过外部中断来响应“外部事件”&#xff0c;但系统内部同样需要一个精准的节拍器来处理“内部周期性任务”。单纯依靠 DELAY_US() 这样的软件延时&#xff0c;不仅精度差&#xff0c;而且会在延时期间…

从零开始:用代码解析区块链的核心工作原理

区块链技术被誉为信任的机器&#xff0c;它正在重塑金融、供应链、数字身份等众多领域。但对于许多开发者来说&#xff0c;它仍然像一个神秘的黑盒子。今天&#xff0c;我们将抛开炒作的泡沫&#xff0c;深入技术本质&#xff0c;用大约100行Python代码构建一个简易的区块链&am…

网络通信IP细节

目录 1.通信的NAT技术 2.代理服务器 3.内网穿透和内网打洞 1.通信的NAT技术 NAT技术产生的背景是我们为了解决IPV4不够用的问题&#xff0c;NAT在通信的时候可以对IP将私网IP转化为公网IP&#xff0c;全局IP要求唯一&#xff0c;但是私人IP不是唯一的。 将报文发给路由器进行…

国内真实的交换机、路由器和分组情况

一、未考虑拥挤情况理想状态的网络通信 前面我对骨干网&#xff1a; 宜春城区SDH网图分析-CSDN博客 数据链路层MAC传输&#xff1a; 无线通信网卡底层原理&#xff08;Inter Wi-Fi AX201&#xff09;_ax201ngw是cnvio转pci-e-CSDN博客 物理层、数据链路层、网络层及传输层…

atomic常用类方法

Java中的java.util.concurrent.atomic包提供了多种原子操作工具类&#xff0c;以下是核心类及其方法&#xff1a;‌1. AtomicBoolean‌‌方法‌&#xff1a;get()&#xff1a;获取当前值set(boolean newValue)&#xff1a;强制设置值compareAndSet(boolean expect, boolean upd…

算法题打卡力扣第3题:无重复字符的最长子串(mid)

文章目录题目描述解法一&#xff1a;暴力解解法二&#xff1a;滑动窗口题目描述 解法一&#xff1a;暴力解 遍历每一个可能的子串&#xff0c;然后逐一判断每个子串中是否有重复字符。 具体步骤&#xff1a; 使用两层嵌套循环来生成所有子串的起止位置&#xff1a; 外层循环 i…