【arXiv 2025】新颖方法:基于快速傅里叶变换的高效自注意力,即插即用!

一、整体介绍 

The FFT Strikes Again: An Efficient Alternative to Self-Attention

FFT再次出击:一种高效的自注意力替代方案

图片

图1:FFTNet整体流程,包括局部窗口处理(STFT或小波变换,可选)和全局FFT,随后在频率/变换域进行等距融合(或门控)。

朋友们,今天为大家介绍一个非常有潜力,未来可能会在自然语言处理、计算机视觉、图像处理等领域发挥重大作用的方法。

中心思想:该方法来源arXiv[1],是2025年3月16日最新公开论文,提出了一种名为FFTNet的自适应频谱滤波框架,该框架利用快速傅里叶变换(FFT)在O(nlogn)时间内实现全局标记混合,有效解决了传统自注意力机制在处理长序列时的二次复杂度问题,把自注意力机制(Self-Attention)的时间复杂度从O(n²)降到O(nlogn)。

实现动机:传统的注意力机制在计算成对标记交互时,随着序列长度n的增加,成本呈二次方增长,这使得处理长序列变得昂贵。相比之下,离散傅里叶变换(DFT)在O(nlogn)时间内自然编码全局交互,因为它将标记序列分解为正交频率分量。

核心原理:根据帕塞瓦尔定理,在傅里叶变换下,对于输入序列X及其傅里叶变换F=FFT(X),信号的总能量保持不变,除了一个常数缩放因子。这一能量保持保证了自适应滤波和非线性操作不会意外扭曲输入信号的固有信息。(学过数字信号处理课程的朋友应该更容易理解,总结起来就是一句话:把信号转换到频域进行处理,不会丢失信号信息,但是可以减少计算量)

证明公式和复杂度计算的公式较为枯燥,本文省略。

下面以代码为例,展示原理及用法。

二、代码与原理解读 

1. 基于快速傅里叶变换的基础网络块——FFTNetBlock

import torch
import torch.nn as nn
import torch.nn.functional as F
class ModReLU(nn.Module):def __init__(self, features):super().__init__()self.b = nn.Parameter(torch.Tensor(features))self.b.data.uniform_(-0.1, 0.1)def forward(self, x):return torch.abs(x) * F.relu(torch.cos(torch.angle(x) + self.b))
class FFTNetBlock(nn.Module):def __init__(self, dim):super().__init__()self.dim = dimself.filter = nn.Linear(dim, dim)self.modrelu = ModReLU(dim)def forward(self, x):# x: [batch_size, seq_len, dim]x_fft = torch.fft.fft(x, dim=1)  # FFT along the sequence dimensionx_filtered = self.filter(x_fft.real) + 1j * self.filter(x_fft.imag)x_filtered = self.modrelu(x_filtered)x_out = torch.fft.ifft(x_filtered, dim=1).realreturn x_out
if __name__ == '__main__':# 参数设置batch_size = 1      # 批量大小seq_len = 224 * 224 # 序列长度(Transformer 中的 token 数量)dim = 32      # 维度# 创建随机输入张量,形状为 (batch_size, seq_len, embed_dim)x = torch.randn(batch_size, seq_len, dim)# 初始化 FFTNetBlock 模块model = FFTNetBlock(dim = dim)print(model)print("微信公众号: AI缝合术!")output = model(x)print(x.shape)print(output.shape) 

运行结果:

图片

该代码实现了一个基于 FFT(快速傅里叶变换)的神经网络块,称为 FFTNetBlock,并在 forward 过程中对输入信号进行频域处理。

图片

实现流程:

①使用 FFT 进行频域转换:输入 x 通过 FFT 转换到频域,在频域进行操作。

②使用可学习的滤波器:通过 nn.Linear 进行频域的线性变换,相当于卷积核在频域对信号进行加权处理。

③使用 ModReLU 进行非线性处理:由于 FFT 产生的结果是复数,传统的 ReLU 不能直接作用,因此使用 ModReLU 进行非线性变换。ModReLU为修正的 ReLU 激活函数,作用类似于ReLU在实数域上的作用,但应用于复数域,通过修改相位角(angle)并结合 ReLU 进行修正。

④最终通过 iFFT 还原回时序空间:经过处理的频域信息通过逆 FFT(ifft)变换回时序域,得到最终输出。

2. 基于快速傅里叶变换的ViT网络——FFTNetViT

import torch
import torch.nn as nn
import torch.nn.functional as F
def drop_path(x, drop_prob: float = 0., training: bool = False):"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""if drop_prob == 0. or not training:return xkeep_prob = 1 - drop_prob# Generate binary tensor mask; shape: (batch_size, 1, 1, ..., 1)shape = (x.shape[0],) + (1,) * (x.ndim - 1)random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)random_tensor.floor_()  # binarizeoutput = x.div(keep_prob) * random_tensorreturn output
class DropPath(nn.Module):"""DropPath module that performs stochastic depth."""def __init__(self, drop_prob=None):super(DropPath, self).__init__()self.drop_prob = drop_probdef forward(self, x):return drop_path(x, self.drop_prob, self.training)class MultiHeadSpectralAttention(nn.Module):def __init__(self, embed_dim, seq_len, num_heads=4, dropout=0.1, adaptive=True):"""频谱注意力模块,在保持 O(n log n) 计算复杂度的同时,引入额外的非线性和自适应能力。参数:- embed_dim: 总的嵌入维度。- seq_len: 序列长度(例如 Transformer 中 token 的数量,包括类 token)。- num_heads: 注意力头的数量。- dropout: 逆傅里叶变换(iFFT)后的 dropout 率。- adaptive: 是否启用自适应 MLP 以生成乘法和加法的自适应调制参数。"""super().__init__()if embed_dim % num_heads != 0:raise ValueError("embed_dim 必须能被 num_heads 整除")self.num_heads = num_headsself.head_dim = embed_dim // num_headsself.seq_len = seq_lenself.adaptive = adaptive# 频域的 FFT 频率桶数量: (seq_len//2 + 1)self.freq_bins = seq_len // 2 + 1# 基础乘法滤波器: 每个注意力头和频率桶一个self.base_filter = nn.Parameter(torch.ones(num_heads, self.freq_bins, 1))# 基础加性偏置: 作为频率幅度的学习偏移self.base_bias = nn.Parameter(torch.full((num_heads, self.freq_bins, 1), -0.1))if adaptive:# 自适应 MLP: 每个头部和频率桶生成 2 个值(缩放因子和偏置)self.adaptive_mlp = nn.Sequential(nn.Linear(embed_dim, embed_dim),nn.GELU(),nn.Linear(embed_dim, num_heads * self.freq_bins * 2))self.dropout = nn.Dropout(dropout)# 预归一化层,提高傅里叶变换的稳定性self.pre_norm = nn.LayerNorm(embed_dim)def complex_activation(self, z):"""对复数张量应用非线性激活函数。该函数计算 z 的幅度,将其传递到 GELU 进行非线性变换,并按比例缩放 z,以保持相位不变。参数:z: 形状为 (B, num_heads, freq_bins, head_dim) 的复数张量返回:经过非线性变换的复数张量,形状相同。"""mag = torch.abs(z)# 对幅度进行非线性变换,GELU 提供平滑的非线性mag_act = F.gelu(mag)# 计算缩放因子,防止除零错误scale = mag_act / (mag + 1e-6)return z * scaledef forward(self, x):"""增强型频谱注意力模块的前向传播。参数:x: 输入张量,形状为 (B, seq_len, embed_dim)返回:经过频谱调制和残差连接的张量,形状仍为 (B, seq_len, embed_dim)"""B, N, D = x.shape# 预归一化,提高频域变换的稳定性x_norm = self.pre_norm(x)# 重新排列张量以分离不同的注意力头,形状变为 (B, num_heads, seq_len, head_dim)x_heads = x_norm.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)# 沿着序列维度计算 FFT,结果为复数张量,形状为 (B, num_heads, freq_bins, head_dim)F_fft = torch.fft.rfft(x_heads, dim=2, norm='ortho')# 计算自适应调制参数(如果启用)if self.adaptive:# 全局上下文:对 token 维度求均值,形状为 (B, embed_dim)context = x_norm.mean(dim=1)# 经过 MLP 计算自适应参数,输出形状为 (B, num_heads*freq_bins*2)adapt_params = self.adaptive_mlp(context)adapt_params = adapt_params.view(B, self.num_heads, self.freq_bins, 2)# 划分为乘法缩放因子和加法偏置adaptive_scale = adapt_params[..., 0:1]  # 形状: (B, num_heads, freq_bins, 1)adaptive_bias  = adapt_params[..., 1:2]  # 形状: (B, num_heads, freq_bins, 1)else:# 如果不使用自适应机制,则缩放因子和偏置设为 0adaptive_scale = torch.zeros(B, self.num_heads, self.freq_bins, 1, device=x.device)adaptive_bias  = torch.zeros(B, self.num_heads, self.freq_bins, 1, device=x.device)# 结合基础滤波器和自适应调制参数# effective_filter: 影响频谱响应的缩放因子effective_filter = self.base_filter * (1 + adaptive_scale)# effective_bias: 影响频谱响应的偏置effective_bias = self.base_bias + adaptive_bias# 在频域进行自适应调制# 先进行乘法缩放,再添加偏置(在 head_dim 维度上广播)F_fft_mod = F_fft * effective_filter + effective_bias# 在频域应用非线性激活F_fft_nl = self.complex_activation(F_fft_mod)# 逆傅里叶变换(iFFT)还原到时序空间# 需要指定 n=self.seq_len 以确保输出长度匹配输入x_filtered = torch.fft.irfft(F_fft_nl, dim=2, n=self.seq_len, norm='ortho')# 重新排列张量,将注意力头合并回嵌入维度x_filtered = x_filtered.permute(0, 2, 1, 3).reshape(B, N, D)# 残差连接并应用 Dropoutreturn x + self.dropout(x_filtered)
class TransformerEncoderBlock(nn.Module):def __init__(self, embed_dim, mlp_ratio=4.0, dropout=0.1, attention_module=None, drop_path=0.0):"""一个通用的 Transformer 编码器块,集成了 drop path 随机深度 。- embed_dim: 嵌入维度。- mlp_ratio: MLP 的扩展因子。- dropout: dropout 比率。- attention_module: 处理自注意力的模块。- drop_path: 随机深度的 drop path 比率。"""super().__init__()if attention_module is None:raise ValueError("必须提供一个注意力模块! 此处应调用 MultiHeadSpectralAttention")self.attention = attention_moduleself.mlp = nn.Sequential(nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),nn.GELU(),nn.Dropout(dropout),nn.Linear(int(embed_dim * mlp_ratio), embed_dim),nn.Dropout(dropout))self.norm = nn.LayerNorm(embed_dim)# 用于随机深度的 drop path 层self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()def forward(self, x):# 在残差连接中应用带有 drop path 的注意力。x = x + self.drop_path(self.attention(x))# 在残差连接中应用 MLP(经过层归一化)并加入 drop path。x = x + self.drop_path(self.mlp(self.norm(x)))return xif __name__ == '__main__':# 参数设置batch_size = 1      # 批大小seq_len = 224 * 224 # 序列长度embed_dim = 32      # 嵌入维度num_heads = 4       # 注意力头数# 创建随机输入张量 (batch_size, seq_len, embed_dim)x = torch.randn(batch_size, seq_len, embed_dim)# 初始化 MultiHeadSpectralAttentionattention_module = MultiHeadSpectralAttention(embed_dim=embed_dim, seq_len=seq_len, num_heads=num_heads)# 初始化 TransformerEncoderBlocktransformer_block = TransformerEncoderBlock(embed_dim=embed_dim, attention_module=attention_module)print(transformer_block)print("微信公众号: AI缝合术!")# 前向传播测试output = transformer_block(x)# 打印输出形状print("输入形状:", x.shape)print("输出形状:", output.shape)    

运行结果:

图片

乍一看代码比较多,其实原理非常简单,该代码实现了一个标准的Transformer编码器结构,除去两个固定操作的随机深度DropPath,剩下仅有两个类组成,MultiHeadSpectralAttention实现了基于快速傅里叶变换的高效多头自注意力,TransformerEncoderBlock是一个通用的Transformer编码器模块。

图片

上图是ViT的经典结构图,我们只看右侧编码器部分,上述代码实现的就是右侧的编码器,只是将多头注意力转换到频域来进行计算,非常容易理解。

采用上面方法构建的FFTNetViT在LRA和ImageNet两个数据集上的广泛评估确认,FFTNet不仅实现了有竞争力的准确性,而且与固定傅里叶方法和标准自注意力相比,显著提高了计算效率。

以上两个模块均可即插即用,应用在自然语言处理、图像处理、计算机视觉等各类任务上,是非常好的创新!

https://github.com/AIFengheshu/Plug-play-modules

2025年全网最全即插即用模块,免费分享!包含人工智能全领域(机器学习、深度学习等),适用于图像分类、目标检测、实例分割、语义分割、全景分割、姿态识别、医学图像分割、视频目标分割、图像抠图、图像编辑、单目标跟踪、多目标跟踪、行人重识别、RGBT、图像去噪、去雨、去雾、去阴影、去模糊、超分辨率、去反光、去摩尔纹、图像恢复、图像修复、高光谱图像恢复、图像融合、图像上色、高动态范围成像、视频与图像压缩、3D点云、3D目标检测、3D语义分割、3D姿态识别等各类计算机视觉和图像处理任务,以及自然语言处理、大语言模型、多模态等其他各类人工智能相关任务。持续更新中.....

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/88388.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/88388.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

通过vue如何利用 Three 绘制 简单3D模型(源码案例)

目录 Three 介绍 创建基础3D场景 创建不同类型的3D模型 1. 球体 2. 圆柱体​​​​​​​ 3. 平面​​​​​​​ 加载外部3D模型 添加交互控制 创建可交互的3D场景 Three 介绍 Three.js是一个强大的JavaScript 3D库,可以轻松地在网页中创建3D图形。下面我…

云蝠智能 Voice Agent 落地展会邀约场景:重构会展行业的智能交互范式

一、行业痛点与 AI 破局在会展行业数字化转型的浪潮中,传统展会邀约模式面临多重挑战:人工外呼日均仅能处理 300-500 通电话,且无效号码占比高达 40% 以上,导致邀约效率低下。同时,个性化邀约话术设计依赖经验&#xf…

idea如何打开extract surround

在 IntelliJ IDEA 中,"Extract Surrounding"(提取周围代码)通常指 ​将一段代码提取到新的方法、变量或类中,但更常见的操作是 ​​"Surround With"(用代码结构包围)​。以下是两种场景…

window显示驱动开发—XR_BIAS 和 BltDXGI

Direct3D 运行时调用驱动程序的 BltDXGI 函数,以仅对XR_BIAS源资源执行以下操作:复制到也XR_BIAS的目标未修改的源数据的副本可接受点样本的拉伸旋转由于 XR_BIAS 不支持 MSAA) (多个示例抗锯齿,因此驱动程序不需要解析XR_BIAS资源。核心规则…

web网页开发,在线%ctf管理%系统,基于html,css,webform,asp.net mvc, sqlserver, mysql

webform,asp.net mvc。数据库支持mysql,sqlserver经验心得 每次我们写crud没啥技术含量,这没法让咱们进入大厂,刚好这次与客户沟通优化方案建议,咱们就把能加的帮他都加上去。一个ctf管理系统基本crud,并进行不同分层开发&#xf…

面试技术问题总结一

MySQL的几种锁机制一、从锁的粒度角度划分表级锁机制:它是对整张表进行锁定的一种锁。当一个事务对表执行写操作时,会获取写锁,在写锁持有期间,其他事务无法对该表进行读写操作;而当事务执行读操作时,会获取…

π0.5的KI改进版——知识隔离:让VLM在不受动作专家负反馈的同时,继续输出离散动作token,并根据反馈做微调(而非冻结VLM)

前言 过去的一个月(25年6.4-7.4),我司「七月在线」具身长沙分部为冲刺一些为客户来现场看的演示项目,基本都用lerobot的那套框架 比如上周五(7.4日)晚上,通过上周五下午新采的第五波数据做『耳机线插入耳机孔』的任务,推理十次之…

Eigen中Isometry3d的使用详解和实战示例

Eigen::Isometry3d 是 Eigen 库中用于表示 三维空间中的刚性变换(Rigid Transformation) 的类,属于 Eigen::Transform 模板类的一个特化版本。它结合了 旋转和平移,广泛应用于机器人学、SLAM、三维几何计算等场景。一、核心定义 #…

《未来已来:当人类智慧遇上AI智能体》

在这个充满奇迹的时代,人类的智慧与科技的力量正以前所未有的速度交织在一起。 我们站在一个新时代的门槛上,一边是古老而深邃的自然规律,另一边是充满可能性的未来世界。 今天,就让我们一起走进这场关于人类智慧与AI智能体Kimi的对话,看看未来究竟会带给我们怎样的惊喜…

【三维生成】FlashDreamer:基于扩散模型的单目图像到3D场景

标题&#xff1a;<Enhancing Monocular 3D Scene Completion with Diffusion Model> 代码&#xff1a;https://github.com/CharlieSong1999/FlashDreamer 来源&#xff1a;澳大利亚国立大学 文章目录摘要一、前言二、相关工作2.1 场景重建2.2 扩散模型2.3 Vision languag…

CANFD记录仪设备在无人驾驶快递车的应用

随着物流行业的快速发展&#xff0c;无人驾驶快递车因其高效、低成本的优势&#xff0c;逐渐成为“最后一公里”配送的重要解决方案。然而&#xff0c;无人驾驶系统的稳定性和安全性高度依赖车辆总线数据的精准采集与分析。南金研CANFDlog4 4路记录仪凭借其多通道、高带宽、高可…

Kubernetes存储入门

目录 前言 一、Volume 的概念 二、Volume 的类型 常见的卷类型 Kubernetes 独有的卷类型 三、通过 emptyDir 共享数据 1. 编写 emptyDir 的 Deployment 文件 2. 部署该 Deployment 3. 查看部署结果 4. 登录 Pod 中的第一个容器 5. 登录 Pod 中的第二个容器查看/mnt下…

10.Docker安装mysql

(1)docker pull mysql:版本号eg&#xff1a;docker pull mysql(默认安装最新版本)docker pull mysql:5.7(2)启动并设置mysql镜像docker run -d -p 3306:3306 -e MYSQL_ROOT_PASSWORD123456 --name mysql1 mysql其他参数都不多讲&#xff0c;下面这个参数指的是设置数据库用户ro…

Debian-10编译安装Mysql-5.7.44 笔记250706

Debian-10编译安装Mysql-5.7.44 笔记250706 单一脚本安装 ### 1. 安装编译依赖 sudo apt install -y cmake gcc g build-essential libncurses5-dev libssl-dev \ pkg-config libreadline-dev zlib1g-dev bison curl wget libaio-dev \ libjson-perl libnuma-dev libsystemd-d…

HarmonyOS 中状态管理 V2和 V1 的区别

鸿蒙ArkUI框架中的ComponentV2与V1在状态管理、组件开发模式、性能优化等方面存在显著差异。以下是两者的核心区别及技术解析&#xff1a;一、状态管理机制V1的局限性V1的Observed装饰器只能观察对象的第一层属性变化&#xff0c;需配合ObjectLink手动拆解嵌套对象。例如&#…

centos7 安装jenkins

文章目录前言一、pandas是什么&#xff1f;二、安装依赖环境1.前提准备2.安装git3.安装jdk&#xff0c;以及jdk版本选择4.安装maven5.安装NodeJS6.验证三、安装Jenkins四、验证Jenkins总结前言 正在学习jenkinsdocker部署前后端分离项目&#xff0c;安装jenkins的时候遇到了一…

Leetcode刷题营第二十题:删除链表中的重复节点

面试题 02.01. 移除重复节点 编写代码&#xff0c;移除未排序链表中的重复节点。保留最开始出现的节点。 示例1&#xff1a; 输入&#xff1a;[1, 2, 3, 3, 2, 1]输出&#xff1a;[1, 2, 3]示例2&#xff1a; 输入&#xff1a;[1, 1, 1, 1, 2]输出&#xff1a;[1, 2]提示&…

关于市场主流自动化测试工具和框架的简要介绍

下面我会分别讲解 Selenium、Appium、Playwright 等主流自动化框架的区别、联系、适用场景和归属范畴&#xff0c;帮助你更系统地理解它们。&#x1f527; 一、它们都属于哪一类工具&#xff1f;Selenium、Appium、Playwright、Cypress 等都属于&#xff1a;▶️ 自动化测试框架…

基于cornerstone3D的dicom影像浏览器 第三十二章 文件夹做pacs服务端,fake-pacs-server

文章目录 前言一、实现思路二、项目与代码三、dicom浏览器调用1. view2d.vue前言 本系列最后一章,提供一个模拟pacs服务,供访问dicom图像测试。 修改nodejs本地目录做为http服务根目录,提供一个根目录,其中的每个子目录代表一个检查。在dicom浏览器url中带入参数studyId=目…

【Python 核心概念】深入理解可变与不可变类型

文章目录一、故事从变量赋值说起二、不可变类型 (Immutable Types)三、可变类型 (Mutable Types)四、一个常见的陷阱&#xff1a;当元组遇到列表五、为什么这个区别如此重要&#xff1f;1. 函数参数的传递2. 字典的键 (Dictionary Keys)3. 函数的默认参数陷阱六、进阶话题与扩展…