PyTorch 中mm和bmm函数的使用详解

torch.mm 是 PyTorch 中用于 二维矩阵乘法(matrix-matrix multiplication) 的函数,等价于数学中的 A × B 矩阵乘积。


一、函数定义

torch.mm(input, mat2) → Tensor

执行的是两个 2D Tensor(矩阵)的标准矩阵乘法。

  • input: 第一个二维张量,形状为 (n × m)
  • mat2: 第二个二维张量,形状为 (m × p)
  • 返回:形状为 (n × p) 的张量

二、使用条件和注意事项

条件说明
仅支持 2D 张量一维或三维以上使用 torch.matmul@ 操作符
维度要匹配input.shape[1] == mat2.shape[0]
不支持广播两个矩阵维度不匹配会直接报错
结果是普通矩阵乘积不是逐元素乘法(Hadamard),即不是 *torch.mul()

三、示例代码

示例 1:基本矩阵乘法

import torchA = torch.tensor([[1., 2.], [3., 4.]])   # 2x2
B = torch.tensor([[5., 6.], [7., 8.]])   # 2x2C = torch.mm(A, B)
print(C)

输出:

tensor([[19., 22.],[43., 50.]])

计算步骤:

C[0][0] = 1*5 + 2*7 = 19
C[0][1] = 1*6 + 2*8 = 22
...

示例 2:不匹配维度导致报错

A = torch.rand(2, 3)
B = torch.rand(4, 2)
C = torch.mm(A, B)  # ❌ 会报错

报错:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x3 and 4x2)

示例 3:推荐写法(推荐使用 @matmul

A = torch.rand(3, 4)
B = torch.rand(4, 5)C1 = torch.mm(A, B)
C2 = A @ B                # 推荐用法
C3 = torch.matmul(A, B)   # 推荐用法

四、与其他乘法函数的比较

函数名支持维度运算类型支持广播
torch.mm仅限二维矩阵乘法❌ 不支持
torch.matmul1D, 2D, ND自动判断点乘 / 矩阵乘✅ 支持
torch.bmm批量二维乘法3D Tensor batch × batch❌ 不支持
torch.mul任意维度元素乘(Hadamard)✅ 支持
* 运算符任意维度元素乘✅ 支持
@ 运算符ND(推荐用)矩阵乘法(和 matmul 一样)

五、典型应用场景

  • 神经网络权重乘法:output = torch.mm(W, x)
  • 点云 / 图像变换:x' = torch.mm(R, x) + t
  • 多层感知机中的矩阵计算
  • 注意力机制中 QK^T 乘积

六、总结:什么时候用 mm

使用场景用什么
仅二维矩阵乘法torch.mm
高维或支持广播乘法torch.matmul / @
批量矩阵乘法 (如 batch_size×3×3)torch.bmm
元素乘torch.mul or *

在 PyTorch 中,torch.bmm批量矩阵乘法(batch matrix multiplication) 的操作,专用于处理三维张量(batch of matrices)。它的主要作用是对一组矩阵成对进行乘法,效率远高于手动循环计算。


一、torch.bmm 语法

torch.bmm(input, mat2, *, out=None) → Tensor
  • input: Tensor,形状为 (B, N, M)
  • mat2: Tensor,形状为 (B, M, P)
  • 返回结果形状为 (B, N, P)

这表示对 BN×MM×P 的矩阵进行成对相乘。


二、示例演示

示例 1:基础用法

import torch# 定义两个 batch 矩阵
A = torch.randn(4, 2, 3)  # shape: (B=4, N=2, M=3)
B = torch.randn(4, 3, 5)  # shape: (B=4, M=3, P=5)# 批量矩阵乘法
C = torch.bmm(A, B)       # shape: (4, 2, 5)print(C.shape)  # 输出: torch.Size([4, 2, 5])

示例 2:手动循环 vs bmm 效率对比

# 慢速手动方式
C_manual = torch.stack([A[i] @ B[i] for i in range(A.size(0))])# 等效于 bmm
C_bmm = torch.bmm(A, B)print(torch.allclose(C_manual, C_bmm))  # True

三、注意事项

1. 维度必须是三维张量

  • 否则会报错:
RuntimeError: batch1 must be a 3D tensor

你可以通过 .unsqueeze() 手动调整维度:

a = torch.randn(2, 3)
b = torch.randn(3, 4)# 升维
a_batch = a.unsqueeze(0)  # (1, 2, 3)
b_batch = b.unsqueeze(0)  # (1, 3, 4)c = torch.bmm(a_batch, b_batch)  # (1, 2, 4)

2. 维度必须满足矩阵乘法规则

  • (B, N, M) × (B, M, P)(B, N, P)
  • M 不一致会报错:
RuntimeError: Expected size for the second dimension of batch2 tensor to match the first dimension of batch1 tensor

3. bmm 不支持广播(broadcasting)

  • 必须显式提供相同的 batch size。
  • 如果只有一个矩阵固定,可以使用 .expand()
A = torch.randn(1, 2, 3)  # 单个矩阵
B = torch.randn(4, 3, 5)  # 4 个矩阵# 扩展 A 以进行 batch 乘法
A_expand = A.expand(4, -1, -1)
C = torch.bmm(A_expand, B)  # (4, 2, 5)

四、在实际应用中的例子

在点云变换中:批量乘旋转矩阵

# 假设有 B 个旋转矩阵和点坐标
R = torch.randn(B, 3, 3)       # 旋转矩阵
points = torch.randn(B, 3, N)  # 点云# 先转置点坐标为 (B, N, 3)
points_T = points.transpose(1, 2)  # (B, N, 3)# 用 bmm 做点变换:每组点乘旋转
transformed = torch.bmm(points_T, R.transpose(1, 2))  # (B, N, 3)

五、总结

特性torch.bmm
操作对象三维张量(batch of matrices)
核心规则(B, N, M) x (B, M, P) = (B, N, P)
是否支持广播❌ 不支持,需要手动 .expand()
matmul 区别matmul 支持更多广播,bmm 更高效用于纯批量矩阵乘法
应用场景批量线性变换、点云配准、神经网络前向传播等

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/910304.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/910304.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Qt 解析复杂对象构成

Qt 解析复杂对象构成 dumpStructure 如 QComboBox / QCalendarWidget / QSpinBox … void Widget::Widget(QWidget* parent){auto c new QCalendarWidget(this);dumpStructure(c,4); }void Widget::dumpStructure(const QObject *obj, int spaces) {qDebug() << QString…

山姆·奥特曼:从YC到OpenAI,硅谷创新之星的崛起

名人说&#xff1a;路漫漫其修远兮&#xff0c;吾将上下而求索。—— 屈原《离骚》 创作者&#xff1a;Code_流苏(CSDN)&#xff08;一个喜欢古诗词和编程的Coder&#x1f60a;&#xff09; 山姆奥特曼&#xff1a;从YC到OpenAI&#xff0c;硅谷创新之星的崛起 在人工智能革命…

PHP语法基础篇(五):流程控制

任何 PHP 脚本都是由一系列语句构成的。一条语句可以是一个赋值语句&#xff0c;一个函数调用&#xff0c;一个循环&#xff0c;一个条件语句或者甚至是一个什么也不做的语句&#xff08;空语句&#xff09;。语句通常以分号结束。此外&#xff0c;还可以用花括号将一组语句封装…

怎么隐藏关闭或恢复显示输入法的悬浮窗

以搜狗输入法为例&#xff0c;隐藏输入法悬浮窗 悬浮窗在输入法里的官方叫法为【状态栏】。 假设目前大家的输入法相关显示呈现如下状态&#xff1a; 那我们只需在输入法悬浮窗&#xff08;状态栏&#xff09;的任意位置鼠标右键单击&#xff0c;调出输入法菜单&#xff0c;就…

Electron (02)集成 SpringBoot:服务与桌面程序协同启动方案

本篇是关于把springboot生成的jar打到electron里&#xff0c;在生成的桌面程序启动时springboot服务就会自动启动。 虽然之后并不需要这种方案&#xff0c;更好的是部署[一套服务端&#xff0c;多个客户端]...但是既然搭建成功了&#xff0c;也记录一下。 前端文件 1、main.js…

2025年计算机应用与神经网络国际会议(CANN 2025)

2025 International Conference on Computer Applications and Neural Networks &#xff08;一&#xff09;会议信息 会议简称&#xff1a;CANN 2025 大会地点&#xff1a;中国重庆 收录检索&#xff1a;提交Ei Compendex,CPCI,CNKI,Google Scholar等 &#xff08;二&#x…

振动分析中的低频噪声问题:从理论到实践的完整解决方案

前言 在振动监测和结构健康监测领域&#xff0c;我们经常需要从加速度信号计算速度和位移。然而&#xff0c;许多工程师在实际应用中都会遇到一个令人困扰的问题&#xff1a;通过积分计算得到的速度和位移频谱中低频噪声异常放大。 本文将深入分析这个问题的根本原因&#xf…

ncu学习笔记01——合并访存

全局内存通过缓存实现加载和存储过程。其中&#xff0c;L1为一级缓存&#xff0c;每个SM都有自己的L1&#xff1b;L2为二级缓存&#xff0c;L2则被所有SM共有。 数据从全局内存到SM的传输过程中&#xff0c;会去L1和L2中查询是否有缓存。对全局内存的访问将经过L1&#xff1b;…

2012 - 正方形矩阵

​​​​题目描述 晶晶同学非常喜欢方形&#xff0c;她希望打印出来的字符串也是方形的。老师给了晶晶同学一个字符串"ACM"&#xff0c;晶晶同学突发奇想&#xff0c;如果任意给定义一个整数n&#xff0c;能不能打印出由这个字符串组成的正方形字符串呢&#xff1f;…

C++中set的常见用法

在 C 里&#xff0c;std::set属于标准库容器的一种&#xff0c;其特性是按照特定顺序存储唯一的元素。下面为你详细介绍它的常见使用方法&#xff1a; 1. 头文件引入 要使用std::set&#xff0c;需要在代码中包含相应的头文件&#xff1a; #include <set> 2. 集合的定…

stm32移植freemodbus

1、设置串口 开启串口中断 2、设置定时器 已知在freemodbus中默认定义&#xff1a;当波特率大于19200时&#xff0c;判断一帧数据超时时间固定为1750us&#xff0c;当波特率小于19200时&#xff0c;超时时间为3.5个字符时间。这里移植的是115200&#xff0c;所以一帧数据超时…

鸿蒙next 使用canvas实现ecg动态波形绘制

该代码可在Arkts 与 前端使用&#xff0c;基于canvas 仓库地址&#xff1a;https://gitee.com/harmony_os_example/harmony-os-ecg-waveform.git 代码中的list数组为波形数据&#xff0c;该示例需要根据自己业务替换绘制频率&#xff0c;波形数据&#xff0c;ecg原始数据生成…

基于原生能力的键盘控制

基于原生能力的键盘控制 前言一、进入页面TextInput获焦1、方案2、核心代码 二、点击按钮或其他事件触发TextInput获焦1、方案2、核心代码 三、键盘弹出后只上抬特定的输入组件1、方案2、核心代码 四、监听键盘高度1、方案2、核心代码 五、设置窗口在键盘抬起时的页面避让模式为…

大数据治理域——数据存储与成本管理

摘要 本文主要探讨了数据存储与成本管理的多种策略。介绍了数据压缩技术&#xff0c;如MaxCompute的archive压缩方法&#xff0c;通过RAID file形式存储数据&#xff0c;可有效节省空间&#xff0c;但恢复时间较长&#xff0c;适用于冷备与日志数据。还详细阐述了数据生命周期…

国产Linux银河麒麟操作系统上使用自带openssh远程工具SSH方式登陆华为交换机或服务器

在Windows和Linux Debian系统上我一直使用electerm远程工具访问服务器或交换机&#xff0c; 一、 electerm简介 简介&#xff1a;electerm是一款开源免费的SSH工具&#xff0c;具有良好的跨平台兼容性&#xff0c;适用于Windows、macOS、Linux以及麒麟操作系统。特点&#xf…

Logback 在java中的使用

Logback 是 Java 应用中广泛使用的日志框架&#xff0c;以下是其核心使用方法及最佳实践&#xff1a; 1. 引入依赖 在 Maven 或 Gradle 项目中添加 Logback 及 SLF4J 依赖&#xff1a; <!-- Maven --> <dependency><groupId>ch.qos.logback</groupId>…

Axure应用交互设计:中继器—整行、条件行、当前行赋值

亲爱的小伙伴,如有帮助请订阅专栏!跟着老师每课一练,系统学习Axure交互设计课程! Axure产品经理精品视频课https://edu.csdn.net/course/detail/40420 课程主题:对中继器中:整行、符合某种条件的任意行、当前行的赋值操作 课程视频:

ToolsSet之:TTS及Morse编解码

ToolsSet是微软商店中的一款包含数十种实用工具数百种细分功能的工具集合应用&#xff0c;应用基本功能介绍可以查看以下文章&#xff1a; Windows应用ToolsSet介绍https://blog.csdn.net/BinField/article/details/145898264其中Text菜单中的TTS & Morse可用于将文本转换…

【C++】编码传输:创建零拷贝帧对象4:shared_ptr转unique_ptr给到rtp打包

【C++】编码传输:创建零拷贝帧对象3: dll api转换内部的共享内存根本原因 你想要的是基于 packet 指向的那个已有对象,拷贝(或移动)出一个新的 VideoDataPacket3 实例,因此需要把那个对象本身传进去——也就是 *packet。copilot的原因分析与gpt一致 The issue is with t…

基于UDP的套接字通信

udp是一个面向无连接的&#xff0c;不安全的&#xff0c;报式传输层协议&#xff0c;udp的通信过程默认也是阻塞的。使用UDP进行通信&#xff0c;服务器和客户端的处理步骤比TCP要简单很多&#xff0c;并且两端是对等的 &#xff08;通信的处理流程几乎是一样的&#xff09;&am…