【大模型LLM学习】Flash-Attention的学习记录

【大模型LLM学习】Flash-Attention的学习记录

  • 0. 前言
  • 1. flash-attention原理简述
  • 2. 从softmax到online softmax
    • 2.1 safe-softmax
    • 2.2 3-pass safe softmax
    • 2.3 Online softmax
    • 2.4 Flash-attention
    • 2.5 Flash-attention tiling

0. 前言

  Flash Attention可以节约模型训练和推理时间,很多模型可以通过config参数来选择attention是标准的attention实现还是flash-attention方式。在这里记录一下flash attention的学习过程,发现了一位博主以及参考的资料特别好:

  • zhihu一位做高性能计算的博主博文
  • 华盛顿大学的课程note

1. flash-attention原理简述

a t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V attention(Q,K,V)=softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V attention(Q,K,V)=softmax(dk QKT)V
  标准的attention操作的时间卡点不是在运算上,而是卡在数据读写上。SRAM的读写速度快,但是存储空间有限,无法一次存下来所有的中间计算结果,一次attention计算存在SRAM<->HBM的多次读写操作。
在这里插入图片描述
  与标准的attention操作比较,flash-attention通过减少数据在HBM和SRAM间的读写操作,来节约时间(甚至backward时还进行了重新计算,重新计算的速度也比把数据从HBM读取到SRAM要快)。
https://huggingface.co/docs/text-generation-inference/conceptual/flash_attention

2. 从softmax到online softmax

  直接看flash-attention的论文比较难看明白,发现华盛顿大学的那份note写得特别清晰,跟着它从softmax看到flash-attention会比较容易。

2.1 safe-softmax

  首先是safe的softmax计算方式。原始的softmax,对于N个数:
s o f t m a x ( { x 1 , . . . , x N } ) = { e x i ∑ j = 1 N e x j } i = 1 N softmax(\{x_1,...,x_N\})=\left\{\frac{e^{x_i}}{\sum_{j=1}^{N}e^{x_j}}\right\}_{i=1}^{N} softmax({x1,...,xN})={j=1Nexjexi}i=1N
  对于FP16,最大能表示的数据为65536,当 x > = 11 x>=11 x>=11时, e x e^x ex就会超过FP16的最大表示范围影响结果的正确性。为了避免这个问题,SafeSoftmax 通过减去输入向量中的最大值来调整输入,使得最大的指数项变为 e 0 = 1 e^0=1 e0=1从而防止了上溢的发生。同时,由于所有的指数项都除以同一个数,它们的比例关系不会改变,因此也不会影响最终的概率分布。
e x i ∑ j = 1 N e x j = e x i − m ∑ j = 1 N e x j − m , m = m a x { x j } j = 1 N \frac{e^{x_i}}{\sum_{j=1}{N}e^{x_j}}=\frac{e^{x_i-m}}{\sum_{j=1}{N}e^{x_j-m}}, \quad m=max\left\{x_j\right\}_{j=1}^{N} j=1Nexjexi=j=1Nexjmexim,m=max{xj}j=1N

2.2 3-pass safe softmax

  • 对于一个行向量 { x i } i = 1 N \{x_i\}_{i=1}^N {xi}i=1N,最直白的softmax计算方式是直接for循环

在这里插入图片描述
  这个算法计算softmax需要执行3次从1->N的循环,在attention中, { x i } \{x_i\} {xi} Q K T QK^T QKT的结果,但是如果SRAM里面存不下这个大的矩阵,上面的计算过程,就需要从HBM里面加载3次 { x i } \{x_i\} {xi},时间花在了数据读写上。

2.3 Online softmax

  如果能把上面(7)(8)(9)这3个式子的计算放一个for循环,就只需要一次load数据。但是 m N m_N mN是全局最大值,计算 m N m_N mN就已经需要一次遍历了。
  Online softmax算法把(7)(8)进行了合并,把3次遍历缩减为2个。它提出计算 d i ′ = ∑ j = 1 i e x j − m i d_i^{\prime}=\sum_{j=1}^{i}e^{x_j-m_i} di=j=1iexjmi来代替计算 d i d_i di,当算到最后 i = N i=N i=N时会发现, d N = d N ′ d_N=d_N^{\prime} dN=dN。具体的,迭代计算 d i ′ d_i^{\prime} di的方式为:
d i ′ = ∑ j = 1 i e x j − m i = ( ∑ j = 1 i − 1 e x j − m i ) + e x i − m i = ( ∑ j = 1 i − 1 e x j − m i − 1 ) e m i − 1 − m i + e x i − m i = d i − 1 ′ e m i − 1 − m i + e x i − m i \begin{aligned} d_i^{\prime} &= \sum_{j=1}^{i} e^{x_j - m_i} \\ &= \left( \sum_{j=1}^{i-1} e^{x_j - m_i} \right) + e^{x_i - m_i} \\ &= \left( \sum_{j=1}^{i-1} e^{x_j - m_{i-1}} \right) e^{m_{i-1} - m_i} + e^{x_i - m_i} \\ &= d_{i-1}^{\prime} e^{m_{i-1} - m_i} + e^{x_i - m_i} \end{aligned} di=j=1iexjmi=(j=1i1exjmi)+eximi=(j=1i1exjmi1)emi1mi+eximi=di1emi1mi+eximi

  所以就可以用迭代的方式,在找最大值 m N m_N mN的时候,同时来计算 d i ′ d_i^{\prime} di,把(7)和(8)一起计算,这样只需要加载两次 x i x_i xi

在这里插入图片描述

2.4 Flash-attention

  上面的online softmax仍然需要2个for循环,加载2次 x i x_i xi来完成softmax的计算。完成softmax的计算,没法更进一步地压缩到1次遍历。但是attention计算的最终目标是获取输出结果,也就是注意力分数与 V V V相乘的结果 O = A × V O=A \times V O=A×V,计算 O O O可以通过一次遍历完成。
在这里插入图片描述
  可以使用类似online softmax把计算 d i d_i di变成计算 d i ′ d_i^{\prime} di的方式,把 o i o_i oi的计算也改成迭代式的,首先把 a i a_i ai带入 o i o_i oi的表达式
o i = ∑ j = 1 i ( e x j − m N d N ′ V [ j , : ] ) o_i=\sum_{j=1}^{i}\left(\frac{e^{x_j-m_{N}}}{d_N^{\prime}}V[j,:]\right) oi=j=1i(dNexjmNV[j,:])

  可以找到一个 o i ′ o_i^{\prime} oi,它不依赖于全局的 d N ′ d_N^{\prime} dN m N m_N mN
o i ′ = ∑ j = 1 i ( e x j − m i d i ′ V [ j , : ] ) o_i^{\prime}=\sum_{j=1}^{i}\left(\frac{e^{x_j-m_{i}}}{d_i^{\prime}}V[j,:]\right) oi=j=1i(diexjmiV[j,:])

  对于 o i ′ o_i^{\prime} oi的计算可以使用迭代的方式,同样的是有 o N = o N ′ o_N=o_N^{\prime} oN=oN
o i ′ = ∑ j = 1 i e x j − m i d i ′ V [ j , : ] = ( ∑ j = 1 i − 1 e x j − m i d i ′ V [ j , : ] ) + e x i − m i d i ′ V [ i , : ] = ( ∑ j = 1 i − 1 e x j − m i − 1 d i − 1 ′ e x j − m i e x j − m i − 1 d i − 1 ′ d i ′ V [ j , : ] ) + e x i − m i d i ′ V [ i , : ] = ( ∑ j = 1 i − 1 e x j − m i − 1 d i − 1 ′ V [ j , : ] ) d i − 1 ′ d i ′ e m i − 1 − m i + e x i − m i d i ′ V [ i , : ] = o i − 1 ′ d i − 1 ′ e m i − 1 − m i d i ′ + e x i − m i d i ′ V [ i , : ] \begin{aligned} o_i' &= \sum_{j=1}^{i} \frac{e^{x_j - m_i}}{d_i'} V[j,:] \\ &= \left( \sum_{j=1}^{i-1} \frac{e^{x_j - m_i}}{d_i'} V[j,:] \right) + \frac{e^{x_i - m_i}}{d_i'} V[i,:] \\ &= \left( \sum_{j=1}^{i-1} \frac{e^{x_j - m_{i-1}}}{d_{i-1}'} \frac{e^{x_j - m_i}}{e^{x_j - m_{i-1}}} \frac{d_{i-1}'}{d_i'} V[j,:] \right) + \frac{e^{x_i - m_i}}{d_i'} V[i,:] \\ &= \left( \sum_{j=1}^{i-1} \frac{e^{x_j - m_{i-1}}}{d_{i-1}'} V[j,:] \right) \frac{d_{i-1}'}{d_i'} e^{m_{i-1} - m_i} + \frac{e^{x_i - m_i}}{d_i'} V[i,:] \\ &= o_{i-1}' \frac{d_{i-1}' e^{m_{i-1} - m_i}}{d_i'} + \frac{e^{x_i - m_i}}{d_i'} V[i,:] \end{aligned} oi=j=1idiexjmiV[j,:]=(j=1i1diexjmiV[j,:])+dieximiV[i,:]=(j=1i1di1exjmi1exjmi1exjmididi1V[j,:])+dieximiV[i,:]=(j=1i1di1exjmi1V[j,:])didi1emi1mi+dieximiV[i,:]=oi1didi1emi1mi+dieximiV[i,:]

  这样计算attention的输出结果可以只进行一次遍历就完成
在这里插入图片描述

2.5 Flash-attention tiling

  上面是每次计算一个元素 [ i ] [i] [i],实际上可以一次读取一个大小为b的块(tile)来计算

在这里插入图片描述在这里插入图片描述

  此外,在flash-attention的paper里面,对 Q Q Q K K K V V V O O O分块,其中 Q Q Q
O O O每块大小为 m i n ( M / 4 d , d ) × d min(M/4d,d) \times d min(M/4d,d)×d K / V K/V K/V的每块大小为 M / 4 d × d M/4d \times d M/4d×d,加起来正好不会超过SRAM的大小M,完整的算法在paper中:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/web/82710.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python打卡day46@浙大疏锦行

知识点回顾&#xff1a; 不同CNN层的特征图&#xff1a;不同通道的特征图什么是注意力&#xff1a;注意力家族&#xff0c;类似于动物园&#xff0c;都是不同的模块&#xff0c;好不好试了才知道。通道注意力&#xff1a;模型的定义和插入的位置通道注意力后的特征图和热力图 内…

JavaSec-SPEL - 表达式注入

简介 SPEL(Spring Expression Language)&#xff1a;SPEL是Spring表达式语言&#xff0c;允许在运行时动态查询和操作对象属性、调用方法等&#xff0c;类似于Struts2中的OGNL表达式。当参数未经过滤时&#xff0c;攻击者可以注入恶意的SPEL表达式&#xff0c;从而执行任意代码…

SpringCloud——OpenFeign

概述&#xff1a; OpenFeign是基于Spring的声明式调用的HTTP客户端&#xff0c;大大简化了编写Web服务客户端的过程&#xff0c;用于快速构建http请求调用其他服务模块。同时也是spring cloud默认选择的服务通信工具。 使用方法&#xff1a; RestTemplate手动构建: // 带查询…

【深入学习Linux】System V共享内存

目录 前言 一、共享内存是什么&#xff1f; 共享内存实现原理 共享内存细节理解 二、接口认识 1.shmget函数——申请共享内存 2.ftok函数——生成key值 再次理解ftok和shmget 1&#xff09;key与shmid的区别与联系 2&#xff09;再理解key 3&#xff09;通过指令查看/释放系统中…

探索 Java 垃圾收集:对象存活判定、回收流程与内存策略

个人主页-爱因斯晨 文章专栏-JAVA学习笔记 热门文章-赛博算命 一、引言 在 Java 技术体系里&#xff0c;垃圾收集器&#xff08;Garbage Collection&#xff0c;GC&#xff09;与内存分配策略是自动内存管理的核心支撑。深入探究其原理与机制&#xff0c;对优化程序内存性能…

hbase资源和数据权限控制

hbase适合大数据量下点查 https://zhuanlan.zhihu.com/p/471133280 HBase支持对User、NameSpace和Table进行请求数和流量配额限制&#xff0c;限制频率可以按sec、min、hour、day 对于请求大小限制示例&#xff08;5K/sec,10M/min等&#xff09;&#xff0c;请求大小限制单位如…

大数据-275 Spark MLib - 基础介绍 机器学习算法 集成学习 随机森林 Bagging Boosting

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; 大模型篇章已经开始&#xff01; 目前已经更新到了第 22 篇&#xff1a;大语言模型 22 - MCP 自动操作 FigmaCursor 自动设计原型 Java篇开…

Delphi 实现远程连接 Access 数据库的指南

方法一&#xff1a;通过局域网共享 Access 文件&#xff08;简单但有限&#xff09; 步骤 1&#xff1a;共享 Access 数据库 将 .mdb 或 .accdb 文件放在局域网内某台电脑的共享文件夹中。 右键文件夹 → 属性 → 共享 → 启用共享并设置权限&#xff08;需允许网络用户读写&a…

VR视频制作有哪些流程?

VR视频制作流程知识 VR视频制作&#xff0c;作为融合了创意与技术的复杂制作过程&#xff0c;涵盖从初步策划到最终呈现的多个环节。在这个过程中&#xff0c;我们可以结合众趣科技的产品&#xff0c;解析每一环节的实现与优化&#xff0c;揭示背后的奥秘。 VR视频制作有哪些…

文件上传/下载接口开发

接口特性 文件传输接口与传统接口的核心差异体现在数据传输格式&#xff1a; 上传接口采用 multipart/form-data 格式支持二进制文件传输下载接口接收二进制流并实现本地文件存储 文件上传接口开发 接口规范 请求地址&#xff1a;/createbyfile 请求方式&#xff1a;POST…

深入学习RabbitMQ队列的知识

目录 1、AMQP协议 1.1、介绍 1.2、AMQP的特点 1.3、工作流程 1.4、消息模型 1.5、消息结构 1.6、AMQP 的交换器类型 2、RabbitMQ结构介绍 2.1、核心组件 2.2、最大特点 2.3、工作原理 3、消息可靠性保障 3.1、生产端可靠性 1、生产者确认机制 2、持久化消息 3.…

【计算机网络】NAT、代理服务器、内网穿透、内网打洞、局域网中交换机

&#x1f525;个人主页&#x1f525;&#xff1a;孤寂大仙V &#x1f308;收录专栏&#x1f308;&#xff1a;计算机网络 &#x1f339;往期回顾&#x1f339;&#xff1a;【计算机网络】数据链路层——ARP协议 &#x1f516;流水不争&#xff0c;争的是滔滔不息 一、网络地址转…

[论文阅读] 人工智能 | 大语言模型计划生成的新范式:基于过程挖掘的技能学习

#论文阅读# 大语言模型计划生成的新范式&#xff1a;基于过程挖掘的技能学习 论文信息 Skill Learning Using Process Mining for Large Language Model Plan Generation Andrei Cosmin Redis, Mohammadreza Fani Sani, Bahram Zarrin, Andrea Burattin Cite as: arXiv:2410.…

C文件操作2

五、文件的随机读写 这些函数都需要包含头文件 #include<stdio.h> 5.1 fseek 根据文件指针的位置和偏移量来定位文件指针&#xff08;文件内容的光标&#xff09; &#xff08;重新定位流位置指示器&#xff09; int fseek ( FILE * stream, long int offset, int or…

react私有样式处理

react私有样式处理 Nav.jsx Menu.jsx vue中通过scoped来实现样式私有化。加上scoped&#xff0c;就属于当前组件的私有样式。 给视图中的元素都加了一个属性data-v-xxx&#xff0c;然后给这些样式都加上属性选择器。&#xff08;deep就是不加属性也不加属性选择器&#xff09; …

【信创-k8s】海光/兆芯+银河麒麟V10离线部署k8s1.31.8+kubesphere4.1.3

❝ KubeSphere V4已经开源半年多&#xff0c;而且v4.1.3也已经出来了&#xff0c;修复了众多bug。介于V4优秀的LuBan架构&#xff0c;核心组件非常少&#xff0c;资源占用也显著降低&#xff0c;同时带来众多功能和便利性。我们决定与时俱进&#xff0c;使用1.30版本的Kubernet…

单片机内部结构基础知识 FLASH相关解读

一、总线简单说明 地址总线、控制总线、数据总线 什么是8位8051框架结构的微控制器&#xff1f; 数据总线宽度为8位&#xff0c;即CPU一次处理或传输的数据量为8位&#xff08;1字节&#xff09; 同时还有一个16位的地址总线&#xff0c;这个地方也刚好对应了为什么能看到内存…

HTTPS加密的介绍

HTTPS&#xff08;HyperText Transfer Protocol Secure&#xff0c;超文本传输安全协议&#xff09;是HTTP协议的安全版本。它在HTTP的基础上加入了SSL/TLS协议&#xff0c;用于对数据进行加密&#xff0c;并确保数据传输过程中的机密性、完整性和身份验证。 在HTTPS出现之前&a…

【freertos-kernel】stream_buffer

文章目录 补充任务通知发送处理ulTaskGenericNotifyTakexTaskGenericNotifyWait 清除xTaskGenericNotifyStateClearulTaskGenericNotifyValueClear 结构体StreamBufferHandle_tStreamBufferCallbackFunction_t 创建xStreamBufferGenericCreatestream buffer的类型 删除vStreamB…

在word中点击zotero Add/Edit Citation没有反应的解决办法

重新安装了word插件 1.关掉word 2.进入Zotero左上角编辑-引用 3.往下滑找到Microsoft Word&#xff0c;点重新安装加载项