深度学习——基于卷积神经网络实现食物图像分类【2】(数据增强)

文章目录

    • 引言
    • 一、项目概述
    • 二、环境准备
    • 三、数据预处理
      • 3.1 数据增强与标准化
      • 3.2 数据集准备
    • 四、自定义数据集类
    • 五、构建CNN模型
    • 六、训练与评估
      • 6.1 训练函数
      • 6.2 评估函数
      • 6.3 训练流程
    • 七、关键技术与优化
    • 八、常见问题与解决
    • 九、完整代码
    • 十、总结

引言

本文将详细介绍如何使用PyTorch框架构建一个食物图像分类系统,涵盖数据预处理、模型构建、训练和评估全过程。我们将使用自定义的食物数据集,构建一个卷积神经网络(CNN)模型,并实现完整的训练流程。

一、项目概述

食物图像分类是计算机视觉中的一个常见应用场景。在本项目中,我们将构建一个能够识别20种不同食物的分类系统。整个流程包括:

  1. 数据准备与预处理
  2. 构建自定义数据集类
  3. 设计CNN模型架构
  4. 训练模型并评估性能
  5. 优化与结果分析

二、环境准备

首先确保已安装必要的Python库:

import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import os

三、数据预处理

3.1 数据增强与标准化

我们为训练集和验证集分别定义不同的转换策略:

data_transforms = {'train': transforms.Compose([transforms.Resize([300,300]),transforms.RandomRotation(45),transforms.CenterCrop(256),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),transforms.RandomGrayscale(p=0.1),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'valid': transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}

关键点解析:

  1. 训练集增强

    • 随机旋转(-45°到45°)
    • 随机水平和垂直翻转
    • 色彩抖动(亮度、对比度、饱和度和色调)
    • 随机灰度化(概率10%)
  2. 标准化处理

    • 使用ImageNet的均值和标准差进行归一化
    • 有助于模型更快收敛

3.2 数据集准备

我们编写了一个函数来生成训练和测试的标注文件:

def train_test_file(root, dir):file_txt = open(dir+'.txt','w')path = os.path.join(root,dir)for roots, directories, files in os.walk(path):if len(directories) != 0:dirs = directorieselse:now_dir = roots.split('\\')for file in files:path_1 = os.path.join(roots,file)file_txt.write(path_1+' '+str(dirs.index(now_dir[-1]))+'\n')file_txt.close()

该函数会遍历指定目录,生成包含图像路径和对应标签的文本文件。

四、自定义数据集类

我们继承PyTorch的Dataset类创建自定义数据集:

class food_dataset(Dataset):def __init__(self, file_path, transform=None):self.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs)def __getitem__(self, idx):image = Image.open(self.imgs[idx])if self.transform:image = self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label, dtype=np.int64))return image, label

关键方法:

  1. __init__: 初始化数据集,读取标注文件
  2. __len__: 返回数据集大小
  3. __getitem__: 根据索引返回图像和标签,应用预处理

五、构建CNN模型

我们设计了一个三层的CNN网络:

class CNN(nn.Module):def __init__(self):super(CNN,self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 16, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.conv2 = nn.Sequential(nn.Conv2d(16,32,5,1,2),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.out = nn.Linear(64*32*32, 20)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)output = self.out(x)return output

网络结构分析:

  1. 卷积层1

    • 输入通道:3 (RGB)
    • 输出通道:16
    • 卷积核:5×5
    • 输出尺寸:(16, 128, 128)
  2. 卷积层2

    • 输入通道:16
    • 输出通道:32
    • 输出尺寸:(32, 64, 64)
  3. 卷积层3

    • 输入通道:32
    • 输出通道:64
    • 输出尺寸:(64, 32, 32)
  4. 全连接层

    • 输入:64×32×32 = 65536
    • 输出:20 (对应20类食物)

六、训练与评估

6.1 训练函数

def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()if batch_size_num % 1 == 0:print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1

6.2 评估函数

def Test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test result: \n Accuracy:{(100*correct)}%, Avg loss:{test_loss}")

6.3 训练流程

# 初始化模型
model = CNN().to(device)# 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练10个epoch
epochs = 10
for t in range(epochs):print(f"epoch {t+1}\n---------------")train(train_dataloader, model, loss_fn, optimizer)# 最终评估
Test(test_dataloader, model, loss_fn)

七、关键技术与优化

  1. 数据增强:通过多种变换增加数据多样性,防止过拟合
  2. 批标准化:使用ImageNet统计量进行标准化,加速收敛
  3. 学习率选择:使用Adam优化器,初始学习率0.001
  4. 设备选择:自动检测并使用GPU加速训练

八、常见问题与解决

  1. 内存不足

    • 减小batch size
    • 使用更小的图像尺寸
  2. 过拟合

    • 增加数据增强
    • 添加Dropout层
    • 使用L2正则化
  3. 训练不收敛

    • 检查学习率
    • 检查数据预处理
    • 检查模型结构

九、完整代码

import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import osdata_transforms = { #字典'train':transforms.Compose([            #对图片预处理的组合transforms.Resize([300,300]),   #对数据进行改变大小transforms.RandomRotation(45),  #随机旋转,-45到45之间随机选transforms.CenterCrop(256),     #从中心开始裁剪[256,256]transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转,p是指选择一个概率翻转,p=0.5表示百分之50transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue=0.1),transforms.RandomGrayscale(p=0.1),#概率转换成灰度率,3通道就是R=G=Btransforms.ToTensor(),#数据转换为tensortransforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])#标准化,均值,标准差]),'valid':transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 标准化,均值,标准差]),
}
#做了数据增强不代表训练效果一定会好,只能说大概率会变好
def train_test_file(root,dir):file_txt = open(dir+'.txt','w')path = os.path.join(root,dir)for roots,directories,files in os.walk(path):if len(directories) !=0:dirs = directorieselse:now_dir = roots.split('\\')for file in files:path_1 = os.path.join(roots,file)print(path_1)file_txt.write(path_1+' '+str(dirs.index(now_dir[-1]))+'\n')file_txt.close()root = r'.\食物分类\food_dataset'
train_dir = 'train'
test_dir = 'test'
train_test_file(root,train_dir)
train_test_file(root,test_dir)#Dataset是用来处理数据的
class food_dataset(Dataset):        # food_dataset是自己创建的类名称,可以改为你需要的名称def __init__(self,file_path,transform=None):    #类的初始化,解析数据文件txtself.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f: #是把train.txt文件中的图片路径保存在self.imgssamples = [x.strip().split(' ') for x in f.readlines()]for img_path,label in samples:self.imgs.append(img_path)  #图像的路径self.labels.append(label)   #标签,还不是tensor# 初始化:把图片目录加到selfdef __len__(self):  #类实例化对象后,可以使用len函数测量对象的个数return  len(self.imgs)#training_data[1]def __getitem__(self, idx):    #关键,可通过索引的形式获取每一个图片的数据及标签image = Image.open(self.imgs[idx])  #读取到图片数据,还不是tensor,BGRif self.transform:                  #将PIL图像数据转换为tensorimage = self.transform(image)   #图像处理为256*256,转换为tensorlabel = self.labels[idx]    #label还不是tensorlabel = torch.from_numpy(np.array(label,dtype=np.int64))    #label也转换为tensorreturn image,label
#training_data包含了本次需要训练的全部数据集
training_data = food_dataset(file_path='train.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path='test.txt', transform=data_transforms['valid'])#training_data需要具备索引的功能,还要确保数据是tensor
train_dataloader = DataLoader(training_data,batch_size=16,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=16,shuffle=True)'''判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU'''
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")   #字符串的格式化,CUDA驱动软件的功能:pytorch能够去执行cuda的命令
# 神经网络的模型也需要传入到GPU,1个batch_size的数据集也需要传入到GPU,才可以进行训练''' 定义神经网络  类的继承这种方式'''
class CNN(nn.Module): #通过调用类的形式来使用神经网络,神经网络的模型,nn.mdouledef __init__(self): #输入大小:(3,256,256)super(CNN,self).__init__()  #初始化父类self.conv1 = nn.Sequential( #将多个层组合成一起,创建了一个容器,将多个网络组合在一起nn.Conv2d(              # 2d一般用于图像,3d用于视频数据(多一个时间维度),1d一般用于结构化的序列数据in_channels=3,      # 图像通道个数,1表示灰度图(确定了卷积核 组中的个数)out_channels=16,     # 要得到多少个特征图,卷积核的个数kernel_size=5,      # 卷积核大小 3×3stride=1,           # 步长padding=2,          # 一般希望卷积核处理后的结果大小与处理前的数据大小相同,效果会比较好),                      # 输出的特征图为(16,256,256)nn.ReLU(),  # Relu层,不会改变特征图的大小nn.MaxPool2d(kernel_size=2),    # 进行池化操作(2×2操作),输出结果为(16,128,128))self.conv2 = nn.Sequential(nn.Conv2d(16,32,5,1,2),  #输出(32,128,128)nn.ReLU(),  #Relu层  (32,128,128)nn.MaxPool2d(kernel_size=2),    #池化层,输出结果为(32,64,64))self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),  # 输出(64,64,64)nn.ReLU(),  # Relu层  (64,64,64)nn.MaxPool2d(kernel_size=2),  # 池化层,输出结果为(64,32,32))self.out = nn.Linear(64*32*32,20)  # 全连接层得到的结果def forward(self,x):   #前向传播,你得告诉它 数据的流向 是神经网络层连接起来,函数名称不能改x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0),-1)    # flatten操作,结果为:(batch_size,32 * 64 * 64)output = self.out(x)return output
model = CNN().to(device) #把刚刚创建的模型传入到GPU
print(model)def train(dataloader,model,loss_fn,optimizer):model.train() #告诉模型,我要开始训练,模型中w进行随机化操作,已经更新w,在训练过程中,w会被修改的
# pytorch提供2种方式来切换训练和测试的模式,分别是:model.train() 和 mdoel.eval()
# 一般用法是:在训练开始之前写上model.train(),在测试时写上model.eval()batch_size_num = 1for X,y in dataloader:              #其中batch为每一个数据的编号X,y = X.to(device),y.to(device) #把训练数据集和标签传入cpu或GPUpred = model.forward(X)         # .forward可以被省略,父类种已经对此功能进行了设置loss = loss_fn(pred,y)          # 通过交叉熵损失函数计算损失值loss# Backpropagation 进来一个batch的数据,计算一次梯度,更新一次网络optimizer.zero_grad()           # 梯度值清零loss.backward()                 # 反向传播计算得到每个参数的梯度值woptimizer.step()                # 根据梯度更新网络w参数loss_value = loss.item()        # 从tensor数据种提取数据出来,tensor获取损失值if batch_size_num %1 ==0:print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1def Test(dataloader,model,loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)  # 打包的数量model.eval()        #测试,w就不能再更新test_loss,correct =0,0with torch.no_grad():       #一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候for X,y in dataloader:X,y = X.to(device),y.to(device)pred = model.forward(X)test_loss += loss_fn(pred,y).item() #test_loss是会自动累加每一个批次的损失值correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y) #dim=1表示每一行中的最大值对应的索引号,dim=0表示每一列中的最大值对应的索引号b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batches #能来衡量模型测试的好坏correct /= size  #平均的正确率print(f"Test result: \n Accuracy:{(100*correct)}%, Avg loss:{test_loss}")loss_fn = nn.CrossEntropyLoss()  #创建交叉熵损失函数对象,因为手写字识别一共有十种数字,输出会有10个结果
#
optimizer = torch.optim.Adam(model.parameters(),lr=0.001) #创建一个优化器,SGD为随机梯度下降算法
# # params:要训练的参数,一般我们传入的都是model.parameters()
# # lr:learning_rate学习率,也就是步长
#
# # loss表示模型训练后的输出结果与样本标签的差距。如果差距越小,就表示模型训练越好,越逼近真实的模型
train(train_dataloader,model,loss_fn,optimizer) #训练1次完整的数据。多轮训练
Test(test_dataloader,model,loss_fn)epochs = 10
for t in range(epochs):print(f"epoch {t+1}\n---------------")train(train_dataloader,model,loss_fn,optimizer)
print("Done!")
Test(test_dataloader,model,loss_fn)

十、总结

本文详细介绍了使用PyTorch实现食物分类的全流程。通过合理的网络设计、数据增强和训练策略,我们能够构建一个有效的分类系统。读者可以根据实际需求调整网络结构、超参数和数据增强策略,以获得更好的性能。

完整代码已在上文展示,建议在实际应用中根据具体数据集调整相关参数。希望本文能帮助读者掌握PyTorch图像分类的基本流程和方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/85421.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/85421.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

详细说说分布式Session的几种实现方式

1. 基于客户端存储(Cookie-Based) 原理:将会话数据直接存储在客户端 Cookie 中 实现: // Spring Boot 示例 Bean public CookieSerializer cookieSerializer() {DefaultCookieSerializer serializer new DefaultCookieSerializ…

用mac的ollama访问模型,为什么会出现模型胡乱输出,然后过一会儿再访问,就又变成正常的

例子:大模型推理遇到内存不足 1. 场景还原 你在Mac上用Ollama运行如下代码(以Python为例,假设Ollama有API接口): import requestsprompt "请写一首关于夏天的诗。" response requests.post("http:…

简说 Linux 用户组

Linux 用户组 的核心概念、用途和管理方法,尽量简明易懂。 🌟 什么是 Linux 用户组? 在 Linux 系统中: 👉 用户组(group) 是一组用户的集合,用来方便地管理权限。 👉 用…

S32DS上进行S32K328的时钟配置,LPUART时钟配置步骤详解

1:S32K328的基础信息 S32K328官网介绍 由下图可知,S32K328的最大主频为 240MHz 2:S32K328时钟树配置 2.1 system clock node 节点说明 根据《S32K3xx Reference Manual》资料说明 Table 143 各个 系统时钟节点 的最大频率如下所示&#…

wordpress小语种网站模板

wordpress朝鲜语模板 紫色风格的韩语wordpress主题,适合做韩国、朝鲜的外贸公司官方网站使用。 https://www.jianzhanpress.com/?p8486 wordpress日文模板 绿色的日语wordpress外贸主题,用来搭建日文外贸网站很实用。 https://www.jianzhanpress.co…

网络:Wireshark解析https协议,firefox

文章目录 问题浏览器访问的解决方法python requests问题 现在大部分的网站已经切到https,很多站点即使开了80的端口,最终还是会返回301消息,让客户端转向到https的一个地址。 所以在使用wireshark进行问题分析的时候,解析tls上层的功能,是必不可少的,但是这个安全交换的…

ollama部署开源大模型

1. 技术概述 Spring AI:Spring 官方推出的 AI 框架,简化大模型集成(如文本生成、问答系统),支持多种 LLM 提供商。Olama:开源的本地 LLM 推理引擎,支持量化模型部署,提供 REST API …

Kafka 可靠性保障:消息确认与事务机制(二)

Kafka 事务机制 1. 幂等性与事务的关系 在深入探讨 Kafka 的事务机制之前,先来了解一下幂等性的概念。幂等性,简单来说,就是对接口的多次调用所产生的结果和调用一次是一致的。在 Kafka 中,幂等性主要体现在生产者端&#xff0c…

使用 React.Children.map遍历或修改 children

使用场景: 需要对子组件进行统一处理(如添加 key、包裹额外元素、过滤特定类型等)。 动态修改 children 的 props 或结构。 示例代码:遍历并修改 children import React from react;// 一个组件,给每个子项添加边框…

智能体三阶:LLM→Function Call→MCP

哈喽,我是老刘 老刘是个客户端开发者,目前主要是用Flutter进行开发,从Flutter 1.0开始到现在已经6年多了。 那为啥最近我对MCP和AI这么感兴趣的呢? 一方面是因为作为一个在客户端领域实战多年的程序员,我觉得客户端开发…

flutter的常规特征

前言 Flutter 是由 Google 开发的开源 UI 软件开发工具包,用于构建跨平台的高性能、美观且一致的应用程序。 一、跨平台开发能力 1.多平台支持:Flutter 支持构建 iOS、Android、Web、Windows、macOS 和 Linux 应用,开发者可以使用一套代码库在…

【Git】代码托管服务

博主:👍不许代码码上红 欢迎:🐋点赞、收藏、关注、评论。 格言: 大鹏一日同风起,扶摇直上九万里。 文章目录 Git代码托管服务概述Git核心概念主流Git托管平台Git基础配置仓库创建方式Git文件状态管理常用…

Android 网络请求的选择逻辑(Connectivity Modules)

代码分析 ConnectivityManager packages/modules/Connectivity/framework/src/android/net/ConnectivityManager.java 许多APN已经弃用,应用层统一用 requestNetwork() 来请求网络。 [ConnectivityManager] example [ConnectivityManager] requestNetwork() [Connectivi…

C#建立与数据库连接(版本问题的解决方案)踩坑总结

1.如何优雅的建立数据库连接 今天使用这个deepseek写代码,主要就是建立数据库的链接,包括这个建库建表啥的都是他整得,我就是负责执行,然后解决这个里面遇到的一些问题; 其实我学习这个C#不过是短短的4天的时间&…

FastAPI的初步学习(Django用户过来的)

我一直以来是Django重度用户。它有清晰的MVC架构模式、多应用组织结构。它内置用户认证、数据库ORM、数据库迁移、管理后台、日志等功能,还有强大的社区支持。再搭配上Django REST framework (DRF) ,开发起来效率极高。主打功能强大、易于使用。 曾经也…

提升IT运维效率 贝锐向日葵推出自动化企业脚本功能

在企业进行远程IT运维管理的过程中,难免会涉及很多需要批量操作下发指令的场景,包括但不限于下列场景: ● ⼤规模设备部署与初始化、设备配置更新 ● 业务软件安装与系统维护,进行安全加固或执行问题修复命令 ● 远程设备监控与…

最简单的远程桌面连接方法是什么?系统自带内外网访问实现

在众多远程桌面连接方式中,使用 Windows 系统自带的远程桌面连接功能是较为简单的方法之一,无论是在局域网内还是通过公网进行远程连接,都能轻松实现。 一、局域网内连接步骤 1、 开启目标计算机远程桌面功能:在目标计算机&…

JVM(2)——垃圾回收算法

本文将穿透式解析JVM垃圾回收核心算法,涵盖7大基础算法4大现代GC实现3种内存分配策略,通过15张动态示意图GC日志实战分析,带您彻底掌握JVM内存自动管理机制。 一、GC核心概念体系 1.1 对象存亡判定法则 引用计数法致命缺陷: // …

基于Spring Boot+Vue的“暖寓”宿舍管理系统设计与实现(源码及文档)

基于Spring BootVue的“暖寓”宿舍管理系统设计与实现 第 1 章 绪论 1.1 论文研究主要内容 1.1.1 系统概述 1.1.2 系统介绍 1.2 国内外研究现状 第 2 章 关键技术介绍 2.1 关键性开发技术的介绍 2.1.1 Java简介 2.1.2 Spring Boot框架 2.2 其他相关技术 2.2.1 Vue.J…

基于Java的不固定长度字符集在指定宽度和自适应模型下图片绘制生成实战

目录 前言 一、需求介绍 1、指定宽度生成 2、指定列自适应生成 二、Java生成实现 1、公共方法 2、指定宽度生成 3、指定列自适应生成 三、总结 前言 在当今数字化与信息化飞速发展的时代,图像的生成与处理技术正日益成为众多领域关注的焦点。从创意设计到数…