【机器学习】-torch相关知识01

学习代码时遇到的问题,GPT给的答案,如有错误请指出。

问题1 torch.empty nn.init.xavier
问题2 nn.Parameter 是什么?
问题3 self.add_module
问题4 torch.matmul torch.mm

文章目录

    • 问题1 torch.empty nn.init.xavier
    • 问题2 `nn.Parameter` 是什么?
    • 问题3 self.add_module
    • 问题4 `torch.matmul` `torch.mm`

问题1 torch.empty nn.init.xavier

在 PyTorch 里:

self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))

这里的 torch.empty 并不是“空的”的意思,而是 创建一个未初始化的张量

  • torch.empty(shape) 会直接分配一块内存,但 不对数值做初始化
  • 里面的值通常是内存里的随机数(垃圾值)
  • 一般会 马上再用 nn.init(比如 xavier_uniform_)去初始化
self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
nn.init.xavier_uniform_(self.W.data, gain=1.414)

解释:

  1. torch.empty(...):先申请一个形状为 (in_features, out_features) 的张量,里面值是随机的垃圾值。
  2. nn.init.xavier_uniform_:再用 Xavier 初始化self.W 的值设成合理的初始权重。

🔹 举个例子

x = torch.empty(2, 3)
print(x)

可能输出:

tensor([[ 4.5908e-41,  0.0000e+00, -1.8754e+38],[ 1.4013e-45,  0.0000e+00,  2.8026e-45]])

⚠️ 这些值是 内存里原本的残留值,完全没意义。
所以我们几乎总是 立刻再初始化

📌 总结一句:
torch.empty = 只开辟空间,不赋值(可能是内存垃圾值)。
这里用它的原因是:要立刻用 xavier_uniform_ 来做权重初始化。

  • nn.init.xavier_uniform_:Xavier 初始化方法(又叫 Glorot 初始化)
  • self.W.data:表示直接访问 Parameter 的底层 Tensor 数据(绕过 autograd 的计算图)
  • gain=1.414:一个缩放系数(通常是根据激活函数来设定),这里取 2\sqrt{2}2,适合 ReLU / ELU 激活

🔹 Xavier 初始化的原理

在神经网络训练中,如果权重初始化得不好,可能会出现:

  • 前向传播时数值逐层变得很大或很小
  • 反向传播时梯度消失或爆炸

Xavier 初始化的目标是:
让每一层输入和输出的方差尽量保持一致,避免数值逐层放大/缩小。

公式(均匀分布版):

Wij∼U(−6fan_in+fan_out⋅gain,6fan_in+fan_out⋅gain)W_{ij} \sim U\Bigg(-\sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}} \cdot \text{gain}, \; \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}} \cdot \text{gain}\Bigg) WijU(fan_in+fan_out6gain,fan_in+fan_out6gain)

其中:

  • fan_in = 输入维度
  • fan_out = 输出维度
  • gain = 缩放系数(跟激活函数有关)

🔹 为什么 gain=1.414

  • PyTorch 里可以用 nn.init.calculate_gain('relu') 得到对应的 gain
  • 对于 ReLU / ELU 激活,推荐的 gain = 2≈1.414\sqrt{2} ≈ 1.41421.414
  • 因为 GAT 里面用了 ELU 激活,所以这里选 1.414

🔹 举个例子

import torch
import torch.nn as nnW = torch.empty(3, 2)
nn.init.xavier_uniform_(W, gain=1.414)
print(W)

可能输出:

tensor([[-0.6465,  0.5042],[ 0.7393, -0.1285],[-0.2268,  0.6471]])

这些值就是在合适的范围内随机初始化的权重。

📌 总结:
nn.init.xavier_uniform_(self.W.data, gain=1.414)
= 用 Xavier 均匀分布初始化 W,保证前后层方差平衡,避免梯度消失/爆炸,适合 ELU 激活。

问题2 nn.Parameter 是什么?

nn.Parameter 是一种 特殊的张量(Tensor),和普通 torch.Tensor 的区别在于:

👉 如果你把一个 Tensor 包装成 nn.Parameter,并把它赋值给 nn.Module 的属性,那么它就会 自动被认为是模型的可训练参数,会在训练时出现在 model.parameters() 里,交给优化器更新。
🔹 举个例子

import torch
import torch.nn as nn# 普通 Tensor
w1 = torch.randn(3, 3)# Parameter
w2 = nn.Parameter(torch.randn(3, 3))print(isinstance(w1, nn.Parameter))  # False
print(isinstance(w2, nn.Parameter))  # True

如果我们放到一个模型里:

class MyModel(nn.Module):def __init__(self):super().__init__()self.w1 = torch.randn(3, 3)                 # 普通张量self.w2 = nn.Parameter(torch.randn(3, 3))   # 可训练参数model = MyModel()
print(list(model.parameters()))

输出只会包含 w2,而不会包含 w1
也就是说,只有 Parameter 会被 优化器更新

🔹 在你代码里的用法

self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))

这里 self.W线性变换的权重矩阵,必须在训练中更新,所以定义成 nn.Parameter
同样:

self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))

是注意力机制里的权重,也需要更新,所以用 Parameter

🔹 总结

  • torch.Tensor:普通张量,不会自动训练
  • nn.Parameter:可训练张量,自动注册进模型参数
  • 在定义自定义层(如 GAT 层)时,所有需要 梯度更新的权重 都要用 Parameter

问题3 self.add_module

差不多可以这么理解 ✅,但它不是普通的 Python 字典,而是 PyTorch 的模块管理机制
我帮你详细解释一下 self.add_module

定义

nn.Module.add_module(name, module)

  • name: 字符串,子模块的名字
  • module: 一个 nn.Module 对象

作用就是 把子模块注册到当前 Module 里
🔹 内部原理

实际上,nn.Module 内部有一个 OrderedDict 存储子模块:

self._modules: Dict[str, Module]

当你调用:

self.add_module("attention_0", GraphAttentionLayer(...))

就相当于:

self._modules["attention_0"] = GraphAttentionLayer(...)

所以你可以像字典一样访问:

print(self._modules["attention_0"])

🔹 区别

和普通字典不同,_modules 里的对象会被:

  • model.parameters() 自动收集
  • model.to(device) 时自动搬到 GPU
  • model.train() / model.eval() 时自动切换模式

也就是说,只要是 add_module 注册进去的,就能参与训练

🔹 示例

import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super().__init__()self.add_module("layer1", nn.Linear(10, 20))self.add_module("layer2", nn.ReLU())model = MyModel()
print(model._modules)  # 就是一个 OrderedDict

输出类似:

OrderedDict([('layer1', Linear(in_features=10, out_features=20, bias=True)),('layer2', ReLU())])

所以你理解得没错:self.add_module 背后确实是往一个 OrderedDict 里加条目,但它和普通字典不同,会多了很多 PyTorch 的训练支持功能。

问题4 torch.matmul torch.mm

torch.matmul 是 PyTorch 中一个 更通用的矩阵/向量乘法函数,功能比 torch.mm 更强,支持多种情况。

1️⃣ 基本区别

函数支持的输入输出备注
torch.mm(A, B)A、B 都是二维矩阵矩阵乘法 (m,n) × (n,p) → (m,p)仅二维矩阵
torch.matmul(A, B)支持向量、矩阵和高维张量自动选择矩阵乘法或批量乘法通用

2️⃣ 输入和输出规则

(1) 二维矩阵

A.shape = (m, n)
B.shape = (n, p)
C = torch.matmul(A, B)  # C.shape = (m, p)
  • torch.mm 一致

(2) 向量

a.shape = (n,)
b.shape = (n,)
c = torch.matmul(a, b)  # 内积,标量

(3) 高维张量(批量矩阵乘法)

A.shape = (batch, m, n)
B.shape = (batch, n, p)
C = torch.matmul(A, B)  # C.shape = (batch, m, p)
  • torch.matmul 会对 batch 维度自动广播
  • 非常适合 批量图神经网络计算

3️⃣ 在 GAT 中的应用

Wh1 = torch.matmul(Wh, self.a[:self.out_features, :])
  • Wh:节点特征 (N, out_features)
  • self.a[:out_features, :]:权重 (out_features, 1)
  • 输出 Wh1:形状 (N, 1)
  • 功能:对每个节点的特征向量进行线性组合,得到注意力机制的部分输入

相比 torch.mm,这里用 matmul 可以更通用,如果以后是批量矩阵或者高维张量也能工作

4️⃣ 总结

  • torch.mm:只能二维矩阵

  • torch.matmul:通用版本,支持:

    • 向量内积
    • 矩阵乘法
    • 批量矩阵乘法
  • GAT 中:

    • Wh = torch.mm(h, W) → 将输入特征线性映射
    • Wh1 = torch.matmul(Wh, a[:out_features,:]) → 注意力机制的线性组合

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/920806.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/920806.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Hutool DsFactory多数据源切换

一、简单上手&#xff1a;从配置到使用全流程 DsFactory 的核心优势是零侵入配置&#xff0c;支持多种配置方式&#xff0c;不管是 properties 文件还是代码里直接定义&#xff0c;都能快速初始化数据源。先引依赖&#xff08;Maven&#xff09;&#xff1a; <dependency>…

Mysql中事务隔离级别有哪些?

Mysql中事务隔离级别有哪些&#xff1f; 读未提交&#xff1a; 一个事务可以看到另一个事务尚未提交的数据。可能导致脏读。 读已提交&#xff1a; 一个事务只能看到其他事务提交后的数据。避免了脏读&#xff0c;仍可能引发不可重复读。 可重复读&#xff1a; 可以确保一个事务…

el-carousel在新增或者删除el-carousel-item时默认跳到第一页的原因和解决

现象 使用走马灯效果时 当el-carousel-item增加或者减少时&#xff0c;页会跳到第一页 体验很不友好。 原因 当新增或这删除el-carousel-item时&#xff0c;会触发setActiveIndex&#xff08;props.initialindex&#xff09;, setActiveIndex的行为是小于0或者大于最大页会有一…

人工智能学习:机器学习相关面试题(二)

7、有监督学习和无监督学习的区别 有监督学习&#xff1a; 对具有概念标记&#xff08;分类&#xff09;的训练样本进行 学习&#xff0c;以尽可能对训练样本集外的数据进行 标记&#xff08;分类&#xff09;预测。 这里 &#xff0c;所有的标记&#xff08;分类&#xff09…

python如何下载svg图片

# 生成博客文章框架代码 import datetimeblog_content f"""# Python如何下载SVG图片## 引言 SVG&#xff08;可缩放矢量图形&#xff09;作为一种基于XML的矢量图形格式&#xff0c;在Web开发中广泛应用。本文将介绍如何使用Python从网络下载SVG图片&#xff0…

Linux(一) | 初识Linux与目录管理基础命令掌握

个人主页-爱因斯晨 文章专栏-Linux 最近学习人工智能时遇到一个好用的网站分享给大家&#xff1a; 人工智能学习 文章目录个人主页-爱因斯晨文章专栏-Linux一、前言1.为什么学习Linux2.操作系统概述&#xff1a;3.常见的操作系统&#xff1a;二、初识Linux1.诞生2.什么是Linux…

android-studio 安装

下载地址 国内&#xff1a;https://developer.android.google.cn/studio?hlzh-cn 全国&#xff1a;https://developer.android.com/studio 1.设置 ANDROID_HOME 环境变量 ANDROID_HOME D:\zhy\android-studio\sdk 2. 更新 PATH 环境变量 %ANDROID_HOME%\platform-tools %AN…

【重学MySQL】九十三、MySQL字符集与比较规则完全解析

【重学MySQL】九十三、MySQL字符集与比较规则完全解析一、字符集概述1.1 支持的字符集1.2 UTF8与UTF8MB4的区别二、比较规则&#xff08;Collation&#xff09;2.1 比较规则分类2.2 常见比较规则差异三、配置层级与继承关系3.1 配置层级3.2 继承关系四、最佳实践与问题解决4.1 …

基于Kafka的延迟队列

实现原理 通过topic区分不同的延迟时长&#xff0c;每个topic对于一个延迟&#xff0c;比如 topic100 仅存储延迟 100ms 的消息&#xff0c;topic1000 仅存储延迟 1s 的消息&#xff0c;依次类推。生产消息时&#xff0c;消息需按延迟时长投递到对应的topic。消费消息时&#x…

LabVIEW转速仪校准系统

LabVIEW 与机器视觉的智能校准系统以工控机为核心&#xff0c;整合标准源、智能相机等硬件&#xff0c;通过软件实现校准流程自动化&#xff0c;支持 500-6000r/min 转速范围校准&#xff0c;覆盖 5 类转速测量仪&#xff0c;校准时间缩短约 70%&#xff0c;满足计量院高效、精…

Synchronized 概述

1. 初识 synchronized 是 Java 中的关键字&#xff0c;是一种 同步锁 &#xff0c;可重入锁&#xff0c;悲观锁。它修饰的对象有以下几种&#xff1a; 具体表现为以下3种形式。 对于普通同步方法&#xff0c;锁是当前实例对象。 对于静态同步方法&#xff0c;锁是当前类的 Clas…

通过Auth.log来查看VPS服务器是否被扫描和暴力破解及解决办法

说明&#xff1a;很多人vps可能出现过被扫的情况&#xff0c;有的还被爆破了&#xff0c;这里提供下查看方法 查看用密码登陆成功的IP地址及次数grep "Accepted password for root" /var/log/auth.log | awk {print $11} | sort | uniq -c | sort -nr | more查看用密…

碰一碰发视频手机版源码开发:支持OEM

**从事开发 20 年&#xff0c;见过不少技术风口起起落落&#xff0c;最近 “碰一碰发视频” 又成了热门话题。不少同行或刚入行的年轻人来问我&#xff0c;手机版源码开发该从哪下手&#xff0c;怕踩坑、怕走弯路。今天就以一个老程序员的视角&#xff0c;把碰一碰发视频手机版…

只出现一次的数字(总结)

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录前言一、给定一个整数数组nums&#xff0c;除了某个元素只出现一次以外&#xff0c;其余元素均出现两次。找出那个只出现一次的元素二、给你一个整数数组nums&#x…

Cesium 入门教程(十一):Camera相机功能展示

文章目录一&#xff0c;Cesium 实际示例&#xff08;含源代码&#xff09;1&#xff0c;vuecesium&#xff1a; 围绕一个固定点自动左右旋转2&#xff0c;vuecesium&#xff1a; flyto一个具体的实体位置3&#xff0c;vuecesium&#xff1a; flyto一个具体的点位置4&#xff0c…

go语言基本排序算法

package mainimport "fmt"func main() {BubbleSort()SelectSort()InsertSort()MergeSort()QuickSort()HeapSort()ShellSort() }//冒泡排序 func BubbleSort() {str : []int{9, 1, 5, 8, 3, 7, 4, 6, 2}for i : 0; i < len(str)-1; i {flag : falsefor j : len(str…

一步完成CalDAV账户同步,日历服务助力钉钉日历日程集中管理

在信息爆炸节奏飞快的今天&#xff0c;高效的管理时间已经成为我们工作和生活中的核心竞争力&#xff0c;复杂纷繁的日程安排&#xff0c;无处不在的提醒需求以及跨设备同步的困扰&#xff0c;这些问题仿佛都在呼唤着一个更智能、更便捷、更可靠的解决方案。 而华为日历App&am…

企业内部机密视频安全保护|如何防止企业内部机密视频泄露?

在企业数字化进程飞速发展的今天&#xff0c;视频内容已成为承载企业内部培训、战略会议、产品机密和核心技术的关键载体。一次意外的泄露&#xff0c;不仅可能导致知识产权流失&#xff0c;更会让企业声誉和市场竞争力遭受重创。面对无孔不入的安全威胁&#xff0c;企业该如何…

C# Deconstruct | 简化元组与对象的数据提取

官方文档&#xff1a;析构元组和其他类型 - C# | Microsoft Learn 标签&#xff1a;Deconstruct、Tuple、record、模式匹配 PS&#xff1a;record相关内容后续还会继续更新&#x1f504; 模式匹配可以查看我的另一篇&#x1f449;模式匹配 目录1. 概述2. 基本用法2.1 元组解…

R 语言 ComplexUpset 包实战:替代 Venn 图的高级集合可视化方案

摘要 在生物信息学、数据挖掘等领域的集合分析中,传统 Venn 图在多维度数据展示时存在信息拥挤、可读性差等问题。本文基于 R 语言的 ComplexUpset 包,以基因表达研究为场景,从包安装、数据准备到可视化实现,完整演示如何制作正刊级别的集合交集图,解决多条件下差异基因(…