PyTorch nn.Parameter理解及初始化方法总结


一、理解 nn.Parameter

  1. 本质是什么?

    • nn.Parametertorch.Tensor 的一个子类
    • 这意味着它继承了 Tensor 的所有属性和方法(如 .data, .grad, .requires_grad, .shape, .dtype, .device, .backward() 等)。
    • 它本身不是一个函数或模块,而是一种特殊的张量类型
  2. 核心目的/作用:

    • 标识模型参数: 它的主要作用是标记一个 Tensor,告诉 PyTorch 这个 Tensor 是模型的一部分,是需要被优化器在训练过程中更新(学习) 的参数。
    • 自动注册: 当一个 nn.Parameter 被分配为一个 nn.Module属性时(通常在 __init__ 方法中),PyTorch 会自动将其注册到该 Module 的 parameters() 列表中。这是最关键的特性!
    • 优化器目标: 优化器(如 torch.optim.SGD, torch.optim.Adam)通过调用 model.parameters() 来获取所有需要更新的参数。只有注册了的 nn.Parameter(以及 nn.Module 子模块中的 nn.Parameter)才会被包含在这个列表中。
  3. 与普通 Tensor (torch.Tensor) 的区别:

    特性nn.Parameter普通 torch.Tensor
    自动注册✅ 当是 Module 属性时自动加入 parameters()❌ 不会自动注册
    优化器更新目标✅ 默认会被优化器更新❌ 默认不会被优化器更新
    requires_grad默认为 True默认为 False
    用途定义模型需要学习的权重 (Weights) 和偏置 (Biases)存储输入数据、中间计算结果、常量、缓冲区等
  4. 关键结论:

    • 如果你想定义一个会被优化器更新的模型参数(权重、偏置),务必使用 nn.Parameter 包装你的 Tensor,并将其设置为 nn.Module 的属性。
    • 普通 Tensor 即使 requires_grad=True,如果没有被注册(通过 nn.Parameterregister_parameter()),也不会被优化器更新。它们可能用于存储需要梯度的中间状态或自定义计算。

二、nn.Parameter 的创建与初始化方法大全

创建 nn.Parameter 的核心是:先创建一个 Tensor,然后用 nn.Parameter() 包装它。初始化方法的多样性体现在如何创建这个底层的 Tensor。以下是常见方法:

方法 1:直接包装 Tensor (最灵活)

import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self, input_size, output_size):super().__init__()# 方法 1a: 使用 torch.tensor 创建并包装self.weight = nn.Parameter(torch.tensor([[1.0, 2.0], [3.0, 4.0]]))  # 显式指定值 (不常用)# 方法 1b: 使用 torch 函数创建并包装 (最常用!)self.weight = nn.Parameter(torch.randn(input_size, output_size))  # 正态分布初始化self.bias = nn.Parameter(torch.zeros(output_size))               # 常数初始化 (0)def forward(self, x):return x @ self.weight + self.bias
  • 优点: 绝对控制,可以使用任何创建 Tensor 的函数。
  • 常用函数:
    • torch.randn(*size): 标准正态分布 (均值 0, 标准差 1) 初始化。最常用基础初始化。
    • torch.rand(*size): [0, 1) 均匀分布初始化。
    • torch.zeros(*size): 全 0 初始化 (常用于偏置)。
    • torch.ones(*size): 全 1 初始化 (较少直接用于权重)。
    • torch.full(size, fill_value): 用指定值填充。
    • torch.empty(*size).uniform_(-a, a): 在 [-a, a] 均匀分布初始化 (手动实现均匀分布)。
    • torch.empty(*size).normal_(mean, std): 指定均值和标准差的正态分布初始化 (手动实现正态分布)。

方法 2:使用 torch.nn.init 模块 (推荐用于特定初始化策略)

PyTorch 提供了 torch.nn.init 模块,包含许多常用的、研究证明有效的初始化函数。这些函数通常原地修改传入的 Tensor。

class MyModule(nn.Module):def __init__(self, input_size, output_size):super().__init__()# 1. 先创建一个未初始化的或基础初始化的 Tensor (通常用 empty 或 zeros)self.weight = nn.Parameter(torch.empty(input_size, output_size))self.bias = nn.Parameter(torch.zeros(output_size))# 2. 应用 nn.init 函数进行初始化 (原地操作)nn.init.xavier_uniform_(self.weight)  # Xavier/Glorot 均匀初始化 (适用于 tanh, sigmoid)# nn.init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu')  # Kaiming/He 正态初始化 (适用于 ReLU)nn.init.constant_(self.bias, 0.1)    # 将偏置初始化为常数 0.1
  • 优点: 使用经过验证的、针对不同激活函数设计的初始化策略,通常能获得更好的训练起点和稳定性。代码更清晰表达意图。
  • 常用 nn.init 函数:
    • 常数初始化:
      • nn.init.constant_(tensor, val):用 val 填充。
      • nn.init.zeros_(tensor):全 0。
      • nn.init.ones_(tensor):全 1。
    • 均匀分布初始化:
      • nn.init.uniform_(tensor, a=0.0, b=1.0)[a, b] 均匀分布。
    • 正态分布初始化:
      • nn.init.normal_(tensor, mean=0.0, std=1.0):指定 meanstd 的正态分布。
    • Xavier / Glorot 初始化 (适用于饱和激活函数如 tanh, sigmoid):
      • nn.init.xavier_uniform_(tensor, gain=1.0):均匀分布,范围基于 gain 和输入/输出神经元数 (fan_in/fan_out) 计算。
      • nn.init.xavier_normal_(tensor, gain=1.0):正态分布,标准差基于 gain 和 fan_in/fan_out 计算。
    • Kaiming / He 初始化 (适用于 ReLU 及其变种):
      • nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):均匀分布。
      • nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):正态分布。
      • mode: 'fan_in' (默认, 保持前向传播方差) 或 'fan_out' (保持反向传播方差)。
      • nonlinearity: 'relu' (默认) 或 'leaky_relu'
    • 正交初始化 (Orthogonal Initialization):
      • nn.init.orthogonal_(tensor, gain=1):生成正交矩阵(或行正交、列正交的张量),有助于缓解深度网络中的梯度消失/爆炸。
    • 单位矩阵初始化 (Identity Initialization):
      • nn.init.eye_(tensor):尽可能将张量初始化为单位矩阵(对于非方阵,会初始化成尽可能接近单位矩阵的形式)。适用于某些 RNN 或残差连接。
    • 对角矩阵初始化 (Diagonal Initialization):
      • nn.init.dirac_(tensor):尽可能初始化为 Dirac delta 函数(多维卷积核中常用,保留输入通道信息)。主要用于卷积层。

方法 3:使用 nn.Linear, nn.Conv2d 等内置模块 (隐式初始化)

当你使用 PyTorch 提供的标准层(如 nn.Linear, nn.Conv2d, nn.LSTM 等)时,它们内部已经使用 nn.Parameter 定义了权重和偏置,并自动应用了合理的默认初始化策略

class MyModule(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(784, 256)  # 内部包含 weight (Parameter) 和 bias (Parameter)self.fc2 = nn.Linear(256, 10)# 这些层的参数已经被初始化了 (通常是某种均匀或正态分布)
  • 优点: 最方便,无需手动创建和初始化参数。对于标准层,默认初始化通常足够好。
  • 查看/修改内置模块的初始化:
    • 你可以通过 nn.Linear.weight.datann.Conv2d.bias 访问这些内置的 nn.Parameter
    • 如果你想修改默认初始化,可以在创建层后,使用 nn.init 函数重新初始化它们的 .weight.bias
      self.fc1 = nn.Linear(784, 256)
      nn.init.kaiming_normal_(self.fc1.weight, mode='fan_out', nonlinearity='relu')
      nn.init.constant_(self.fc1.bias, 0.0)
      

方法 4:从另一个模块或状态字典加载 (预训练/迁移学习)

# 假设 pretrained_model 是一个已经训练好的模型
pretrained_dict = pretrained_model.state_dict()# 创建新模型
new_model = MyModule()# 获取新模型的状态字典
model_dict = new_model.state_dict()# 1. 严格加载: 键必须完全匹配
new_model.load_state_dict(pretrained_dict)# 2. 非严格加载: 只加载键匹配的参数
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
new_model.load_state_dict(model_dict)# 3. 部分加载/初始化: 手动指定
new_model.some_layer.weight.data.copy_(pretrained_model.some_other_layer.weight.data)  # 直接复制数据
  • 优点: 迁移学习、微调、模型集成的基础。利用预训练知识加速训练或提升性能。

三、总结与最佳实践

  1. 何时用 nn.Parameter

    • 总是用它来定义你的模型层中需要被优化器更新的权重 (Weights)偏置 (Biases)
    • 对于模型中的输入数据、中间计算结果、常量、统计量(如 BatchNorm 的 running_mean)等不需要更新的张量,使用普通 torch.Tensor(通常通过 self.register_buffer() 注册为缓冲区,以便正确地在设备间移动和序列化)。
  2. 初始化方法选择:

    • 新手/快速原型: 使用内置层 (nn.Linear, nn.Conv2d 等),它们有合理的默认初始化。
    • 自定义层/需要特定策略:
      • 基础: torch.randn (正态), torch.zeros (偏置)。
      • 推荐: 优先使用 torch.nn.init 模块中的函数
        • 对于使用 tanh / sigmoid 的网络:考虑 nn.init.xavier_uniform_ / nn.init.xavier_normal_
        • 对于使用 ReLU / LeakyReLU 的网络:强烈推荐 nn.init.kaiming_uniform_ / nn.init.kaiming_normal_ (根据 modenonlinearity 选择)。
      • 特殊需求: 常数 (nn.init.constant_),正交 (nn.init.orthogonal_),单位矩阵 (nn.init.eye_),Dirac (nn.init.dirac_)。
  3. 关键步骤:

    • __init__ 中:
      1. 使用 torch.* 函数或 torch.empty 创建一个 Tensor。
      2. nn.Parameter() 包装这个 Tensor。
      3. (可选但推荐) 使用 nn.init.*_ 函数对这个 nn.Parameter 进行原地初始化(如果使用 torch.randn / torch.zeros 等创建时已初始化,此步可省)。
    • 对于内置层,初始化已自动完成,但可以按需修改。
    • 对于预训练模型,使用 load_state_dict 加载参数进行初始化。
  4. 注意:

    • 初始化参数的大小 (size) 必须根据层的设计(输入维度、输出维度、卷积核大小等)正确设置。
    • 确保 nn.Parameter 被设置为 nn.Module属性(直接赋值),否则不会被自动注册到 parameters() 中。
    • 理解不同初始化方法背后的原理(如 Xavier/Glorot 考虑输入输出方差, Kaiming/He 考虑 ReLU 的修正)对于设计深层网络非常重要。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/87396.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/87396.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Linux】环境基础和开发工具

Linux 软件包管理器 yum 什么是软件包 在Linux下安装软件, 一个通常的办法是下载程序的源代码, 并进行编译, 得到可执行程序. 但是这样太麻烦了, 于是有些人把一些常用的软件提前编译好, 做成软件包(可以理解成windows上的安装程序)放在一个服务器上, 通过包管理器可以很方便…

多模态进化论:GPT-5V图文推理能力在工业质检中的颠覆性应用

前言 前些天发现了一个巨牛的人工智能免费学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站 🚀《多模态进化论:GPT-5V图文推理能力在工业质检中的颠覆性应用》 副标题:2025年实测报告显…

Linux实现一主二从模式

主从复制: 复制概念中分为两类数据库,一类是主数据库(master),一类是从数据(slave),主 数据库可以进行读写操作,并把写的操作同步给从数据库,一般从数据库是只…

大势智慧亮相第十八届中国智慧城市大会

6月26日-28日,第十八届中国智慧城市大会在武汉盛大举行。本次大会以“数智赋能城市创新协同共治发展蓝图”为主题,汇聚了李德仁、刘经南等八位院士及全国智慧城市领域的专家学者、行业精英,共同探讨行业发展新方向。作为实景三维技术领域领军…

Xbox One 控制器转换为 macOS HID 设备的工作原理分析

Xbox One 控制器转换为 macOS HID 设备的工作原理分析 源代码在 https://github.com/guilhermearaujo/xboxonecontrollerenabler.git 这个工程的核心功能是将 Xbox One 控制器(macOS 原生不支持的设备)转换为 macOS 可识别的 HID 设备。这里通过分析代…

Notepad++ 复制宏、编辑宏的方法

Notepad具有宏的功能,能够记录当下所有操作,后续只需要一键就可以重复执行,大大减少工作量。 比如我需要把很多文件里面的字符完成替换,那我只需要把替换的过程录制成宏,后续打开文件就可以一键替换了。 但是Notepad的…

Oracle:报错jdbc:oracle:thin:@IP地址:端口:实例名, errorCode 28001, state 99999

报错原因是oracle密码过期,根本解决办法是让密码不再过期,永久有效。具体操作记录一下。 cmd命令行输入: sqlplus / as sysdba修改Oracle密码期限为无限: SQL> ALTER PROFILE DEFAULT LIMIT PASSWORD_LIFE_TIME UNLIMITED;SQL&…

Apipost 签约中原消费金融:共建企业级 API 全链路协作平台,推动接口管理与测试智能化升级

随着企业数字化转型的不断深化,API 正在从技术细节演变为业务协作的核心枢纽。特别是在金融行业,微服务架构、系统联动、合规要求等多重因素交织下,接口数量激增、管理复杂度提升、质量保障难度加大。近日,Apipost 与中原消费金融…

AntV L7 之LarkMap 地图

一、安装$ npm install -S antv/l7 antv/larkmap # or $ yarn add antv/l7 antv/larkmap二、引入包import type { LarkMapProps, LineLayerProps } from antv/larkmap; import { LarkMap, LineLayer, Marker } from antv/larkmap;三、config配置const layerOptions:Omit<Lin…

客户案例 | 某新能源车企依托Atlassian工具链+龙智定制开发服务,打造符合ASPICE标准的研发管理体系

客户案例 ASPICE标准已成为衡量整车厂及供应商研发能力的重要标尺。某知名车企在其重点项目研发过程中&#xff0c;面临着ASPICE 4.0评估认证的挑战——项目团队缺乏体系经验、流程规范和数字化支撑工具。 为帮助该客户团队顺利通过ASPICE认证并提升研发合规性&#xff0c;At…

stm32的USART使用DMA配置成循环模式时发送和接收有着本质区别

stm32的USART使用DMA配置成循环模式时发送和接收有着本质区别&#xff0c;不要被网上误导了。发送数据时会不停的发送数据&#xff0c;而接收只有有数据时才会接收&#xff0c;没有数据时就会挂起等待。 一、触发机制的差异‌ ‌发送方向&#xff08;TX&#xff09;——状态驱…

银河麒麟系统上利用WPS的SDK进行WORD的二次开发

目录 1.下载安装包 2.安装WPS 3.获取示例代码 4.编译示例代码 5.完整示例代码 相关链接 1.下载安装包 去wps的官网 https://www.wps.cn/ 下载linux版本。 下载的安装包名称为&#xff1a;wps-office_12.8.2.21176.AK.preload.sw_amd64.deb, 官网有介绍适用于Ubuntu、麒麟…

人工智能之数学基础:如何判断正定矩阵和负定矩阵?

本文重点 正定矩阵和负定矩阵是线性代数中的重要概念,在优化理论、数值分析、统计学等领域有广泛应用。 正定矩阵(负定矩阵) 如上所示,我们可以看到满足上面的性质的时候,我们可以认为矩阵A称为正定矩阵(负定矩阵) 举例: 半正定(半负定) 如果≥或者≤的时候,我们认为矩…

汇编基础介绍——ARMv8指令集(四)

一、CMP 指令 CMP 指令用来比较两个数的大小。在 A64 指令集的实现中&#xff0c;CMP 指令内部调用 SUBS 指令来实现。 1.1、使用立即数的 CMP 指令 使用立即数的 CMP 指令的格式如下。 CMP <Xn|SP>, #<imm>{, <shift>} 上述指令等同于如下指令。 SUBS …

深入剖析 Electron 性能瓶颈及优化策略

Electron 是一个流行的跨平台桌面应用开发框架&#xff0c;基于 Chromium 和 Node.js&#xff0c;使得开发者可以使用 Web 技术&#xff08;HTML、CSS、JavaScript&#xff09;构建跨平台的桌面应用。许多知名应用如 VS Code、Slack、Discord 和 Figma 都采用了 Electron。然而…

Qt的前端和后端过于耦合(0/7)

最近在写一个软件&#xff0c;这个软件稍微复杂一些&#xff0c;界面大概需要十几个&#xff0c;后端也是要开多线程读各种传感器数据。然后鼠鼠我呀就发现一个致命的问题&#xff0c;那就是前端要求的控件太多了&#xff0c;点一下就需要通知后端&#xff0c;即调用后端的函数…

碰一碰发视频源码搭建定制化开发:支持OEM

在移动互联网与物联网深度融合的当下&#xff0c;“碰一碰发视频” 作为一种创新的信息交互方式&#xff0c;正逐渐应用于营销推广、产品展示、社交互动等多个领域。其核心在于通过近场通信技术&#xff08;如 NFC、蓝牙&#xff09;实现设备间的快速连接&#xff0c;无需复杂操…

机器学习文本特征提取:CountVectorizer与TfidfVectorizer详解

一、文本特征提取概述 在自然语言处理&#xff08;NLP&#xff09;和文本挖掘任务中&#xff0c;文本特征提取是将原始文本数据转换为机器学习模型可以理解的数值特征的关键步骤。scikit-learn提供了两种常用的文本特征提取方法&#xff1a;CountVectorizer&#xff08;词频统…

【PHP】.Hyperf 框架-collection 集合数据(内置函数归纳-实用版)

&#x1f4cc; Article::query()->where(article_id, 6)->select()->first()✍️ 进行数据结果的循环&#xff0c;遍历 1.each() 方法遍历集合中的项目并将每个项目传递给闭包&#xff0c;进行处理数据 Article::query()->get()->each(function ($item) {// 可…

巨兽的阴影:大型语言模型的挑战与伦理深渊

当GPT-4这样的庞然大物能够流畅对话、撰写诗歌、编写代码、解析图像&#xff0c;甚至在某些测试中媲美人类专家时&#xff0c;大型语言模型&#xff08;LLM&#xff09;仿佛成为了无所不能的“智能神谕”。然而&#xff0c;在这令人目眩的成就之下&#xff0c;潜藏着复杂而严峻…