SDPA(Scaled Dot-Product Attention)详解

SDPA(Scaled Dot-Product Attention)详解

SDPA(Scaled Dot-Product Attention,缩放点积注意力)是 Transformer 模型的核心计算单元,最早由 Vaswani 等人在 2017 年的论文《Attention Is All You Need》提出。它通过计算查询(Query)、键(Key)和值(Value)之间的相似度,生成上下文感知的表示。


1. SDPA 的数学定义

给定:

  • 查询矩阵(Query) Q ∈ R n × d k Q \in \mathbb{R}^{n \times d_k} QRn×dk
  • 键矩阵(Key) K ∈ R m × d k K \in \mathbb{R}^{m \times d_k} KRm×dk
  • 值矩阵(Value) V ∈ R m × d v V \in \mathbb{R}^{m \times d_v} VRm×dv

SDPA 的计算公式为:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right) V Attention(Q,K,V)=softmax(dk QKT)V

其中:

  • Q K T QK^T QKT 计算查询和键的点积(相似度)。
  • d k \sqrt{d_k} dk 用于缩放点积,防止梯度消失或爆炸(尤其是 d k d_k dk 较大时)。
  • softmax 将注意力权重归一化为概率分布。
  • 最终加权求和 V V V 得到输出。

2. SDPA 的计算步骤

  1. 计算相似度(Dot-Product)
  • 计算 Q Q Q K K K 的点积:
    S = Q K T S = QK^T S=QKT
  • 相似度矩阵 S ∈ R n × m S \in \mathbb{R}^{n \times m} SRn×m 表示每个查询对所有键的匹配程度。
  1. 缩放(Scaling)

    • 除以 d k \sqrt{d_k} dk (键向量的维度),防止点积值过大导致 softmax 梯度消失:
      S scaled = S d k S_{\text{scaled}} = \frac{S}{\sqrt{d_k}} Sscaled=dk S
  2. Softmax 归一化

    • 对每行(每个查询)做 softmax,得到注意力权重 A A A
      A = softmax ( S scaled ) A = \text{softmax}(S_{\text{scaled}}) A=softmax(Sscaled)
    • 保证 ∑ j A i , j = 1 \sum_j A_{i,j} = 1 jAi,j=1,权重总和为 1。
  3. 加权求和(Value 聚合)

    • 用注意力权重 A A A V V V 加权求和,得到最终输出:
      Output = A ⋅ V \text{Output} = A \cdot V Output=AV
    • 输出维度: R n × d v \mathbb{R}^{n \times d_v} Rn×dv

3. SDPA 的作用与优势

核心作用

  • 让模型动态关注输入的不同部分(类似人类注意力机制)。
  • 适用于序列数据(如文本、语音、视频),捕捉长距离依赖。

优势

  1. 并行计算友好
  • 矩阵乘法(GEMM)可高效并行加速(GPU/TPU 优化)。
  1. 可解释性
    • 注意力权重可视化(如 BertViz)可分析模型关注哪些 token。
  2. 灵活扩展
    • 可结合 多头注意力(Multi-Head Attention) 增强表达能力。

4. SDPA 的变体与优化

变体/优化核心改进应用场景
多头注意力(MHA)并行多个 SDPA,增强特征多样性Transformer (BERT, GPT)
FlashAttention优化内存访问,减少 HBM 读写长序列推理(如 8K+ tokens)
Sparse Attention只计算局部或稀疏的注意力降低计算复杂度(如 Longformer)
Linear Attention用线性近似替代 softmax低资源设备(如 RetNet)

5. 代码实现(PyTorch 示例)

import torch
import torch.nn.functional as Fdef scaled_dot_product_attention(Q, K, V, mask=None):d_k = Q.size(-1)scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)attn_weights = F.softmax(scores, dim=-1)output = torch.matmul(attn_weights, V)return output# 示例输入
Q = torch.randn(2, 5, 64)  # (batch_size, seq_len, d_k)
K = torch.randn(2, 5, 64)
V = torch.randn(2, 5, 128)
output = scaled_dot_product_attention(Q, K, V)
print(output.shape)  # torch.Size([2, 5, 128])

6. 总结

  • SDPA 是 Transformer 的基石,通过 Query-Key-Value 机制 + Softmax 归一化 实现动态注意力。
  • 关键优化点:缩放(防止梯度问题)、并行计算、内存效率(如 FlashAttention)
  • 现代优化(如 SageAttention2)进一步结合 量化、稀疏化、离群值处理 提升效率。

SDPA 及其变体已成为 NLP、CV、多模态领域的核心组件,理解其原理对模型优化至关重要。

SDPA计算过程举例

我们通过一个具体的数值例子,逐步演示 SDPA 的计算过程。假设输入如下(简化版,便于手动计算):

输入数据(假设 d_k = 2, d_v = 3
  • Query (Q):2 个查询(n=2),每个查询维度 d_k=2
    Q = [ 1 2 3 4 ] Q = \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ \end{bmatrix} Q=[1324]
  • Key (K):3 个键(m=3),每个键维度 d_k=2
    K = [ 5 6 7 8 9 10 ] K = \begin{bmatrix} 5 & 6 \\ 7 & 8 \\ 9 & 10 \\ \end{bmatrix} K= 5796810
  • Value (V):3 个值(m=3),每个值维度 d_v=3
    V = [ 1 0 1 0 1 0 1 1 0 ] V = \begin{bmatrix} 1 & 0 & 1 \\ 0 & 1 & 0 \\ 1 & 1 & 0 \\ \end{bmatrix} V= 101011100

Step 1: 计算 Query 和 Key 的点积(Dot-Product)

计算 S = Q K T S = QK^T S=QKT

Q K T = [ 1 ⋅ 5 + 2 ⋅ 6 1 ⋅ 7 + 2 ⋅ 8 1 ⋅ 9 + 2 ⋅ 10 3 ⋅ 5 + 4 ⋅ 6 3 ⋅ 7 + 4 ⋅ 8 3 ⋅ 9 + 4 ⋅ 10 ] = [ 5 + 12 7 + 16 9 + 20 15 + 24 21 + 32 27 + 40 ] = [ 17 23 29 39 53 67 ] QK^T = \begin{bmatrix} 1 \cdot 5 + 2 \cdot 6 & 1 \cdot 7 + 2 \cdot 8 & 1 \cdot 9 + 2 \cdot 10 \\ 3 \cdot 5 + 4 \cdot 6 & 3 \cdot 7 + 4 \cdot 8 & 3 \cdot 9 + 4 \cdot 10 \\ \end{bmatrix} = \begin{bmatrix} 5+12 & 7+16 & 9+20 \\ 15+24 & 21+32 & 27+40 \\ \end{bmatrix} = \begin{bmatrix} 17 & 23 & 29 \\ 39 & 53 & 67 \\ \end{bmatrix} QKT=[15+2635+4617+2837+4819+21039+410]=[5+1215+247+1621+329+2027+40]=[173923532967]


Step 2: 缩放(Scaling)

除以 d k = 2 ≈ 1.414 \sqrt{d_k} = \sqrt{2} \approx 1.414 dk =2 1.414

S scaled = S 2 = [ 17 / 1.414 23 / 1.414 29 / 1.414 39 / 1.414 53 / 1.414 67 / 1.414 ] ≈ [ 12.02 16.26 20.51 27.58 37.48 47.38 ] S_{\text{scaled}} = \frac{S}{\sqrt{2}} = \begin{bmatrix} 17/1.414 & 23/1.414 & 29/1.414 \\ 39/1.414 & 53/1.414 & 67/1.414 \\ \end{bmatrix} \approx \begin{bmatrix} 12.02 & 16.26 & 20.51 \\ 27.58 & 37.48 & 47.38 \\ \end{bmatrix} Sscaled=2 S=[17/1.41439/1.41423/1.41453/1.41429/1.41467/1.414][12.0227.5816.2637.4820.5147.38]


Step 3: Softmax 归一化(计算注意力权重)

对每一行(每个 Query)做 softmax:

$\text{softmax}([12.02, 16.26, 20.51]) \approx [2.06 \times 10^{-4}, 0.016, 0.984] $
$\text{softmax}([27.58, 37.48, 47.38]) \approx [1.67 \times 10^{-9}, 0.0001, 0.9999] $

因此,注意力权重矩阵 A A A 为:

A ≈ [ 2.06 × 10 − 4 0.016 0.984 1.67 × 10 − 9 0.0001 0.9999 ] A \approx \begin{bmatrix} 2.06 \times 10^{-4} & 0.016 & 0.984 \\ 1.67 \times 10^{-9} & 0.0001 & 0.9999 \\ \end{bmatrix} A[2.06×1041.67×1090.0160.00010.9840.9999]

解释

  • 第 1 个 Query 主要关注第 3 个 Key(权重 0.984)。
  • 第 2 个 Query 几乎只关注第 3 个 Key(权重 0.9999)。

Step 4: 加权求和(聚合 Value)

计算 Output = A ⋅ V \text{Output} = A \cdot V Output=AV

Output = [ 2.06 × 10 − 4 ⋅ 1 + 0.016 ⋅ 0 + 0.984 ⋅ 1 2.06 × 10 − 4 ⋅ 0 + 0.016 ⋅ 1 + 0.984 ⋅ 1 2.06 × 10 − 4 ⋅ 1 + 0.016 ⋅ 0 + 0.984 ⋅ 0 ] T ≈ [ 0.984 1.000 0.0002 ] T \text{Output} = \begin{bmatrix} 2.06 \times 10^{-4} \cdot 1 + 0.016 \cdot 0 + 0.984 \cdot 1 \\ 2.06 \times 10^{-4} \cdot 0 + 0.016 \cdot 1 + 0.984 \cdot 1 \\ 2.06 \times 10^{-4} \cdot 1 + 0.016 \cdot 0 + 0.984 \cdot 0 \\ \end{bmatrix}^T \approx \begin{bmatrix} 0.984 \\ 1.000 \\ 0.0002 \\ \end{bmatrix}^T Output= 2.06×1041+0.0160+0.98412.06×1040+0.0161+0.98412.06×1041+0.0160+0.9840 T 0.9841.0000.0002 T

Output = [ 0.984 1.000 0.0002 0.9999 0.9999 0.0001 ] \text{Output} = \begin{bmatrix} 0.984 & 1.000 & 0.0002 \\ 0.9999 & 0.9999 & 0.0001 \\ \end{bmatrix} Output=[0.9840.99991.0000.99990.00020.0001]

解释

  • 第 1 行:主要聚合了第 3 个 Value [1, 1, 0],但受前两个 Value 微弱影响。
  • 第 2 行:几乎完全由第 3 个 Value 决定。

最终输出

Output ≈ [ 0.984 1.000 0.0002 0.9999 0.9999 0.0001 ] \text{Output} \approx \begin{bmatrix} 0.984 & 1.000 & 0.0002 \\ 0.9999 & 0.9999 & 0.0001 \\ \end{bmatrix} Output[0.9840.99991.0000.99990.00020.0001]


总结

  1. 点积:计算 Query 和 Key 的相似度。
  2. 缩放:防止梯度爆炸/消失。
  3. Softmax:归一化为概率分布。
  4. 加权求和:聚合 Value 得到最终表示。

这个例子展示了 SDPA 如何动态分配注意力权重,并生成上下文感知的输出。实际应用中(如 Transformer),还会结合 多头注意力(Multi-Head Attention) 增强表达能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/bicheng/84762.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java通过hutool工具生成二维码实现扫码跳转功能

实现&#xff1a; 首先引入zxing和hutool工具依赖 <dependency><groupId>com.google.zxing</groupId><artifactId>core</artifactId><version>3.5.2</version></dependency><dependency><groupId>com.google.zxi…

数据库数据导出到Excel表格

1.后端代码 第一步&#xff1a;UserMapper定义根据ID列表批量查询用户方法 // 批量查询用户信息List<User> selectUserByIds(List<Integer> ids); 第二步&#xff1a;UserMapper.xml写动态SQL&#xff0c;实现批量查询用户 <!--根据Ids批量查询用户-->&l…

Altera系列FPGA基于ADV7180解码PAL视频,纯verilog去隔行,提供2套Quartus工程源码和技术支持

目录 1、前言工程概述免责声明 2、相关方案推荐我已有的所有工程源码总目录----方便你快速找到自己喜欢的项目Altera系列FPGA相关方案推荐我这里已有的PAL视频解码方案 3、设计思路框架工程设计原理框图输入PAL相机ADV7180芯片解读BT656视频解码模块图像缓存架构输出视频格式转…

【教程】Windows安全中心扫描设置排除文件

转载请注明出处&#xff1a;小锋学长生活大爆炸[xfxuezhagn.cn] 如果本文帮助到了你&#xff0c;欢迎[点赞、收藏、关注]哦~ 目录 背景说明 解决方法 背景说明 即使已经把实时防护等设置全都关了&#xff0c;但Windows还是会不定时给你扫描&#xff0c;然后把风险软件给删了…

OPenCV CUDA模块立体匹配------对立体匹配生成的视差图进行双边滤波处理类cv::cuda::DisparityBilateralFilter

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 cv::cuda::DisparityBilateralFilter 是 OpenCV CUDA 模块中的一个类&#xff0c;用于对立体匹配生成的视差图进行双边滤波处理。这种滤波方法可…

自然语言处理期末复习

自然语言处理期末复习 一单元 自然语言处理基础 两个核心任务&#xff1a; 自然语言理解&#xff08;NLU, Natural Language Understanding&#xff09; 让计算机“读懂”人类语言&#xff0c;理解文本的语义、结构和意图。 典型子任务包括&#xff1a;分词、词性标注、句法分…

黄仁勋在2025年巴黎VivaTech大会上的GTC演讲:AI工厂驱动的工业革命(上)

引言 2025年6月12日,在巴黎VivaTech大会上,英伟达创始人兼CEO黄仁勋发表了题为"AI工厂驱动的工业革命"的GTC主题演讲。这场持续约1小时35分钟的演讲不仅详细阐述了英伟达在AI基础设施、智能体技术、量子计算及机器人领域的最新突破,更系统性地勾勒出了人工智能如…

DMC-E 系列总线控制卡----雷赛板卡介绍(六)

应用软件开发方法 DMC-E 系列总线运动控制卡的应用软件可以在 Visual Basic 、 Visual C++ 、 C# 等高级语言 环境下开发。应用软件开发之前,需保证 DMC-E 系列总线运动控制卡连接好从站,通过控制 卡 Motion 的 EtherCAT 总线配置界面扫描从站、设置总线通信周期…

题目类型——左右逢源

1、针对的题目&#xff1a;&#xff08;不一定正确或完整&#xff09; 数据结构为数组之类的线性结构&#xff08;也许可以拓展&#xff09;&#xff0c;于是数组中每个元素和其他元素的相对关系为左右或前后需要对数组中每个元素求解或者说最终解要根据每个元素的解得出每个元…

RAG检索前处理

1. 查询构建&#xff08;包括Text2SQL&#xff09; 查询构建的相关技术栈&#xff1a; Text-to-SQLText-to-Cypher 从查询中提取元数据&#xff08;Self-query Retriever&#xff09; 1.1 Text-to-SQL&#xff08;关系数据库&#xff09; 1.1.1 大语言模型方法Text-to-SQL样…

OmoFun动漫官网,动漫共和国最新入口|网页版

OmoFun 动漫&#xff0c;又叫动漫共和国&#xff0c;是一个专注于提供丰富动漫资源的在线平台&#xff0c;深受广大动漫爱好者的喜爱。它汇集了海量的动漫资源&#xff0c;涵盖日本动漫、国产动漫、欧美动漫等多种类型&#xff0c;无论是最新上映的热门番剧还是经典老番&#x…

ue5的blender4.1groom毛发插件v012安装和使用方法(排除了冲突错误)

关键出错不出错是看这个文件pyalembic-1.8.8-cp311-cp311-win_amd64.whl&#xff0c;解决和Alembic SQL工具&#xff09;的加载冲突&#xff01; 其他blender版本根据其内部的python版本选择对应的文件解压安装。 1、安装插件&#xff01;把GroomExporter_v012_Blender4.1.1(原…

windows安装jekyll

windows安装jekyll 安装ruby 首先需要下载ruby RubyInstaller for Windows - RubyInstaller国内镜像站 我的操作系统是win10所以我安装的最新版&#xff0c;你们安装的时候&#xff0c;也可以安装最新版&#xff0c;我这里就不附加图片了 如果你的ruby安装完成之后&#x…

DBever工具自适应mysql不同版本的连接

DBever工具的连接便捷性 最近在使用DBever工具连接不同版本的mysql数据库&#xff0c;发现这个工具确实比mysql-log工具要兼容性好很多&#xff0c;直接就可以连接不同版本的数据库&#xff0c;比如常见的mysql数据库版本&#xff1a;8.0和5.7&#xff0c;而且链接成功后&…

K8S认证|CKS题库+答案| 10. Trivy 扫描镜像安全漏洞

目录 10. Trivy 扫描镜像安全漏洞 免费获取并激活 CKA_v1.31_模拟系统 题目 开始操作&#xff1a; 1&#xff09;、切换集群 2&#xff09;、切换到master并提权 3&#xff09;、查看Pod和镜像对应关系 4&#xff09;、查看并去重镜像名称 5&#xff09;、扫描所有镜…

Rust高级抽象

Rust 的高级抽象能力是其核心优势之一&#xff0c;允许开发者通过特征&#xff08;Traits&#xff09;、泛型&#xff08;Generics&#xff09;、闭包&#xff08;Closures&#xff09;、迭代器&#xff08;Iterators&#xff09;等机制实现高度灵活和可复用的代码。今天我们来…

Vue里面的映射方法

111.getters配置项 112.mapstate和mapgetter 113.&#xfeff;mapActions与&#xfeff;mapMutations 114.多组件共享数据 115.vuex模块化&#xff0c;namespaces1 116.name&#xfeff;s&#xfeff;pace2

Node.js特训专栏-基础篇:2. JavaScript核心知识在Node.js中的应用

我将从变量、函数、异步编程等方面入手&#xff0c;结合Node.js实际应用场景&#xff0c;为你详细阐述JavaScript核心知识在其中的运用&#xff1a; JavaScript核心知识在Node.js中的应用 在当今的软件开发领域&#xff0c;Node.js凭借其高效的性能和强大的功能&#xff0c;成…

负载均衡LB》》LVS

LO 接口 LVS简介 LVS&#xff08;Linux Virtual Server&#xff09;即Linux虚拟服务器&#xff0c;是由章文嵩博士主导的开源负载均衡项目&#xff0c;通过LVS提供的负载均衡技术和Linux操作系统实现一个高性能、高可用的服务器集群&#xff0c;它具有良好可靠性、可扩展性和可…

Modbus TCP转DeviceNet网关配置温控仪配置案例

某工厂生产线需将Modbus TCP协议的智能仪表接入DeviceNet网络&#xff08;主站为PLC&#xff0c;如Rockwell ControlLogix&#xff09;&#xff0c;实现集中监控。需通过开疆智能Modbus TCP转DeviceNet网关KJ-DVCZ-MTCPS完成协议转换。Modbus TCP设备&#xff1a;温控器&#x…