Pytorch-07 如何快速把已经有的视觉模型权重扒拉过来为己所用

下载,保存,加载,使用模型权重

在这一节里面我们会过一遍对模型权重的常用操作,比如:

  • 如何下载常用模型的预训练权重
  • 如何下载常用模型的无训练权重(只下载网络结构)
  • 如何加载模型权重
  • 如何保存权重
  • 加载模型权重后进行推理的注意事项

写在开头:权重是以什么形式存在的?

光用,是肯定够的,但是如果你能稍微懂一点原理,那么你很有可能在某一日突然融会贯通,做出非常牛逼的优化。

Pytorch模型会将学习到的参数存储在称为state_dict的内部状态字典中。为了深入探究,我们可以创建一个单线性层简单模型,然后看看它的状态字典长啥样:
在这里插入图片描述

import torch
import torch.nn as nn
import jsonclass SimpleLinearModel(nn.Module):def __init__(self):super().__init__()self.linear = nn.Linear(5, 2) # 五个输入,两个输出def forward(self, x):return self.linear(x)model = SimpleLinearModel()
model_state_dict = model.state_dict()
print(model_state_dict)

以上代码块的输出如下:
在这里插入图片描述
虽然它是个字典,但是由于张量的存在让他不是一个友好的键值对结构,而是元组结构,但是我们还是可以把他转换JSON来直观感受一下:

{"linear.weight": [[0.1941000074148178,0.22420001029968262,-0.3236999809741974,-0.1558000087738037,0.2337000072002411],[0.15130000114440918,0.11470000594854355,0.3953999876976013,-0.33970001339912415,-0.20650000870227814]],"linear.bias": [0.046799998730421066,-0.03530000150203705]
}

这下直观多了,可以看到,state_dict存储了两个学习的参数,其中包括了一个W2×5W_{2\times5}W2×5的全连接矩阵和一个长度为2的偏置向量bbb

不过需要注意的是这是一个有序字典,这样才能保证数据在流经权重文件的时候才能一层一层的被处理。

下载并保存常用预训练模型权重

torchvision.models包内置了很多的不同任务模型权重,包括但不限于图像分类,语义分割,实例分割,关键点检测,视频分类,光流等等,你可以逛逛这个权重菜市场,这里我就不放图介绍了。

一般情况下,你可能要使用这些预训练的模型来进行 迁移学习, pytorch加载这些权重相当简单,一般的调用公式为:

model = model.模型名(weights='用什么数据集预训练的')

举个例子,我们加载一下vgg的在IMAGENET1K_V1上的权重,看看它结构如何

model = models.vgg16(weights='IMAGENET1K_V1')
print(model)

在这里插入图片描述
可以看到这个包含预训练权重的网络模型已经被我们保存到state_dict中了,但是它目前还只是在内存里面,没有写入外存(硬盘),如果你想把它保存到本地,你可以这样做:

torch.save(model.state_dict(), 'vgg1-_model_weights.pth')

这样就会把你当前的权重保存为一个.pth文件。

加载模型权重-仅加载权重

OK, 现在你已经有了一个vgg模型权重,那你要怎么把它加载到对应的网络上呢?

正常的步骤是这样的:

  1. 创建同一模型的实例 (不指定数据集的时候说明只要结构不要预训练权重)
  2. 使用load_state_dict()方法加载参数
model = models.vgg16() # 这里没有指定数据集,说明只要结构
model.load_state_dict(torch.load('model_weights.pth', weights_only=True)) # 这里加载权重,该.pth文件只包含权重,不包含结构。

注意,如果你使用torch.save(model.state_dict(), 'path'),只有权重会被保存!如果你想在保存权重的同时也保存模型结构,你可以这么做:

torch.save(model, 'model.pth')

这个做法的优点是可以在加载被这样保存的权重的时候无需初始化对应的网络:

model = torch.load('model.pth', weights_only=False) # 说明这个model.pth并不是只保存了权重,还有模型架构,所以不需要先实例化再加载权重

但是,模型和权重一起加载并不是pytorch官方推荐的最佳实践! pytorch官方推荐的方式还是只保存模型权重,要加载的时候先实例化网络再加载权重(后一段就是讲这个的)。这是因为.pth文件的解析基于pickle协议实现,而pickle文件不仅仅是数据存储,它还可以包含可执行代码。当 torch.load() 反序列化一个 pickle 文件时,它会执行文件中的字节码来重新创建对象。

这代表:如果一个 .pth 文件是恶意创建的,它可能包含恶意代码。当你在不知情的情况下加载这个文件时,这些恶意代码会被执行,从而导致你的系统受到攻击,例如被植入病毒、窃取数据等。weights_only=True 的作用就是切断这个风险链。它告诉 PyTorch:“我只信任文件中的张量数据。不要执行任何其他的 Python 对象代码,即使它们存在于文件中。”这样就有效地防止了潜在的恶意代码被执行。

使用预训练权重进行推理

这里要说的不多,用预训练权重加载好模型之后记得打开.eval评估模式再开始推理:

model.eval()

.eval() 方法的作用是将模型切换到评估(evaluation)模式。这个模式会关闭一些在训练时才需要的特殊层,以确保模型在推理时能够产生一致且可预测的结果。

具体来说,他会关闭这两个层的以下作用:

.eval() 是 PyTorch 模型推理时一个非常重要的步骤,你提到的这一点非常关键。

为什么需要在推理前调用 model.eval()

.eval() 方法的作用是将模型切换到评估(evaluation)模式。这个模式会关闭一些在训练时才需要的特殊层,以确保模型在推理时能够产生一致且可预测的结果。

具体来说,.eval() 主要影响以下两种类型的层:

  1. Dropout

    • 训练模式 (model.train())Dropout 层会以一定的概率随机“丢弃”一些神经元的输出,以防止模型过拟合。这意味着每次前向传播(forward pass)时,网络结构都是不一样的。
    • 评估模式 (model.eval())Dropout 层会被关闭。所有神经元都参与计算,不再随机丢弃。这确保了在推理时,每次对同一输入进行预测,都会得到完全相同的结果。
  2. BatchNorm(批量归一化)层

    • 训练模式 (model.train())BatchNorm 层会根据当前批次(batch)的输入数据来计算均值和方差,并用这些统计量进行归一化。
    • 评估模式 (model.eval())BatchNorm 层会停止更新均值和方差。它会使用在训练阶段已经学到的、全局的、固定的均值和方差来进行归一化。这同样是为了确保推理结果的稳定性,因为在推理时,我们通常只处理单个样本或小批量的样本,它们的统计量没有代表性。

如果没有调用 model.eval(),模型将保持在训练模式。这将导致:

  1. 结果不稳定:因为 Dropout 层会随机丢弃神经元,即使输入相同,每次推理的结果也可能不同。
  2. 结果不准确BatchNorm 层会使用不稳定的批次统计量进行归一化,而不是使用训练时学到的稳定统计量,这会导致推理结果的准确性下降。

所以,为了得到稳定、准确且可复现的推理结果,在使用预训练模型进行预测时,必须在推理循环之前调用 model.eval()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/92167.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/92167.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言零基础第9讲:指针基础

目录 1.内存和地址 2.指针变量和地址 2.1 取地址操作符(&) 2.2 指针变量 2.3 解引用操作符(*) 2.4 指针变量的大小 3.指针变量类型的意义 3.1 指针的解引用 3.2 指针 - 整数 3.3 void*指针 4.指针运算 4.1 指针…

013 HTTP篇

3.1 HTTP常见面试题 1、HTTP基本概念: 超文本传输协议:在计算机世界里专门在「两点」之间「传输」文字、图片、音频、视频等「超文本」数据的「约定和规范」HTTP常见的状态码 [[Pasted image 20250705140705.png]]HTTP常见字段 Host 字段:客户…

每日面试题20:spring和spring boot的区别

我曾经写过一道面试题,题目是为什么springboot项目可以直接打包给别人运行?其实这涉及到的就是springboot的特点。今天来简单了解一下springboot和spring的区别, Spring 与 Spring Boot:从“全能框架”到“开箱即用”的进化之路 …

ClickHouse数据迁移

ClickHouse实例是阿里云上的云实例,想同步数据到本地,本地部署有ClickHouse实例,下面为单库单表 源实例:阿里云cc-gs5xxxxxxx.public.clickhouse.ads.aliyuncs.com:8123 目标实例:本地172.16.22.10:8123 1、目标实例建…

sqli-labs-master/Less-41~Less-50

Less-41这一关还是用堆叠注入,这关数字型不需要闭合了。用堆叠的话,我们就不爆信息了。我们直接用堆叠,往进去写一条数据?id-1 union select 1,2,3;insert into users (id,username,password) values(666,zk,180)--看一下插进去了没?id-1 u…

Tiger任务管理系统-10

十是个很好美好的数字,十全十美,确实没让人失望,收获还是很大的。 温习了前端知识,巩固了jQuery,thymeleaf等被忽视的框架,意外将之前的所学所用的知识都连起来了,感觉有点像打通了任督二脉一样…

ora-01658 无法为表空间 users中的段创建initial区

ora-01658 无法为表空间 users中的段创建initial区 参考1 参考2 参考3 参考4 给用户新增表空间 alter tablespace system add datafile D:\APP\ADMINISTRATOR\ORADATA\ORCL\SYSTEM03.DBF size 5G autoextend on next 10M;设置表空间文件自动扩展 ALTER DATABASE DATAFILE /…

lodash的替代品es-toolkit详解

一、es-toolkit简介 es-toolkit 是一款先进的高性能 JavaScript 实用程序库,体积小巧,并支持强类型注释,典型特征包括: 提供各种日常实用函数并采用现代实现,例如: debounce、delay、chunk、sum 和 pick 等 设计充分考虑了性能,在现代 JavaScript 环境中实现了 2-3 倍…

【原创】基于gemini-2.5-flash-preview-05-20多模态模型实现短视频的自动化二创

画面和解说保持一致,这个模型就是NB[16:57:37] [*] 正在从视频中提取帧和时长 (频率: 1.0 帧/秒)... [16:57:55] [] 提取完成。视频时长: 83.40秒, 提取了 84 帧。 [16:57:55] [*] 使用AI供应商: gemini [16:57:55] [*] 正在进行视觉分析... [16:57:55] L-> 正…

数仓架构 数据表建模

数仓架构 主要用来描述 数据加工的实时链路 和 离线链路之间的关系,即 流批 关系; lamda 架构, 是两条路, 实时计算式的, 维护数据的实时性。然后每天经过批计算后, 覆盖实时的计算结果。 保证数据准确性。 kappa架构, 即流批一体了 数据建模 星型模型是数据仓库中最…

vscode调试python脚本时无法进入函数内部的解决方法

只需在launch.json配置文件中添加“justMyCode”:false.

Python day37

浙大疏锦行 python day37. 内容: 保存模型只需要保存模型的参数即可,使用的时候直接构建模型再导入参数即可 # 保存模型参数 torch.save(model.state_dict(), "model_weights.pth")# 加载参数(需先定义模型结构) mod…

ORACLE进阶操作

1 事务 事务的任务便是使数据库从一种状态变换成为另一种状态,这不同于文件系统,它是数据库所特用的。 所有的数据库中,事务只针对DML(增删改),不针对select select只能查看其他事务提交或回滚的数据,不能查…

Modbus 的一些理解

疑问:(使用的是Modbustcp)我在 Modbus slave 上面设置了slave地址为1,位置为40001的位置的值为1,40001这个位置上面的值是怎么存储的,存储在哪里的?他们是怎么进行交互的?在Modbus协…

【运动控制框架】WPF运动控制框架源码,可用于激光切割机,雕刻机,分板机,点胶机,插件机等设备,开箱即用

WPF运动控制框架源码,可用于激光切割机,雕刻机,分板机,点胶机,插件机等设备,考虑到各运动控制硬件不同,视觉应用功能(应用视觉软件)也不同,所以只开发各路径编…

RabbitMQ-日常运维命令

作者介绍:简历上没有一个精通的运维工程师。请点击上方的蓝色《运维小路》关注我,下面的思维导图也是预计更新的内容和当前进度(不定时更新)。中间件,我给它的定义就是为了实现某系业务功能依赖的软件,包括如下部分:Web服务器代理…

【Linux基础知识系列】第九十篇 - 使用awk进行文本处理

在Linux系统中,文本处理是一个常见的任务,尤其是在处理日志文件、配置文件和数据文件时。awk是一个功能强大的文本处理工具,广泛用于数据提取、分析和格式化。它不仅可以处理简单的文本文件,还可以处理复杂的结构化数据&#xff0…

第二十七天(数据结构:图)

图:是一种非线性结构形式化的描述: G{V,R}V:图中各个顶点元素(如果这个图代表的是地图,这个顶点就是各个点的地址)R:关系集合,图中顶点与顶点之间的关系(如果是地图,这个关系集合可能就代表的是各个地点之间的距离)在顶点与顶点…

数据赋能(386)——数据挖掘——迭代过程

概述重要性如下:提升挖掘效果:迭代过程能不断优化数据挖掘模型,提高挖掘结果的准确性和有效性,从而更好地满足业务需求。适应复杂数据:数据往往具有复杂性和多样性,通过迭代可以逐步探索和适应数据的特点&a…

什么是键值缓存?让 LLM 闪电般快速

一、为什么 LLMs 需要 KV 缓存?大语言模型(LLMs)的文本生成遵循 “自回归” 模式 —— 每次仅输出一个 token(如词语、字符或子词),再将该 token 与历史序列拼接,作为下一轮输入,直到…