「日拱一码」087 机器学习——SPARROW

目录

SPARROW 介绍

核心思想:稀疏掩码训练

与 Lottery Ticket Hypothesis (LTH) 的关系

代码示例

代码关键点解释:


在机器学习领域,"SPARROW" 并不是一个像 Scikit-learn、TensorFlow 或 PyTorch 那样广为人知的通用框架或算法名称。经过查询,最相关的 "SPARROW" 是指一篇重要的研究论文或其中提出的技术。

最著名的 "SPARROW" 来自 Google Research 在 2020年发表的一篇论文 《Rigging the Lottery: Making All Tickets Winners》

SPARROW 介绍

核心思想:稀疏掩码训练

传统的模型剪枝流程是:训练一个大模型 -> 剪枝(移除不重要的权重) -> 微调。这个过程通常非常耗时。

SPARROW(在论文中更常被称为 稀疏掩码训练 或 Lottery Ticket Hypothesis 的扩展)提出了一种截然不同的方法:

在训练一开始就随机初始化一个网络,并立即应用一个预先定义好的稀疏性掩码(Sparsity Mask),使得网络从一开始就是稀疏的。然后,在整个训练过程中,这个掩码保持不变,只更新那些未被掩码掩盖的权重。

这种方法的核心优势在于:

  1. 效率高:模型从始至终都是稀疏的,训练和推理的计算开销、内存占用都显著降低。
  2. 性能好:论文表明,通过找到合适的初始化和固定掩码(即“中奖彩票”),这种稀疏网络可以达到甚至有时超过原始稠密模型的精度。
  3. 简单直接:无需复杂的剪枝调度或微调阶段。

与 Lottery Ticket Hypothesis (LTH) 的关系

LTH 假设指出:一个随机初始化的稠密网络中,包含一个子网络(“中奖彩票”),当被单独训练时,其性能可以媲美原始网络。

SPARROW 可以看作是 LTH 的一个极其高效的实践版本。它不是先训练再找彩票,而是假设一个随机初始化的掩码就是一张“潜在”的中奖彩票,并直接训练这个稀疏子网络,省去了寻找彩票的昂贵过程。

代码示例

下面实现一个最简单的 SPARROW 风格训练示例。流程如下:

  1. 创建一个全连接神经网络。
  2. 随机生成一个固定掩码,指定哪些权重参与训练(例如,50% 的稀疏度)。
  3. 在训练过程中,应用这个掩码:在前向传播时,权重会被掩码;在反向传播后,只有未被掩码的权重会被更新。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms# 1. 定义模型
class SimpleFC(nn.Module):def __init__(self):super(SimpleFC, self).__init__()self.fc1 = nn.Linear(784, 256)self.fc2 = nn.Linear(256, 128)self.fc3 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 784) # 展平输入x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return F.log_softmax(x, dim=1)# 2. 创建模型、优化器、损失函数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleFC().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.NLLLoss()# 3. 创建并应用固定稀疏掩码 (50% 稀疏度)
def create_sparsity_mask(model, sparsity=0.5):masks = {}for name, param in model.named_parameters():if 'weight' in name:# 为权重矩阵创建一个相同形状的随机掩码mask = torch.rand_like(param.data) > sparsitymasks[name] = mask.to(device)# 初始应用掩码:将不参与训练的权重置零param.data *= maskreturn maskssparsity_mask = create_sparsity_mask(model, sparsity=0.5)# 4. 训练循环
def train(model, device, train_loader, optimizer, criterion, mask, epochs=5):model.train()for epoch in range(epochs):for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()# SPARROW 关键步骤:在 optimizer.step() 之前,应用掩码到梯度上# 确保只有被掩码选中的权重才有非零梯度,从而被更新with torch.no_grad():for name, param in model.named_parameters():if name in mask:param.grad *= mask[name]optimizer.step()# SPARROW 另一个关键步骤:在参数更新后,再次应用掩码到权重上# 确保被剪枝的权重始终保持为零with torch.no_grad():for name, param in model.named_parameters():if name in mask:param.data *= mask[name]if batch_idx % 100 == 0:print(f'Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}'f' ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')# 5. 加载数据并开始训练
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)print("开始训练 SPARROW 稀疏模型...")
train(model, device, train_loader, optimizer, criterion, sparsity_mask, epochs=3)
# 开始训练 SPARROW 稀疏模型...
# Epoch: 0 [0/60000 (0%)]	Loss: 2.300159
# Epoch: 0 [6400/60000 (11%)]	Loss: 0.135472
# Epoch: 0 [12800/60000 (21%)]	Loss: 0.146012
# Epoch: 0 [19200/60000 (32%)]	Loss: 0.177537
# Epoch: 0 [25600/60000 (43%)]	Loss: 0.034564
# Epoch: 0 [32000/60000 (53%)]	Loss: 0.165950
# Epoch: 0 [38400/60000 (64%)]	Loss: 0.214527
# Epoch: 0 [44800/60000 (75%)]	Loss: 0.239639
# Epoch: 0 [51200/60000 (85%)]	Loss: 0.173407
# Epoch: 0 [57600/60000 (96%)]	Loss: 0.087583
# Epoch: 1 [0/60000 (0%)]	Loss: 0.040576
# Epoch: 1 [6400/60000 (11%)]	Loss: 0.092811
# Epoch: 1 [12800/60000 (21%)]	Loss: 0.397150
# Epoch: 1 [19200/60000 (32%)]	Loss: 0.221431
# Epoch: 1 [25600/60000 (43%)]	Loss: 0.218968
# Epoch: 1 [32000/60000 (53%)]	Loss: 0.164273
# Epoch: 1 [38400/60000 (64%)]	Loss: 0.122340
# Epoch: 1 [44800/60000 (75%)]	Loss: 0.197523
# Epoch: 1 [51200/60000 (85%)]	Loss: 0.268147
# Epoch: 1 [57600/60000 (96%)]	Loss: 0.203193
# Epoch: 2 [0/60000 (0%)]	Loss: 0.115242
# Epoch: 2 [6400/60000 (11%)]	Loss: 0.276544
# Epoch: 2 [12800/60000 (21%)]	Loss: 0.515723
# Epoch: 2 [19200/60000 (32%)]	Loss: 0.202442
# Epoch: 2 [25600/60000 (43%)]	Loss: 0.092944
# Epoch: 2 [32000/60000 (53%)]	Loss: 0.090384
# Epoch: 2 [38400/60000 (64%)]	Loss: 0.145279
# Epoch: 2 [44800/60000 (75%)]	Loss: 0.155133
# Epoch: 2 [51200/60000 (85%)]	Loss: 0.091369
# Epoch: 2 [57600/60000 (96%)]	Loss: 0.216552

代码关键点解释:

  1. 创建掩码 (create_sparsity_mask): 为每个权重矩阵生成一个随机二进制掩码(1表示保留,0表示剪枝)
  2. 初始化应用掩码: 在训练开始前,将模型的权重与掩码相乘,使部分权重归零
  3. 梯度掩码: 在反向传播计算出梯度后、优化器更新权重之前,将梯度与掩码相乘。这确保了只有被保留的权重才会被更新,被剪枝的权重的梯度始终为零
  4. 权重掩码: 在优化器更新完权重后,再次将权重与掩码相乘。这是一个保护步骤,防止由于优化器(如带有动量的SGD)的更新操作可能使本应为零的权重产生微小的数值变化

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/96814.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/96814.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

18、决策树与集成学习 - 从单一智慧到群体决策

学习目标:理解决策树的构建原理和分裂标准,掌握信息增益、基尼系数等概念,学会决策树的剪枝方法,深入理解集成学习的思想,掌握随机森林和梯度提升的基本原理。 > 从第17章到第18章:从概率模型到规则模型 在第17章中,我们学习了逻辑回归——一个基于概率的线性分类器…

王道计算机组成原理 学习笔记

第一章计算机系统概述1.1计算机的发展历程1.2计算机系统层次结构1.2.11.2.2 计算机硬件的基本组成1.2.2 各个硬件的工作原理1.2.3 计算机软件1.2.4 计算机系统的层次结1.2.5 计算机系统的工作原理1.3计算机的性能指标第二章数据的表示和运算第三章存储系统第四章指令系统第五章…

Oracle 笔记1 表空间及用户

Oracle 笔记1 表空间及用户1 安装Oracle2 创建表空间3 创建表空间用户1. 核心管理用户2. 示例与工具用户3. 系统与服务用户4. 创建表空间用户5. 修改表空间用户特性OracleMySQL开发商Oracle 公司最初由 MySQL AB 开发,后被 Sun 收购,现属 Oracle 公司数据…

MyBatis主键返回机制解析

关于 MyBatis 主键返回的深入解释 核心问题:信息隔离 数据库和应用程序是两个独立的系统: 数据库在服务器上执行 INSERT 操作并生成主键应用程序在另一个进程或甚至另一台机器上运行如果没有明确的机制,应用程序无法自动知道数据库生成了什么…

【Python】Python内置函数大全解析(附源码)

目录专栏导读前言🚀 功能特性1. 全面的函数覆盖2. 多种查询工具3. 完整的测试验证🛠️ 使用方法基本使用交互式查询运行测试📚 支持的内置函数分类数学运算 (13个)类型转换 (8个)序列操作 (8个)迭代器 (6个)输入输出 (3个)对象操作 (31个)&am…

每日算法题推送

题目1:快乐数 我们先来结合实例看一下判断快乐数的整个过程: 结合题目可以知道,如果一个数是快乐数,那么这个数最终就会变成1,如果一个数不是快乐数,那么变化序列最终就会陷入循环。想一下,如果…

Oracle体系结构-数据文件(Data Files)

一、 数据文件的本质与原理 物理存储的基石: 数据文件是 Oracle 数据库在操作系统层面最核心、最基础的物理存储单元。它们是存储在服务器硬盘(或存储阵列)上的操作系统文件(如 .dbf, .ora 扩展名常见,但非强制&#x…

【C++练习】18.C++求两个整数的最小公倍数(LCM)

目录C求两个整数的最小公倍数(LCM)的方法方法一:利用最大公约数(GCD)计算代码实现方法二:逐次增加法代码实现方法三:质因数分解法代码实现方法比较处理大数和特殊情况改进版GCD方法实现 C求两个整数的最小公倍数(LCM)的方法 最小公倍数(LCM)是…

Linux网络:应用层协议http

前言 虽然我们说,应用层协议是我们程序猿自己定的。但实际上,已经有大佬们定义了一些现成的,又非常好用的应用层协议,供我们直接参考使用.HTTP(超文本传输协议)就是其中之一。 我们之前已经学了UDP与TCP套接字的简单使用,以及讲解了进程间的各种关系&a…

ffmpeg推流测试

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录前言一、操作步骤1.测试12.测试2总结前言 提示:这里可以添加本文要记录的大概内容: 环境信息: 摄像头:usb摄像头 &a…

Docker的使用及核心命令

文章目录Docker基础概念镜像管理命令镜像查看和搜索镜像下载和删除镜像构建容器生命周期管理创建和启动容器容器控制命令容器清理容器交互和调试进入容器文件操作日志和监控数据管理数据卷(Volume)绑定挂载网络管理网络基础操作端口映射Dockerfile和Dock…

考研408计算机网络第36题真题解析(2021-2023)

(2023.36)在使用 CSMA/CD 协议的环境中,使用截断二进制指数退避算法,来选择重传时机,算法 有如下规定: (1)基本的退避时间为争用期 2τ,假设某网络具体的争用期为 51.2us…

Asio C++ Library是用来做什么的

hriskohlhoff/asio 是由 Chris Kohlhoff 主导维护的开源 C 库,专注于提供高效、跨平台的异步 I/O 支持,广泛应用于网络编程、并发控制和高性能系统开发。 📘 项目概述 项目名称:Asio C Library 下载地址:https://down…

ac791的按键ad_channel

每次ad_channel这个参数都要给我一定的迷惑性,让我以为这是通道的数量

机器人巡检与巡逻的区别进行详细讲解和对比

机器人巡检与巡逻的区别进行详细讲解和对比 尽管这两个词经常被混用,但在技术和应用层面上,它们有着本质的区别。核心区别在于:巡检是“深度体检”,而巡逻是“治安巡查”。 以下将从多个维度进行详细讲解和对比。 一、核心概念与目…

先进电机拓扑及控制算法介绍(3)——以“数据”驱动电机实现真正的无模型

1. 背景介绍 之前已经介绍过“无模型预测控制(Model-Free Predictive Control/MFPC)”中的“无模型预测电流控制(Model-Free Predictive Current Control/MFPCC)”,可参考下面知乎。 https://zhuanlan.zhihu.com/p/6…

C primer plus (第六版)第十一章 编程练习第5,6题

题目:5.设计并测试⼀个函数,搜索第1个函数形参指定的字符串,在其中查找第2个函数形参指定的字符⾸次出现的位置。如果成功,该函数返指向该字符的指针,如果在字符串中未找到指定字符,则返回空指针…

Altium Designer(AD)PCB丝印批量修改

目录 1 Altium Designer(AD)PCB丝印的字体批量修改 1.1选中所有丝印 1.1.1选中一个丝印:鼠标左键点击 1.1.2查找相似对象:鼠标右键或快捷键N 1.1.3如下图所示丝印被全部选中 1.2丝印字体信息修改 1.2.1打开属性面板——>位置/属性/字体修改 1.2.2丝印字体修改 1.2.…

AI+华为HarmonyOS开发工具DevEco Studio详细安装指南

作者:长江支流 日期:2025-09-13 第一部分:AI工具使用 一、如何使用DeepSeek帮助自己的工作? (一)提示词 为了与时俱进,充分利用最新技术、提高效率,采用AI生成部分材料&#xf…

【Ambari监控】— API请求逻辑梳理

附录:完整内容和源代码下载请参照 https://doc.janettr.com/ 一、前序章节回忆 我们在前面章节拆解了 Collector 的启动过程,并定位了控制器 TimelineWebServices。 本节聚焦 Collector 对外暴露的 REST 服务,搭建「接口全景图」。 二、接口…