torch 高维矩阵乘法分析,一文说透

文章目录

    • 简介
    • 向量乘法
    • 二维矩阵乘法
    • 三维矩阵乘法
      • 广播
    • 高维矩阵乘法
    • 开源

简介

一提到矩阵乘法,大家对于二维矩阵乘法都很了解,即 A 矩阵的行乘以 B 矩阵的列。
但对于高维矩阵乘法可能就不太清楚,不知道高维矩阵乘法是怎么在计算。

建议使用torch.matmul 做矩阵乘法,其支持向量乘法 和 二维、乃至多维的矩阵乘法。

向量乘法

a1 = torch.tensor([1, 2])
res1 = torch.matmul(a1, a1)
print(res1)
print(res1.shape)

输出:

tensor(5)
torch.Size([])

torch 也支持使用 @ 完成乘法操作

二维矩阵乘法

a2 = torch.tensor([[1, 2]])
res2 = torch.matmul(a2, a2.transpose(-2, -1))
print(res2)
print(res2.shape)

输出:

tensor([[5]])
torch.Size([1, 1])

torch.mm@ 也可以做二维矩阵乘法:

  • a2 @ a2.transpose(-2, -1)
  • torch.mm(a2, a2.transpose(-2, -1))

三维矩阵乘法

torch.bmm 支持三维矩阵乘法,不支持更高维度的矩阵乘法

a3 = torch.randn(2, 3, 2)
res3 = torch.bmm(a3,a3.transpose(-1, -2)
)
print(res3)
print(res3.shape)

输出:

tensor([[[ 4.5979,  0.6648,  2.9231],[ 0.6648,  0.1155,  0.4713],[ 2.9231,  0.4713,  1.9805]],[[ 1.0323,  1.8212, -0.3546],[ 1.8212,  3.5445, -0.3834],[-0.3546, -0.3834,  0.2988]]])
torch.Size([2, 3, 3])

a3 的 shape是(2, 3, 2),a3 底层的两个维度做转置之后变成(2, 2, 3),才可以做矩阵乘法。
可以发现第一位的数字都是2。高维矩阵做乘法的时候,除了最后两个维度,高维矩阵前面的维度两个矩阵要保持一致。

torch.randn(2, 3, 2) @ torch.randn(3, 2, 3)

在这里插入图片描述
虽然上述两个矩阵,在最后两个维度满足矩阵运算的条件,但是第一个维度两个矩阵的值不一样,所以不能做矩阵乘法。

广播

但是发现:

t1 = torch.randn(1, 3, 2)
t2 = torch.randn(3, 2, 3)
t1 @ t2

输出:

tensor([[[-0.6557,  1.0518,  0.3055],[-0.2876, -2.5104, -1.4417],[ 1.4447, -0.1799,  0.4602]],[[ 0.2971,  0.0060, -0.2612],[-0.9089,  1.0824,  0.7131],[ 0.0929, -0.7898, -0.0199]],[[ 0.0027,  1.2031,  0.1543],[-0.5603, -1.8567, -0.1302],[ 0.3978, -0.9356, -0.1977]]])

理论上两个矩阵的高维度的shape不一样,就不可以做矩阵乘法。但上述 t1t2可以做矩阵乘法。这是因为 t1 的第一个维度是1,就会自动做广播。

广播的效果类似于,把 t1 在第一个维度复制成与t2一样,第一个维度都变成3。
在下述使用 concat完成复制工作,再做矩阵乘法,发现可以得到上述一样的结果。

torch.concat((t1, t1, t1)) @ t2

输出:

tensor([[[-0.6557,  1.0518,  0.3055],[-0.2876, -2.5104, -1.4417],[ 1.4447, -0.1799,  0.4602]],[[ 0.2971,  0.0060, -0.2612],[-0.9089,  1.0824,  0.7131],[ 0.0929, -0.7898, -0.0199]],[[ 0.0027,  1.2031,  0.1543],[-0.5603, -1.8567, -0.1302],[ 0.3978, -0.9356, -0.1977]]])

高维矩阵乘法

矩阵乘法只会在最后两个维度,用A矩阵的行乘以B矩阵的列。
其他的维度都是对应位置的数据,互相做乘法(类似向量乘法)。

high_matrix1 = torch.randn(2, 3, 4, 5)
high_matrix2 = torch.randn(2, 3, 5, 4)
high_result = high_matrix1 @ high_matrix2

把最后两个维度看成一个点。更高的维度的矩阵乘法,可想象为两个矩阵对应位置的点相乘。

比如,shape(2, 3, 4, 5)与shape(2, 3, 5, 4)的矩阵相乘,若把最后两个维度看成一个点。就可以类比为 (2, 3) 与 (2, 3)的两个矩阵做向量乘法,就是对应位置的点做乘法。

如下面的运行结果所示。针对两个矩阵,在高维空间中,选取(1,2)对应的小矩阵数据做矩阵乘法得到的结果。与两个矩阵乘法的结果对应(1,2)的值是一样的。

(high_matrix1[1][2] @  high_matrix2[1][2]) == high_result[1][2]

输出:

tensor([[True, True, True, True],[True, True, True, True],[True, True, True, True],[True, True, True, True]])

开源

https://github.com/JieShenAI/csdn/blob/main/25/06/torch_matmul/run.ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/web/83637.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

瑞萨RA-T系列芯片马达类工程TCM加速化设置

本篇介绍在使用RA8-T系列芯片,建立马达类工程应用时,如何将电流环部分的指令和变量设置到TCM单元,以提高电流环执行速度,从而提高系统整体的运行性能,在伺服和高端工业领域有很高的实用价值。本文以RA8T1为范例&#x…

获取Unity节点路径

解决目的: 避免手动拼写节点路径的时候,出现路径错误导致获取不到节点的情况。解决效果: 添加如下脚本之后,将自动复制路径到剪贴板中,在代码中通过 ctrlv 粘贴路径代码如下: public class CustomMenuItems…

Docker 安装 Oracle 12C

镜像 https://docker.aityp.com/image/docker.io/truevoly/oracle-12c:latest docker pull swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/truevoly/oracle-12c:latest docker tag swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/truevoly/oracle-12c:latest d…

Linux内核网络协议注册与初始化:从proto_register到tcp_v4_init_sock的深度解析

一、协议注册:proto_register的核心使命 在Linux网络协议栈中,proto_register是协议初始化的基石,主要完成三项关键任务: Slab缓存创建(内存管理核心) prot->slab = kmem_cache_create_usercopy(prot->name, prot->obj_size, ...); if (prot->twsk_prot) pr…

GD32 MCU的真随机数发生器(TRNG)

GD32 MCU的真随机数发生器(TRNG) 文章目录 GD32 MCU的真随机数发生器(TRNG)一、定义与核心特征二、物理机制:量子与经典随机性三、生成方法四、应用场景五、与伪随机数的对比六、局限性⚙️ 七、物理熵源原理🔧 八、硬件实现流程(以GD32F450 GD32L233为例)8.1. **初始…

Vulkan学习笔记6—渲染呈现

一、渲染循环核心 while (!glfwWindowShouldClose(window)) {glfwPollEvents();helloTriangleApp.drawFrame(); // 绘制帧} 在 Vulkan 中渲染帧包含一组常见的步骤 等待前一帧完成(vkWaitForFences) 从交换链获取图像(vkAcquireNextImageKH…

React第六十二节 Router中 createStaticRouter 的使用详解

前言 createStaticRouter 是 React Router 专为 服务端渲染(SSR) 设计的 API,用于在服务器端处理路由匹配和数据加载。它在构建静态 HTML 响应时替代了客户端的 BrowserRouter,确保 SSR 和客户端 Hydration 的路由状态一致。 一…

qt 双缓冲案例对比

双缓冲 1.双缓冲原理 单缓冲:在paintEvent中直接绘制到屏幕,绘制过程被用户看到 双缓冲:先在redrawBuffer绘制到缓冲区,然后一次性显示完整结果 代码结构 单缓冲:所有绘制逻辑在paintEvent中 双缓冲:绘制…

华为云AI开发平台ModelArts

华为云ModelArts:重塑AI开发流程的“智能引擎”与“创新加速器”! 在人工智能浪潮席卷全球的2025年,企业拥抱AI的意愿空前高涨,但技术门槛高、流程复杂、资源投入巨大的现实,却让许多创新构想止步于实验室。数据科学家…

ParaGraphX [特殊字符]

https://github.com/stevechampion1/paragraphx 一个基于 JAX 的、为 CPU/GPU 加速而生的超高性能图算法库。 ParaGraphX 是一个实验性的 Python 库,旨在利用 JAX 的即时编译 (JIT) 和大规模并行计算能力,为经典的图算法提供惊人的性能提升。我们的目标…

如何用4 种可靠的方法更换 iPhone(2025 年指南)

Apple 每年都会发布新版本的 iPhone。升级到新 iPhone 是一种令人兴奋的体验,但转移所有宝贵数据的想法有时会让人感到畏惧。幸运的是,我们准备了 4 种有效的更换 iPhone 的方法,让你可以毫不费力地更换到你的新 iPhone。 此外,您…

GitLab 拉取变慢的原因及排查方法

前言:在软件开发的快节奏世界里,高效协作与快速交付是制胜关键。然而,当开发团队兴高采烈地投入工作,却发现从GitLab拉取代码的速度慢如蜗牛,那种沮丧感简直能瞬间浇灭热情。在分布式开发环境中,这种情况时…

落水人员目标检测数据集(猫脸码客第253期)

落水人员目标检测:科技守护生命之舟 一、背景与意义 随着人类海洋活动和水上活动的日益频繁,海上与水域安全事故频发。每年都会开展大量的海上救援行动,以搜救数以万计的落难人员。在水上活动区域,如水库、河道等,溺…

JAVA_强制类型转换:

类型范围大的变量,不可以直接赋值给类型变量小的变量 需要进行强制类型转换: 想要完成类型范围大的变量传给类型范围小的变量需要先创建一个新的变量(类型与方法的形参类型要相同)。将类型范围大的变量前面加上(转换类…

打卡第44天:无人机数据集分类

重复以下内容 作业: kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化 进阶: 并拆分成多个文件 import os import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader,…

个人网站大更新,还是有个总站比较好

个人网站大更新,还是有个总站比较好 放弃了所有框架,用纯htmlcssjs撸了个网站,这回可以想改啥改啥了。 选择了黑紫作为主色调,暂时看着还算可以。 为什么不用那些框架了 几个原因: 尝试用vuepress、vitepress、not…

高精度算法详解:从原理到加减乘除的完整实现

文章目录 一、为什么需要高精度算法二、高精度算法的数据结构设计2.1 基础工具函数2.2 高精度加法实现2.3 高精度减法实现2.4 高精度乘法实现2.5 高精度除法实现 三、完整测试程序四、总结 一、为什么需要高精度算法 在编程中,处理极大数值是常见需求,例…

排序--计数排序

一,引言 计数排序是一种针对整数数据的高效排序算法。其主要流程可分为三个步骤:首先计算整数数据的数值范围;接着按大小顺序统计各数值的出现次数;最后根据统计结果输出排序后的数据序列。 二,求最值 遍历现有数据,获取最大值…

Kubernetes安全机制深度解析(四):动态准入控制和Webhook

#作者:程宏斌 文章目录 动态准入控制什么是准入 Webhook? 尝试准入Webhook先决条件编写一个准入 Webhook 服务器部署准入 Webhook 服务即时配置准入 Webhook对 API 服务器进行身份认证 Webhook 请求与响应Webhook 配置匹配请求-规则匹配请求&#xff1a…

WDK 10.0.19041.685,可在32位win7 sp1系统下搭配vs2019使用,可以编译出xp驱动。

(14)[驱动开发]配置环境 VS2019 WDK10 写 xp驱动 (14)[驱动开发]配置环境 VS2019 WDK10 写 xp驱动_microsoft visual 2019 wdk-CSDN博客文章浏览阅读3k次,点赞8次,收藏17次。本文介绍了如何在VS2019环境下安装和配置Windows Driver Kit(WDK)&#xff0…