PyTorch深度学习快速入门学习总结(三)

现有网络模型的使用与调整

VGG — Torchvision 0.22 documentation

        VGG 模型是由牛津大学牛津大学(Oxford University)的 Visual Geometry Group 于 2014 年提出的卷积神经网络模型,在 ImageNet 图像分类挑战赛中表现优异,以其简洁统一的网络结构设计而闻名。

  • 优点:结构简洁统一,易于理解和实现;小卷积核设计提升了特征提取能力,泛化性能较好。
  • 缺点:参数量巨大(主要来自全连接层),计算成本高,训练和推理速度较慢,对硬件资源要求较高。

ImageNet — Torchvision 0.22 documentation

        在 PyTorch 的torchvision库中,ImageNet相关功能主要用于加载和预处理 ImageNet 数据集,方便用户在该数据集上训练或评估模型。

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader# train_data = torchvision.datasets.ImageNet('./data_image_net', split='train', download=True,
#                                        transform=torchvision.transforms.ToTensor())
# 指定要加载的数据集子集。这里设置为'train',表示加载的是 ImageNet 的训练集(包含约 120 万张图像)。
# 若要加载验证集,可将该参数改为'val'(验证集包含约 5 万张图像)。vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)print(vgg16_true)# 利用现有网络修改结构
train_data = torchvision.datasets.CIFAR10('./dataset', train=True, download=True,transform=torchvision.transforms.ToTensor())vgg16_true.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_true)
# (add_linear): Linear(in_features=1000, out_features=10, bias=True)# 修改位置不同
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))# 直接修改某层
print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)

vgg16_false:适合用于训练全新的任务,vgg16_true:常用于迁移学习场景

  • vgg16_false:pretrained=False表示不加载在大型数据集上(如 ImageNet)预训练好的权重参数。此时,模型的权重参数会按照默认的随机初始化方式进行初始化,比如卷积层和全连接层的权重会从一定范围内的随机值开始,偏置项通常初始化为 0 或一个较小的值。这种初始化方式下,模型需要从头开始学习训练数据中的特征表示。
  • vgg16_true:pretrained=True表示加载在 ImageNet 数据集上预训练好的权重参数。VGG16 在 ImageNet 上经过大量图像的训练,已经学习到了通用的图像特征,比如边缘、纹理、形状等。加载这些预训练权重后,模型在新的任务上可以利用已学习到的特征,减少训练所需的样本数量和训练时间,在很多情况下能更快地收敛到较好的性能。

模型的保存与加载

模型保存

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoadervgg16 = torchvision.models.vgg16(pretrained=False)# 方式一保存 模型结构+模型参数
torch.save(vgg16, 'vgg16_method1.pth')# 方式二保存 模型参数 (推荐)
torch.save(vgg16.state_dict(), 'vgg16_method2.pth')# 陷阱
class Chenxi(nn.Module):def __init__(self, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.conv1 = nn.Conv2d(3, 64, kernel_size=3)def forward(self, x):x = self.conv1(x)return xchenxi = Chenxi()
torch.save(chenxi, 'chenxi_method1.pth')

相应模型加载方法

import torch
import torchvision
from module_save import * # 注意from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader# 方式一, 加载模型
model = torch.load('vgg16_method1.pth')
# print(model)# 方式二, 加载模型
vgg16 = torchvision.models.vgg16(pretrained=False) # 新建网络模型结构
vgg16.load_state_dict(torch.load('vgg16_method2.pth'))
# print(vgg16)
# model = torch.load('vgg16_method2.pth')
# print(model)# 陷阱1
# 方法一必须要有模型# class Chenxi(nn.Module):
#     def __init__(self, *args, **kwargs) -> None:
#         super().__init__(*args, **kwargs)
#         self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
#     def forward(self, x):
#         x = self.conv1(x)
#         return x# chenxi = Chenxi() 不需要model = torch.load('chenxi_method1.pth')
print(model)

完整的模型训练套路

CIFAR10为例

model文件

import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear# 搭建神经网络
class Chenxi(nn.Module):def __init__(self, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10),)def forward(self, x):x = self.model1(x)return xif __name__ == '__main__':chenxi = Chenxi()input = torch.ones((64, 3, 32, 32))output = chenxi(input)print(output.shape)# torch.Size([64, 10])

train文件

import torch
import torchvision.datasets
from torch import nn
from torch.nn import MaxPool2d, Sequential, Conv2d, Flatten, Linear
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterfrom model import *# 准备数据集
train_data = torchvision.datasets.CIFAR10('./dataset', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_data = torchvision.datasets.CIFAR10('./dataset', train=False, download=True,transform=torchvision.transforms.ToTensor())train_data_size = len(train_data)
test_data_size = len(test_data)print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))
# 训练数据集的长度为:50000
# 测试数据集的长度为:10000train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)# 创建网络模型
chenxi = Chenxi()# 损失函数
loss_fn = nn.CrossEntropyLoss()# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(chenxi.parameters(),lr=learning_rate)# 设置训练网络的参数
# 记录训练次数
total_train_step = 0
# 记录测试次数
total_test_step = 0
# 训练的轮数
epoch = 10# 添加Tensorboard
writer = SummaryWriter('./logs_train')for i in range(epoch):print('---------第{}轮训练开始-----------'.format(i + 1))# 训练开始chenxi.train()for data in train_dataloader:imgs, target = dataoutputs = chenxi(imgs)loss = loss_fn(outputs, target)# 优化器优化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_step = total_train_step + 1if total_train_step % 100 == 0:print("训练次数:{}, loss:{}".format(total_train_step, loss.item()))writer.add_scalar('train_loss', loss.item(), total_train_step)# 测试步骤开始chenxi.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():# 禁用梯度计算for data in test_dataloader:imgs, targets = dataoutputs = chenxi(imgs)loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + loss.item()accuracy = (outputs.argmax(1) == targets).sum()total_accuracy = total_accuracy + accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))writer.add_scalar('test_loss', loss.item(), total_test_step)writer.add_scalar('test_accuracy', loss.item(), total_test_step)total_test_step += 1torch.save(chenxi, "chenxi_{}.pth".format(i))# torch.save(chenxi.state_dict(), "chenxi_{}.pth".format(i))print("模型已保存")writer.close()

train优化部分

import torch
outputs = torch.tensor([[0.1, 0.2],[0.05, 0.4]])
# print(outputs.argmax(0)) # 纵向
# print(outputs.argmax(1)) # 横向preds = outputs.argmax(0)
targets = torch.tensor([0, 1])
print((preds == targets).sum())

使用GPU进行训练

方法一:使用.cuda()

        对 网络模型、数据及标注、损失函数, 进行 .cuda()操作

        如果电脑不支持GPU,可以使用谷歌浏览器:

https://colab.research.google.com/drive/1HKuF0FtulVXkHaiWV8VzT-VXZmbq4kK4#scrollTo=861yC3qEpi3F

方法二:使用.to(device)

device = torch.device('cpu')
# device = torch.device('cuda')
x = x.to(device)
# 代替.cuda()

import torch
import torchvision.datasets
from torch import nn
from torch.nn import MaxPool2d, Sequential, Conv2d, Flatten, Linear
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter# from model import *# 准备数据集
train_data = torchvision.datasets.CIFAR10('./dataset', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_data = torchvision.datasets.CIFAR10('./dataset', train=False, download=True,transform=torchvision.transforms.ToTensor())train_data_size = len(train_data)
test_data_size = len(test_data)print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))
# 训练数据集的长度为:50000
# 测试数据集的长度为:10000train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)# 创建网络模型
class Chenxi(nn.Module):def __init__(self, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10),)def forward(self, x):x = self.model1(x)return xchenxi = Chenxi()
if torch.cuda.is_available():chenxi = chenxi.cuda()
# 损失函数
loss_fn = nn.CrossEntropyLoss()
if torch.cuda.is_available():loss_fn = loss_fn.cuda()# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(chenxi.parameters(),lr=learning_rate)# 设置训练网络的参数
# 记录训练次数
total_train_step = 0
# 记录测试次数
total_test_step = 0
# 训练的轮数
epoch = 10# 添加Tensorboard
writer = SummaryWriter('./logs_train')for i in range(epoch):print('---------第{}轮训练开始-----------'.format(i + 1))# 训练开始chenxi.train()for data in train_dataloader:imgs, targets = dataif torch.cuda.is_available():imgs = imgs.cuda()targets = targets.cuda()outputs = chenxi(imgs)loss = loss_fn(outputs, targets)# 优化器优化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_step = total_train_step + 1if total_train_step % 100 == 0:print("训练次数:{}, loss:{}".format(total_train_step, loss.item()))writer.add_scalar('train_loss', loss.item(), total_train_step)# 测试步骤开始chenxi.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():# 禁用梯度计算for data in test_dataloader:imgs, targets = dataif torch.cuda.is_available():imgs = imgs.cuda()targets = targets.cuda()outputs = chenxi(imgs)loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + loss.item()accuracy = (outputs.argmax(1) == targets).sum()total_accuracy = total_accuracy + accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))writer.add_scalar('test_loss', loss.item(), total_test_step)writer.add_scalar('test_accuracy', loss.item(), total_test_step)total_test_step += 1torch.save(chenxi, "chenxi_{}.pth".format(i))# torch.save(chenxi.state_dict(), "chenxi_{}.pth".format(i))print("模型已保存")writer.close()

完整的模型验证套路

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/91291.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/91291.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

是否需要买一个fpga开发板?

纠结要不要买个 FPGA 开发板?真心建议搞一块,尤其是想在数字电路、嵌入式领域扎根的同学,这玩意儿可不是可有可无的摆设。入门级的选择不少,全新的像 Cyclone IV、Artix 7 系列,几百块就能拿下,要是去二手平…

【模型细节】MHSA:多头自注意力 (Multi-head Self Attention) 详细解释,使用 PyTorch代码示例说明

MHSA:使用 PyTorch 实现的多头自注意力 (Multi-head Self Attention) 代码示例,包含详细注释说明:线性投影 通过三个线性层分别生成查询(Q)、键(K)、值(V)矩阵: QWq⋅x,KWk⋅x,VWv⋅xQ W_qx, \quad K W_kx, \quad V W_vxQWq​⋅x,KWk​⋅x…

PGSQL运维优化:提升vacuum执行时间观测能力

本文是 IvorySQL 2025 生态大会暨 PostgreSQL 高峰论坛上的演讲内容,作者:NKYoung。 6 月底济南召开的 HOW2025 IvorySQL 生态大会上,我在内核论坛分享了“提升 vacuum 时间观测能力”的主题,提出了新增统计信息的方法&#xff0c…

神奇的数据跳变

目的 上周遇上了一个非常奇怪的问题,就是软件的数据在跳变,本来数据应该是158吧,数据一会变成10,一会又变成158,数据在不断地跳变,那是怎么回事?? 这个问题非常非常的神奇,让人感觉太不可思议了。 这是这段时间,我遇上的最神奇的事了,没有之一,最神奇的事,下面…

【跨国数仓迁移最佳实践3】资源消耗减少50%!解析跨国数仓迁移至MaxCompute背后的性能优化技术

本系列文章将围绕东南亚头部科技集团的真实迁移历程展开,逐步拆解 BigQuery 迁移至 MaxCompute 过程中的关键挑战与技术创新。本篇为第3篇,解析跨国数仓迁移背后的性能优化技术。注:客户背景为东南亚头部科技集团,文中用 GoTerra …

【MySQL集群架构与实践3】使用Dcoker实现读写分离

目录 一. 在Docker中安装ShardingSphere 二 实践:读写分离 2.1 应用场景 2.2 架构图 2.3 服务器规划 2.4 启动数据库服务器 2.5. 配置读写分离 2.6 日志配置 2.7 重启ShardingSphere 2.8 测试 2.9. 负载均衡 2.9.1. 随机负载均衡算法示例 2.9.2. 轮询负…

maven的阿里云镜像地址

在 Maven 中配置阿里云镜像可以加速依赖包的下载,尤其是国内环境下效果明显。以下是阿里云 Maven 镜像的配置方式: 配置步骤:找到 Maven 的配置文件 settings.xml 全局配置:位于 Maven 安装目录的 conf/settings.xml用户级配置&am…

大语言模型信息抽取系统解析

这段代码实现了一个基于大语言模型的信息抽取系统,能够从金融和新闻类文本中提取结构化信息。下面我将详细解析整个代码的结构和功能。1. 代码整体结构代码主要分为以下几个部分:模式定义:定义不同领域(金融、新闻)需要抽取的实体类型示例数据…

Next实习项目总结串联讲解(一)

下面是一些 Next.js 前端面试中常见且具深度的问题,按照逻辑模块整理,同时提供示范回答建议,便于你条理清晰地展示理解与实践经验。 ✅ 面试讲述结构建议 先讲 Next.js 是什么,它为什么比 React 更高级。(支持 SSR/SSG/ISR,提升S…

React开发依赖分析

1. React小案例: 在界面显示一个文本:Hello World点击按钮后,文本改为为:Hello React 2. React开发依赖 2.1. 开发React必须依赖三个库: 2.1.1. react: 包含react所必须的核心代码2.1.2. react-dom: react渲染在不同平…

工具(一)Cursor

目录 一、介绍 二、如何打开文件 1、从idea跳转文件 2、单独打开项目 三、常见使用 1、Chat 窗口 Ask 对话模式 1.1、使用技巧 1.2 发送和使用 codebase 发送区别 1.3、问题快速修复 2、Chat 窗口 Agent 对话模式 2.1、agent模式功能 2.2、Chat 窗口回滚&撤销 2.3…

Prompt编写规范指引

1、📖 引言 随着人工智能生成内容(AIGC)技术的快速发展,越来越多的开发者开始利用AIGC工具来辅助代码编写。然而,如何编写有效的提示词(Prompt)以引导AIGC生成高质量的代码,成为了许…

自我学习----绘制Mark点

在PCB的Layout过程中我们需在光板上放置Mark点以方便生产时的光学定位(三点定位);我个人Mark点绘制步骤如下: layer层:1.放置直径1mm的焊盘(无网络连接) 2.放置一个圆直径2mm,圆心与…

2025年财税行业拓客破局:小蓝本财税版AI拓客系统助力高效拓客

2025年,在"金税四期"全面实施的背景下,中国财税服务市场迎来爆发式增长,根据最新的市场研究报告,2025年中国财税服务行业产值将达2725.7亿元。然而,行业高速发展的背后,80%的财税公司却陷入获客成…

双向链表,对其实现头插入,尾插入以及遍历倒序输出

1.创建一个节点,并将链表的首节点返回创建一个独立节点,没有和原链表产生任何关系#include "head.h"typedef struct Node { int num; struct Node*pNext; struct Node*pPer; }NODE;后续代码:NODE*createNode(int value) {NODE*new …

2025年自动化工程与计算机网络国际会议(ICAECN 2025)

2025年自动化工程与计算机网络国际会议(ICAECN 2025) 2025 International Conference on Automation Engineering and Computer Networks一、大会信息会议简称:ICAECN 2025 大会地点:中国柳州 审稿通知:投稿后2-3日内通…

12.Origin2021如何绘制误差带图?

12.Origin2021如何绘制误差带图?选中Y3列→点击统计→选择描述统计→选择行统计→选择打开对话框输入范围选择B列到D列点击输出量→勾选均值和标准差Control选择下面三列点击绘图→选择基础2D图→选择误差带图双击图像→选择符号和颜色点击第二个Sheet1→点击误差棒→连接选择…

如何使用API接口获取淘宝店铺订单信息

要获取淘宝店铺的订单信息,您需要通过淘宝开放平台(Taobao Open Platform, TOP)提供的API接口来实现。以下是详细步骤:1. 注册淘宝开放平台账号访问淘宝开放平台注册开发者账号并完成实名认证创建应用获取App Key和App Secret2. 申请API权限在"我的…

【Kiro Code 从入门到精通】重要的功能

一、Kiro 是什么? Kiro 是一款智能型集成开发环境(IDE),借助规格说明(specs)、向导(steer)、钩子(hooks)帮助你高效完成工作。 二、Specs 规格说明 规范&…