Pytorch 实战四 VGG 网络训练

系列文章目录


文章目录

  • 系列文章目录
  • 前言
  • 一、源码
    • 1. 解决线程冲突
    • 2.代码框架
  • 二、代码详细介绍
    • 1.基础定义
    • 2. epoch 的定义
    • 3. 每组图片的训练和模型保存


前言

  前面我们已经完成了数据集的制作,VGG 网络的搭建,现在进行网络模型的训练。


一、源码


import torch.nn as nn
import torchvision
from VggNet import VGGNet
from load_cifa10 import train_data_loader, test_data_loader
import torch.multiprocessing as mp
import torch
import multiprocessing
from torch.utils.data import DataLoaderfrom model.ClassModel import netdef main():# 训练模型到底放在 CPU 还是GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 如果cuda有效,就在GPU训练,否则CPU训练# 我们会对样本遍历20次epoch_num = 20# 学习率lr = 0.01# 正确率计算相关batch_num=0correct0 =0# 网络定义# print("need初始化")net = VGGNet().to(device)# 定义损失函数loss,多分类问题,采用交叉熵loss_func = nn.CrossEntropyLoss()# 定义优化器optimizeroptimizer = torch.optim.Adam(net.parameters(),lr=lr)# 动态调整学习率,第一个参数是优化器,起二个参数每5个epoch后调整学习率,第三个参数调整为原来的0.9倍lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5,gamma=0.9)# 定义循环for epoch in range(epoch_num):# print("epoch is:",epoch+1)# 定义网络训练的过程net.train()   # BatchNorm 和 dropout 会选择相应的参数# 对数据进行遍历for i,data in enumerate(train_data_loader):batch_num = len(train_data_loader)# 获取输入和标签inputs, labels = datainputs, labels = inputs.to(device), labels.to(device) # 放到GPU上去# 拿到输出# print("need output")outputs = net(inputs)  # 这句就会调用前向传播,在PyTorch中,当执行outputs = net(inputs)时会自动触发前向传播,# 这是通过nn.Module的__call__方法实现的特殊机制6。具体原理可分为三个关键环节:# 计算损失loss = loss_func(outputs, labels)# 定义优化器,梯度要归零,loss反向传播,更新参数optimizer.zero_grad()loss.backward()optimizer.step()_, predicted = torch.max(outputs.data, dim=1)  # 得到一个batch的预测类别# 在cpu上面运行,当labels是普通张量时,.data属性返回‌剥离计算图的纯数值张量‌(与原始张量共享内存但无梯度追踪)correct = predicted.eq(labels.data).cpu().sum()correct=100.0*correct/len(inputs)correct0+=correct#自动更新学习率lr_scheduler.step()lr = optimizer.state_dict()['param_groups'][0]['lr']print("loss:{},acc:{}",lr,correct0/batch_num)if __name__ == '__main__':# Windows必须设置spawn,Linux/Mac自动选择最佳方式mp.set_start_method('spawn' if torch.cuda.is_available() else 'fork')torch.multiprocessing.freeze_support()try:main()except RuntimeError as e:print(f"多进程错误: {str(e)}")print("降级到单进程模式...")train_data_loader = DataLoader(..., num_workers=0)main()

1. 解决线程冲突

  windows 跑代码需要解决线程冲突的问题:需要自行定义main函数,然后把主题加在里面。当我们运行时自动调用main,就会执行下面的 if 语句,然后运行我们的代码

if __name__ == '__main__':# Windows必须设置spawn,Linux/Mac自动选择最佳方式mp.set_start_method('spawn' if torch.cuda.is_available() else 'fork')torch.multiprocessing.freeze_support()try:main()except RuntimeError as e:print(f"多进程错误: {str(e)}")print("降级到单进程模式...")train_data_loader = DataLoader(..., num_workers=0)main()

2.代码框架

  代码分成四个部分,第一个部分是基础变量定义,第二个部分是循环 epoch ,第三部分是每个 batch 的处理,第四个保存模型,其中最重要的便是第三个。

图 1 代码框架

二、代码详细介绍

1.基础定义

    # 训练模型到底放在 CPU 还是GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 如果cuda有效,就在GPU训练,否则CPU训练# 我们会对样本遍历200次epoch_num = 200# 学习率lr = 0.01# 正确率计算相关batch_num=0correct0 =0# 网络定义# print("need初始化")net = VGGNet().to(device)# 定义损失函数loss,多分类问题,采用交叉熵loss_func = nn.CrossEntropyLoss()# 定义优化器optimizeroptimizer = torch.optim.Adam(net.parameters(),lr=lr)# 动态调整学习率,第一个参数是优化器,起二个参数每5个epoch后调整学习率,第三个参数调整为原来的0.9倍lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5,gamma=0.9)

  基础定义最开始需要定义跑数据的设备,CPU还是GPU.这个是死的。然后定义epoch次数、学习率,至于准确率看自己的使用情况,如果每次跑完一遍数据集打印准确率也不着急定义。我在最终跑完数据才打印准确率,所以需要定义一个全局的变量。接下来便是网络初始化,初始化的网络加载到设备上面。在网络搭建的时候,我们只定义了网络的层次和前向传播。后面的损失函数和优化器需要在训练中进行。那么基础定义里面需要损失函数的选择,优化器的选择和动态调整学习率。epoch 改成200,我的电脑跑了1h还没出结果,现在还在等,建议别弄大了。

图 2 基础定义

当然顺序可以变,最好自己能记住需要的内容。

2. epoch 的定义

  epoch里面开始调用网络,net.train() 会把网络的参数进行初始化,BatchNorm 会自动启用训练模式,dropout层会全部激活,而这在测试集上不需要dropout的。后面便是每组图片 batch的训练行为。最后每一次处理整个数据集需要动态改变学习率,以及打印学习率的方法如下:

        #自动更新学习率lr_scheduler.step()# 打印学习率lr = optimizer.state_dict()['param_groups'][0]['lr']print("学习率", lr)

3. 每组图片的训练和模型保存

      for i,data in enumerate(train_data_loader):batch_num = len(train_data_loader)# 获取输入和标签inputs, labels = datainputs, labels = inputs.to(device), labels.to(device) # 放到GPU上去# 拿到输出# print("need output")outputs = net(inputs)  # 这句就会调用前向传播,在PyTorch中,当执行outputs = net(inputs)时会自动触发前向传播,# 这是通过nn.Module的__call__方法实现的特殊机制6。具体原理可分为三个关键环节:# 计算损失loss = loss_func(outputs, labels)# 定义优化器,梯度要归零,loss反向传播,更新参数optimizer.zero_grad()loss.backward()optimizer.step()_, predicted = torch.max(outputs.data, dim=1)  # 得到一个batch的预测类别

  这里进行数据加载,分批次加载,此处的batch 大小是128,数量是391(用于准确率计算)。加载了数据,获取数据的输入和真实标签。outputs = net(inputs) 这句对网络传入数据,自行前向传播计算获得输出。拿到输出后,进行损失函数计算。损失函数计算,是需要预测值和真实值的,看看偏差多少,因此传入这两个参数。优化器优化,开始梯度归零,然后后向传播,这个过程是自带的,我们只定义了前向传播,后向传播优化参数后固定参数 optimizer.step(),最后使用torch.max() 输出最相似的标签。过程如图:

图 3 batch 循环

模型的保存就一句话:

torch.save(net.state_dict(),"./model/VGGNet.pth")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/85883.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/85883.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

课程专注度分析系统文档

一、项目概述 本项目基于 Flask 框架开发,结合计算机视觉技术(利用 YOLOv10 等模型 ),实现对课堂视频的智能分析。可检测视频中学生手机使用情况、面部表情(专注、分心等 ),统计专注度、手机使…

中国设计 全球审美 | 安贝斯新产品发布会:以东方美学开辟控制台仿生智造新纪元

6月17日,安贝斯(武汉)控制技术有限公司(以下简称“安贝斯”)在武汉隆重举行“新产品发布暨协会联合创新峰会”。近百位来自政府机构、行业协会、行业用户及战略合作伙伴的嘉宾齐聚现场,共同见证以“中国设计…

在微信小程序wxml文件调用函数实现时间转换---使用wxs模块实现

1. 创建 WXS 模块文件(推荐单独存放) 在项目目录下新建 utils.wxs 文件,编写时间转换逻辑: // utils.wxs module.exports {// 将毫秒转换为分钟(保留1位小数)convertToMinutes: function(ms) {if (typeo…

ByteMD 插件系统详解

ByteMD 插件系统详解 ByteMD 的插件系统是其强大扩展性的核心。它允许开发者在 Markdown 解析、AST 转换、HTML 渲染、以及编辑器 UI 交互的各个阶段注入自定义逻辑。这得益于 ByteMD 深度集成了 unified 处理器和其丰富的生态系统(remark 用于 Markdown&#xff0c…

每日一练之 Lua 表

Lua 的 table 是什么数据结构?如何创建和访问? 数据结构:Lua的table是一种哈希表,使用键值对存储数据,支持动态扩容 创建方式: local t1 {} local t2 {10,20,30} local t3 {name"Alice",age25}访问方式&#xff1a…

实现自动胡批量抓取唯品会商品详情数据的途径分享(官方API、网页爬虫)

在电商领域,数据就是企业的核心资产。无论是市场分析、竞品研究,还是精准营销,都离不开对大量商品详情数据的深入挖掘。唯品会作为知名的电商平台,其丰富的商品信息对于众多从业者而言极具价值。本文将详细探讨实现自动批量抓取唯…

Zephyr 高阶实践:彻底讲透 west 构建系统、模块管理与跨平台 CI/CD 配置

本文是 Zephyr 项目管理体系的高阶解构与实战指南,全面覆盖 west 构建系统原理、模块解耦与 west.yml 多模块维护机制,结合企业级多平台 CI/CD 落地流程,深入讲解如何构建可靠、可维护、跨芯片架构的一体化 Zephyr 工程。 一、为什么 Zephyr …

我开源了一套springboot3快速开发模板

我开源了一套springboot3快速开发模板 开箱即用、按需组合、可快速二次开发的后端通用模板。 ✨ 主要特性 Spring Boot 3.x Java 17:跟随 Spring 最新生态,利用现代语法特性。多模块分层:common 抽象通用能力、starter 负责启动、modules…

OpenCV CUDA模块设备层-----在GPU上计算两个uchar1类型像素值的反正切(arctangent)比值函数atan2()

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 对输入的两个 uchar1 像素值 a 和 b,先分别归一化到 [0.0, 1.0] 浮点区间,然后计算它们的 四象限反正切函数。 函数原型…

从C++编程入手设计模式——观察者模式

从C编程入手设计模式——观察者模式 ​ 观察者模式简直就是字如其名,观察观察,观察到了告诉别人。观察手的作用如此,观察者模式的工作机制也是如此。这个模式的核心思路是:一个对象的状态发生变化时,自动通知依赖它的…

MITM 中间人攻击

​据Akamai 2023网络安全报告显示,MITM攻击在数据泄露事件中占比达32.7%,平均每次事件造成企业损失$380,000​ ​NIST研究指出:2022-2023年高级MITM攻击增长41%,近70%针对金融和医疗行业​ 一、MITM攻击核心原理与技术演进 1. 中…

llama_index chromadb实现RAG的简单应用

此demo是自己提的一个需求:用modelscope下载的本地大模型实现RAG应用。毕竟大模型本地化有利于微调,RAG使内容更有依据。 为什么要用RAG? 由于大模型存在一定的局限性:知识时效性不足、专业领域覆盖有限以及生成结果易出现“幻觉…

TDMQ CKafka 版事务:分布式环境下的消息一致性保障

解锁 CKafka 事务能力的神秘面纱 在当今数字化浪潮下,分布式系统已成为支撑海量数据处理和高并发业务的中流砥柱。但在这看似坚不可摧的架构背后,数据一致性问题却如影随形,时刻考验着系统的稳定性与可靠性。 CKafka 作为分布式流处理平台的…

常见的负载均衡算法

常见的负载均衡算法 在实现水平扩展过程中,负载均衡算法是决定请求如何在多个服务实例间分配的核心逻辑。一个合理的负载均衡策略能够有效分散系统压力,提升系统吞吐能力与稳定性。 负载均衡算法可部署在多种层级中,如七层HTTP反向代理&…

数据结构转换与离散点生成

在 C 开发中&#xff0c;我们常常需要在不同的数据结构之间进行转换&#xff0c;以满足特定库或框架的要求。本文将探讨如何将 std::vector<gp_Pnt> 转换为 QVector<QPointF>&#xff0c;并生成特定范围内的二维离散点。 生成二维离散点 我们首先需要生成一系列…

零基础学习Redis(12) -- Java连接redis服务器

在我们之前的内容中&#xff0c;我们会发现通过命令行操作redis是十分不科学的&#xff0c;所以redis官方提供了redis的应用层协议RESP&#xff0c;更具这个协议可以实现一个和redis服务器通信的客户端程序&#xff0c;来简化和完善redis的使用。现阶段有很多封装了RESP协议的库…

clangd LSP 不能找到项目中的文件

clangd LSP 不能找到项目中的文件 clangd LSP 不能找到项目中的文件 clangd LSP 不能找到项目中的文件 Normally you need to create compile_commands.json。 如果你使用 cmake 作为构建工具&#xff0c;请执行下面的命令&#xff1a; cmake -DCMAKE_EXPORT_COMPILE_COMMAN…

【内存】Linux 内核优化实战 - vm.overcommit_memory

目录 vm.overcommit_memory 解释一、概念与作用二、参数取值与含义三、相关参数与配置方式四、实际应用场景建议五、注意事项 vm.overcommit_memory 解释 一、概念与作用 vm.overcommit_memory 是 Linux 内核中的一个参数&#xff0c;用于控制内存分配的“过度承诺”&#xf…

Python:.py文件转换为双击可执行的Windows程序(版本2)

流程步骤&#xff1a; 这个流程图展示了将 Python .py 文件转换为 Windows 可执行程序的完整过程&#xff0c;主要包括以下步骤&#xff1a; 1、准备 Python文件&#xff0c;确保代码可独立运行 2、安装打包工具&#xff08;如 PyInstaller&#xff09; 3、打开命令提示符并定位…

【请关注】mysql一些经常用到的高级SQL

经常去重复数据&#xff0c;数据需要转等操作&#xff0c;汇总高级SQL MySQL操作 一、数据去重&#xff08;Data Deduplication&#xff09; 去重常用于清除重复记录&#xff0c;保留唯一数据。 1. 使用DISTINCT关键字去重单列 -- 从用户表中获取唯一的邮箱地址 SELECT DISTIN…