【模型细节】MHSA:多头自注意力 (Multi-head Self Attention) 详细解释,使用 PyTorch代码示例说明

MHSA:使用 PyTorch 实现的多头自注意力 (Multi-head Self Attention) 代码示例,包含详细注释说明:

  1. 线性投影
    通过三个线性层分别生成查询(Q)、键(K)、值(V)矩阵:
    Q=Wq⋅x,K=Wk⋅x,V=Wv⋅xQ = W_q·x, \quad K = W_k·x, \quad V = W_v·xQ=Wqx,K=Wkx,V=Wvx

  2. 分割多头
    将每个矩阵分割为 hhh 个头部:
    Q→[Q1,Q2,...,Qh],每个Qi∈Rdk\text{Q} \rightarrow [Q_1, Q_2, ..., Q_h], \quad \text{每个} Q_i \in \mathbb{R}^{d_k}Q[Q1,Q2,...,Qh],每个QiRdk

  3. 计算注意力分数
    对每个头部计算缩放点积注意力:
    Attention(Qi,Ki,Vi)=softmax(QiKiTdk)Vi\text{Attention}(Q_i, K_i, V_i) = \text{softmax}\left(\frac{Q_iK_i^T}{\sqrt{d_k}}\right)V_iAttention(Qi,Ki,Vi)=softmax(dkQiKiT)Vi

  4. 合并多头
    拼接所有头部的输出并通过线性层:
    MultiHead=Wo⋅[head1;head2;...;headh]\text{MultiHead} = W_o·[\text{head}_1; \text{head}_2; ... ; \text{head}_h]MultiHead=Wo[head1;head2;...;headh]

数学原理:

多头注意力允许模型同时关注不同表示子空间的信息:
MultiHead(Q,K,V)=Concat(head1,...,headh)WO\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^OMultiHead(Q,K,V)=Concat(head1,...,headh)WO
其中每个头的计算为:
headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)headi=Attention(QWiQ,KWiK,VWiV)

以下是一个使用 PyTorch 实现的多头自注意力 (Multi-head Self Attention) 代码示例,包含详细注释说明:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass MultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads):"""embed_dim: 输入向量维度num_heads: 注意力头的数量"""super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_heads  # 每个头的维度# 检查维度是否可整除assert self.head_dim * num_heads == embed_dim# 定义线性变换层self.query = nn.Linear(embed_dim, embed_dim)self.key = nn.Linear(embed_dim, embed_dim)self.value = nn.Linear(embed_dim, embed_dim)self.fc_out = nn.Linear(embed_dim, embed_dim)def forward(self, x):"""x: 输入张量,形状为 (batch_size, seq_len, embed_dim)"""batch_size = x.shape[0]  #[4,10,512]# 1. 线性投影Q = self.query(x)  # (batch_size, seq_len, embed_dim) #[4,10,512]K = self.key(x)    # (batch_size, seq_len, embed_dim) #[4,10,512]V = self.value(x)  # (batch_size, seq_len, embed_dim) #[4,10,512]# 2. 分割多头Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  #[4,8,10,64]K = K.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  #[4,8,10,64]V = V.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  #[4,8,10,64]# 现在形状: (batch_size, num_heads, seq_len, head_dim)# 3. 计算注意力分数# 计算 Q·K^T / sqrt(d_k)energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / (self.head_dim ** 0.5) #[4,8,10,64]* #[4,8,64,10] = [4,8,10,10]# 形状: (batch_size, num_heads, seq_len, seq_len)# 4. 应用softmax获取注意力权重attention = F.softmax(energy, dim=-1)# 形状: (batch_size, num_heads, seq_len, seq_len)# 5. 计算加权和out = torch.matmul(attention, V)#[4,8,10,10]* [4,8,10,64] = [4,8,10,64]# 形状: (batch_size, num_heads, seq_len, head_dim)# 6. 合并多头out = out.permute(0, 2, 1, 3).contiguous()out = out.view(batch_size, -1, self.embed_dim)# 形状: (batch_size, seq_len, embed_dim)# 7. 最终线性变换out = self.fc_out(out)return out# 使用示例
if __name__ == "__main__":# 参数设置embed_dim = 512  # 输入维度num_heads = 8    # 注意力头数seq_len = 10     # 序列长度batch_size = 4   # 批大小# 创建多头注意力模块mha = MultiHeadAttention(embed_dim, num_heads)# 生成模拟输入数据input_data = torch.randn(batch_size, seq_len, embed_dim)# 前向传播output = mha(input_data)print("输入形状:", input_data.shape)print("输出形状:", output.shape)

输出示例:

输入形状: torch.Size([4, 10, 512])
输出形状: torch.Size([4, 10, 512])

此实现保持了输入输出维度一致,可直接集成到Transformer等架构中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/91288.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/91288.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PGSQL运维优化:提升vacuum执行时间观测能力

本文是 IvorySQL 2025 生态大会暨 PostgreSQL 高峰论坛上的演讲内容,作者:NKYoung。 6 月底济南召开的 HOW2025 IvorySQL 生态大会上,我在内核论坛分享了“提升 vacuum 时间观测能力”的主题,提出了新增统计信息的方法&#xff0c…

神奇的数据跳变

目的 上周遇上了一个非常奇怪的问题,就是软件的数据在跳变,本来数据应该是158吧,数据一会变成10,一会又变成158,数据在不断地跳变,那是怎么回事?? 这个问题非常非常的神奇,让人感觉太不可思议了。 这是这段时间,我遇上的最神奇的事了,没有之一,最神奇的事,下面…

【跨国数仓迁移最佳实践3】资源消耗减少50%!解析跨国数仓迁移至MaxCompute背后的性能优化技术

本系列文章将围绕东南亚头部科技集团的真实迁移历程展开,逐步拆解 BigQuery 迁移至 MaxCompute 过程中的关键挑战与技术创新。本篇为第3篇,解析跨国数仓迁移背后的性能优化技术。注:客户背景为东南亚头部科技集团,文中用 GoTerra …

【MySQL集群架构与实践3】使用Dcoker实现读写分离

目录 一. 在Docker中安装ShardingSphere 二 实践:读写分离 2.1 应用场景 2.2 架构图 2.3 服务器规划 2.4 启动数据库服务器 2.5. 配置读写分离 2.6 日志配置 2.7 重启ShardingSphere 2.8 测试 2.9. 负载均衡 2.9.1. 随机负载均衡算法示例 2.9.2. 轮询负…

maven的阿里云镜像地址

在 Maven 中配置阿里云镜像可以加速依赖包的下载,尤其是国内环境下效果明显。以下是阿里云 Maven 镜像的配置方式: 配置步骤:找到 Maven 的配置文件 settings.xml 全局配置:位于 Maven 安装目录的 conf/settings.xml用户级配置&am…

大语言模型信息抽取系统解析

这段代码实现了一个基于大语言模型的信息抽取系统,能够从金融和新闻类文本中提取结构化信息。下面我将详细解析整个代码的结构和功能。1. 代码整体结构代码主要分为以下几个部分:模式定义:定义不同领域(金融、新闻)需要抽取的实体类型示例数据…

Next实习项目总结串联讲解(一)

下面是一些 Next.js 前端面试中常见且具深度的问题,按照逻辑模块整理,同时提供示范回答建议,便于你条理清晰地展示理解与实践经验。 ✅ 面试讲述结构建议 先讲 Next.js 是什么,它为什么比 React 更高级。(支持 SSR/SSG/ISR,提升S…

React开发依赖分析

1. React小案例: 在界面显示一个文本:Hello World点击按钮后,文本改为为:Hello React 2. React开发依赖 2.1. 开发React必须依赖三个库: 2.1.1. react: 包含react所必须的核心代码2.1.2. react-dom: react渲染在不同平…

工具(一)Cursor

目录 一、介绍 二、如何打开文件 1、从idea跳转文件 2、单独打开项目 三、常见使用 1、Chat 窗口 Ask 对话模式 1.1、使用技巧 1.2 发送和使用 codebase 发送区别 1.3、问题快速修复 2、Chat 窗口 Agent 对话模式 2.1、agent模式功能 2.2、Chat 窗口回滚&撤销 2.3…

Prompt编写规范指引

1、📖 引言 随着人工智能生成内容(AIGC)技术的快速发展,越来越多的开发者开始利用AIGC工具来辅助代码编写。然而,如何编写有效的提示词(Prompt)以引导AIGC生成高质量的代码,成为了许…

自我学习----绘制Mark点

在PCB的Layout过程中我们需在光板上放置Mark点以方便生产时的光学定位(三点定位);我个人Mark点绘制步骤如下: layer层:1.放置直径1mm的焊盘(无网络连接) 2.放置一个圆直径2mm,圆心与…

2025年财税行业拓客破局:小蓝本财税版AI拓客系统助力高效拓客

2025年,在"金税四期"全面实施的背景下,中国财税服务市场迎来爆发式增长,根据最新的市场研究报告,2025年中国财税服务行业产值将达2725.7亿元。然而,行业高速发展的背后,80%的财税公司却陷入获客成…

双向链表,对其实现头插入,尾插入以及遍历倒序输出

1.创建一个节点,并将链表的首节点返回创建一个独立节点,没有和原链表产生任何关系#include "head.h"typedef struct Node { int num; struct Node*pNext; struct Node*pPer; }NODE;后续代码:NODE*createNode(int value) {NODE*new …

2025年自动化工程与计算机网络国际会议(ICAECN 2025)

2025年自动化工程与计算机网络国际会议(ICAECN 2025) 2025 International Conference on Automation Engineering and Computer Networks一、大会信息会议简称:ICAECN 2025 大会地点:中国柳州 审稿通知:投稿后2-3日内通…

12.Origin2021如何绘制误差带图?

12.Origin2021如何绘制误差带图?选中Y3列→点击统计→选择描述统计→选择行统计→选择打开对话框输入范围选择B列到D列点击输出量→勾选均值和标准差Control选择下面三列点击绘图→选择基础2D图→选择误差带图双击图像→选择符号和颜色点击第二个Sheet1→点击误差棒→连接选择…

如何使用API接口获取淘宝店铺订单信息

要获取淘宝店铺的订单信息,您需要通过淘宝开放平台(Taobao Open Platform, TOP)提供的API接口来实现。以下是详细步骤:1. 注册淘宝开放平台账号访问淘宝开放平台注册开发者账号并完成实名认证创建应用获取App Key和App Secret2. 申请API权限在"我的…

【Kiro Code 从入门到精通】重要的功能

一、Kiro 是什么? Kiro 是一款智能型集成开发环境(IDE),借助规格说明(specs)、向导(steer)、钩子(hooks)帮助你高效完成工作。 二、Specs 规格说明 规范&…

直播间里的酒旅新故事:内容正在重构消费链路

文/李乐编辑/子夜今年暑期,旅游的热浪席卷全国。机场、火车站人潮涌动,电子屏上滚动的航班信息与检票口前的长队交织成繁忙的出行图景,酒店预订量也在这股热潮中节节攀升。连线 Insight关注到,今年的暑期游有了一些新变化&#xf…

50天50个小项目 (Vue3 + Tailwindcss V4) ✨ | VerifyAccountUi(验证码组件)

&#x1f4c5; 我们继续 50 个小项目挑战&#xff01;—— VerifyAccountUi组件 仓库地址&#xff1a;https://github.com/SunACong/50-vue-projects 项目预览地址&#xff1a;https://50-vue-projects.vercel.app/ 使用 Vue 3 的 <script setup> 语法结合 Tailwind CS…

AbstractAuthenticationToken 认证流程中​​认证令牌的核心抽象类详解

AbstractAuthenticationToken 认证流程中​​认证令牌的核心抽象类详解在 Spring Security 中&#xff0c;AbstractAuthenticationToken 是 Authentication 接口的​​抽象实现类​​&#xff0c;其核心作用是为具体的认证令牌&#xff08;如用户名密码令牌、JWT 令牌等&#x…