嵌入式学习-PyTorch(8)-day24

torch.optim 优化器

torch.optim 是 PyTorch 中用于优化神经网络参数的模块,里面实现了一系列常用的优化算法,比如 SGD、Adam、RMSprop 等,主要负责根据梯度更新模型的参数。


🏗️ 核心组成

1. 常用优化器

优化器作用典型参数
torch.optim.SGD标准随机梯度下降,支持 momentumlr, momentum, weight_decay
torch.optim.Adam自适应学习率,效果稳定lr, betas, weight_decay
torch.optim.RMSprop平滑梯度,常用于RNNlr, alpha, momentum
torch.optim.AdamW改进版Adam,解耦正则化lr, weight_decay
torch.optim.Adagrad稀疏特征场景,自动调整每个参数的学习率lr, lr_decay, weight_decay

 演示代码

import torch
import torch.nn as nn
import torch.optim as optimmodel = nn.Linear(10, 1)  # 一个简单的线性层
optimizer = optim.Adam(model.parameters(), lr=0.001)for epoch in range(100):output = model(torch.randn(4, 10))  # 模拟一个输入loss = (output - torch.randn(4, 1)).pow(2).mean()  # 假设是 MSE 损失optimizer.zero_grad()  # 梯度清零loss.backward()        # 反向传播optimizer.step()       # 更新参数

import torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2ddataset = torchvision.datasets.CIFAR10(root='./data_CIF', train=False, download=True, transform=torchvision.transforms.ToTensor())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)class Tudui(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2)self.maxpool1 = nn.MaxPool2d(kernel_size=2)self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2)self.maxpool2 = nn.MaxPool2d(kernel_size=2)self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2)self.maxpool3 = nn.MaxPool2d(kernel_size=2)self.flatten = nn.Flatten()self.linear1 = nn.Linear(in_features=1024, out_features=64)self.linear2 = nn.Linear(in_features=64, out_features=10)self.model1 = nn.Sequential(Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),nn.MaxPool2d(kernel_size=2),nn.Flatten(),nn.Linear(in_features=1024, out_features=64),nn.Linear(in_features=64, out_features=10))def forward(self, x):x = self.model1(x)return x
loss = nn.CrossEntropyLoss()
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01 )
for epoch in range(100):running_loss = 0.0for data in dataloader:imgs,targets = dataoutputs = tudui(imgs)result_loss = loss(outputs, targets)#梯度置零optim.zero_grad()#反向传播result_loss.backward()#更新参数optim.step()running_loss += result_lossprint(running_loss)

 

 对网络模型的修改

import torchvision
from torch import nn# train_data = torchvision.datasets.ImageNet(root='./data_IMG',split="train", transform=torchvision.transforms.ToTensor())
#学习如何改变现有的网络结构
vgg16_false = torchvision.models.vgg16(pretrained=False)vgg16_true = torchvision.models.vgg16(pretrained=True)train_data = torchvision.datasets.CIFAR10(root='./data_CIF',train=True,transform=torchvision.transforms.ToTensor(),download=True)
#加一个线性层
vgg16_true.add_module('add_linear',nn.Linear(in_features=1000,out_features=10))
vgg16_true.classifier.add_module('add_linear',nn.Linear(in_features=1000,out_features=10))
#修改一个线性层
vgg16_false.classifier[6] = nn.Linear(in_features=4096,out_features=10)
print(vgg16_false)

网络模型的保存与读取

#model_save.pyimport torch
import torchvision
from torch import nnvgg16 = torchvision.models.vgg16(pretrained=False)
#保存方式一:模型结构+模型参数
torch.save(vgg16,"vgg16.pth")#保存方式二:模型参数(官方推荐)
torch.save(vgg16.state_dict(),"vgg16_state_dict.pth")#陷阱
class Tudui(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3,64,kernel_size=3)def forward(self, x):x = self.conv1(x)return xtudui = Tudui()
torch.save(tudui,"tudui_method1.pth")
#model_load.pyimport torch
import torchvisionfrom torch import nn#保存方式一,加载模型
# model = torch.load("vgg16.pth",weights_only=False)
# print(model)#方式二,加载模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# model = torch.load("vgg16_state_dict.pth")
vgg16.load_state_dict(torch.load("vgg16_state_dict.pth"))
# print(vgg16)#陷阱
#陷阱
class Tudui(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3,64,kernel_size=3)def forward(self, x):x = self.conv1(x)return x#如果直接这么调用的话,机器会找不到类在哪里
# 当你 torch.save(model) 保存整个模型时,它会把整个类的信息序列化。如果加载时当前文件找不到 Tudui 类,自然就炸了。
#可以将定义写到这个类来,也可以在开头写from model_save import *
#!!!更推荐一下模式:
"""
# 保存
torch.save(model.state_dict(), "tudui_method2.pth")# 加载
model = Tudui()
model.load_state_dict(torch.load("tudui_method2.pth"))优点:不管类在哪个文件,只要 Tudui() 存在就能加载;避免因为 class 变动导致报错;更加灵活,适合后期修改网络结构。
"""
model = torch.load("tudui_method1.pth",weights_only=False)
print(model)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/89656.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/89656.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PostgreSQL实战:高效SQL技巧

PostgreSQL PG 在不同领域可能有不同的含义,以下是几种常见的解释: PostgreSQL PostgreSQL(简称 PG)是一种开源的关系型数据库管理系统(RDBMS),支持 SQL 标准并提供了丰富的扩展功能。它广泛应用于企业级应用、Web 服务和数据分析等领域。 PostgreSQL 的详细介绍 Po…

3-大语言模型—理论基础:生成式预训练语言模型GPT(代码“活起来”)

目录 1、GPT的模型结构如图所示 2、介绍GPT自监督预训练、有监督下游任务微调及预训练语言模型 2.1、GPT 自监督预训练 2.1.1、 输入编码:词向量与位置向量的融合 2.1.1.1、 输入序列与词表映射 2.1.1.2、 词向量矩阵与查表操作 3. 位置向量矩阵 4. 词向量与…

【Redis 】看门狗:分布式锁的自动续期

在分布式系统的开发中,保证数据的一致性和避免并发冲突是至关重要的任务。Redis 作为一种广泛使用的内存数据库,提供了实现分布式锁的有效手段。然而,传统的 Redis 分布式锁在设置了过期时间后,如果任务执行时间超过了锁的有效期&…

MYSQL--快照读和当前读及并发 UPDATE 的锁阻塞

快照读和当前读在 MySQL 中,数据读取方式主要分为 快照读 和 当前读,二者的核心区别在于是否依赖 MVCC(多版本并发控制)的历史版本、是否加锁,以及读取的数据版本是否为最新。以下是详细说明:一、快照读&am…

css样式中的选择器和盒子模型

目录 一、行内样式二、内部样式三、外部样式四、结合选择器五、属性选择器六、包含选择器七、子选择器八、兄弟选择器九、选择器组合十、伪元素选择器十一、伪类选择器十二、盒子模型 相关文章 学习标签、属性、选择器和外部加样式积累CSS样式属性:padding、marg…

关于基于lvgl库做的注册登录功能的代码步骤:

以下是完整的文件拆分和代码存放说明,按功能模块化划分,方便工程管理:一、需要创建的文件清单 文件名 作用 类型 main.c 程序入口,初始化硬件和LVGL 源文件 ui.h 声明界面相关函数 头文件 ui.c 实现登录、注册、主页面的UI 源文…

RAII机制以及在ROS的NodeHandler中的实现

好的,这是一个非常核心且优秀的设计问题。我们来分两步详细解析:先彻底搞懂什么是 RAII,然后再看 ros::NodeHandle 是如何巧妙地运用这一机制的。1. 什么是 RAII 机制? RAII 是 “Resource Acquisition Is Initialization” 的缩写…

Linux LVS集群技术

LVS集群概述1、集群概念1.1、介绍集群是指多台服务器集中在一起,实现同一业务,可以视为一台计算机。多台服务器组成的一组计算机,作为一个整体存在,向用户提供一组网络资源,这些单个的服务器就是集群的节点。特点&…

spring-ai-alibaba如何上传文件并解析

问题引出 在我们日常使用大模型时,有一类典型的应用场景,就是将文件发送给大模型,然后由大模型进行解析,提炼总结等,这一类功能在官方app中较为常见,但是在很多模型的api中都不支持,那如何使用…

「双容器嵌套布局法」:打造清晰层级的网页架构设计

一、命名与核心概念 “双容器嵌套布局法”,核心是通过两层容器嵌套构建网页结构:外层容器负责控制布局的“宏观约束”(如页面最大宽度、背景色等),内层容器聚焦“微观排版”(内容居中、内边距调整、红色内容…

基于深度学习的自然语言处理:构建情感分析模型

前言 自然语言处理(NLP)是人工智能领域中一个非常活跃的研究方向,它致力于使计算机能够理解和生成人类语言。情感分析(Sentiment Analysis)是NLP中的一个重要应用,其目标是从文本中识别和提取情感倾向&…

JWT原理及利用手法

JWT 原理 JSON Web Token (JWT) 是一种开放的行业标准,用于在系统之间以 JSON 对象的形式安全地传输信息。这些信息经过数字签名,因此可以被验证和信任。其常用于身份验证、会话管理和访问控制机制中传递用户信息。 与传统的会话令牌相比,JWT…

DeepSeek 助力 Vue3 开发:打造丝滑的日历(Calendar),日历_睡眠记录日历示例(CalendarView01_30)

前言:哈喽,大家好,今天给大家分享一篇文章!并提供具体代码帮助大家深入理解,彻底掌握!创作不易,如果能帮助到大家或者给大家一些灵感和启发,欢迎收藏关注哦 💕 目录DeepS…

git的diff命令、Config和.gitignore文件

diff命令:比较git diff xxx:工作目录 vs 暂存区(比较现在修改之后的工作区和暂存区的内容)git diff --cached xxx:暂存区 vs Git仓库(现在暂存区内容和最一开始提交的文件内容的比较)git diff H…

Linux中的LVS集群技术

一、实验环境(RHEL 9)1、NAT模式的实验环境主机名IP地址网关网络适配器功能角色client172.25.254.111/24(NAT模式的接口)172.25.254.2NAT模式客户机lvs172.25.254.100/24(NAT模式的接口)192.168.0.100/24&a…

【数据结构】「队列」(顺序队列、链式队列、双端队列)

- 第 112篇 - Date: 2025 - 07 - 20 Author: 郑龙浩(仟墨) 文章目录队列(Queue)1 基本介绍1.1 定义1.2 栈 与 队列的区别1.3 重要术语2 基本操作3 顺序队列(循环版本)两种版本两种版本区别版本1.1 - rear指向队尾后边 且 无 size …

Java行为型模式---解释器模式

解释器模式基础概念解释器模式(Interpreter Pattern)是一种行为型设计模式,其核心思想是定义一个语言的文法表示,并定义一个解释器,使用该解释器来解释语言中的句子。这种模式将语法解释的责任分开,使得语法…

[spring6: PointcutAdvisor MethodInterceptor]-简单介绍

Advice Advice 是 AOP 联盟中所有增强(通知)类型的标记接口,表示可以被织入目标对象的横切逻辑,例如前置通知、后置通知、异常通知、拦截器等。 package org.aopalliance.aop;public interface Advice {}BeforeAdvice 前置通知的标…

地图定位与导航

定位 1.先申请地址权限(大致位置精准位置) module.json5文件 "requestPermissions": [{"name": "ohos.permission.INTERNET" },{"name": "ohos.permission.LOCATION","reason": "$string:app_name",&qu…

【数据结构】揭秘二叉树与堆--用C语言实现堆

文章目录1.树1.1.树的概念1.2.树的结构1.3.树的相关术语2.二叉树2.1.二叉树的概念2.2.特殊的二叉树2.2.1.满二叉树2.2.2.完全二叉树2.3.二叉树的特性2.4.二叉树的存储结构2.4.1.顺序结构2.4.2.链式结构3.堆3.1.堆的概念3.2.堆的实现3.2.1.堆结构的定义3.2.2.堆的初始化3.2.3.堆…