MHSA:使用 PyTorch 实现的多头自注意力 (Multi-head Self Attention) 代码示例,包含详细注释说明:
-
线性投影
通过三个线性层分别生成查询(Q)、键(K)、值(V)矩阵:
Q=Wq⋅x,K=Wk⋅x,V=Wv⋅xQ = W_q·x, \quad K = W_k·x, \quad V = W_v·xQ=Wq⋅x,K=Wk⋅x,V=Wv⋅x -
分割多头
将每个矩阵分割为 hhh 个头部:
Q→[Q1,Q2,...,Qh],每个Qi∈Rdk\text{Q} \rightarrow [Q_1, Q_2, ..., Q_h], \quad \text{每个} Q_i \in \mathbb{R}^{d_k}Q→[Q1,Q2,...,Qh],每个Qi∈Rdk -
计算注意力分数
对每个头部计算缩放点积注意力:
Attention(Qi,Ki,Vi)=softmax(QiKiTdk)Vi\text{Attention}(Q_i, K_i, V_i) = \text{softmax}\left(\frac{Q_iK_i^T}{\sqrt{d_k}}\right)V_iAttention(Qi,Ki,Vi)=softmax(dkQiKiT)Vi -
合并多头
拼接所有头部的输出并通过线性层:
MultiHead=Wo⋅[head1;head2;...;headh]\text{MultiHead} = W_o·[\text{head}_1; \text{head}_2; ... ; \text{head}_h]MultiHead=Wo⋅[head1;head2;...;headh]
数学原理:
多头注意力允许模型同时关注不同表示子空间的信息:
MultiHead(Q,K,V)=Concat(head1,...,headh)WO\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^OMultiHead(Q,K,V)=Concat(head1,...,headh)WO
其中每个头的计算为:
headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)headi=Attention(QWiQ,KWiK,VWiV)
以下是一个使用 PyTorch 实现的多头自注意力 (Multi-head Self Attention) 代码示例,包含详细注释说明:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass MultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads):"""embed_dim: 输入向量维度num_heads: 注意力头的数量"""super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_heads # 每个头的维度# 检查维度是否可整除assert self.head_dim * num_heads == embed_dim# 定义线性变换层self.query = nn.Linear(embed_dim, embed_dim)self.key = nn.Linear(embed_dim, embed_dim)self.value = nn.Linear(embed_dim, embed_dim)self.fc_out = nn.Linear(embed_dim, embed_dim)def forward(self, x):"""x: 输入张量,形状为 (batch_size, seq_len, embed_dim)"""batch_size = x.shape[0] #[4,10,512]# 1. 线性投影Q = self.query(x) # (batch_size, seq_len, embed_dim) #[4,10,512]K = self.key(x) # (batch_size, seq_len, embed_dim) #[4,10,512]V = self.value(x) # (batch_size, seq_len, embed_dim) #[4,10,512]# 2. 分割多头Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3) #[4,8,10,64]K = K.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3) #[4,8,10,64]V = V.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3) #[4,8,10,64]# 现在形状: (batch_size, num_heads, seq_len, head_dim)# 3. 计算注意力分数# 计算 Q·K^T / sqrt(d_k)energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / (self.head_dim ** 0.5) #[4,8,10,64]* #[4,8,64,10] = [4,8,10,10]# 形状: (batch_size, num_heads, seq_len, seq_len)# 4. 应用softmax获取注意力权重attention = F.softmax(energy, dim=-1)# 形状: (batch_size, num_heads, seq_len, seq_len)# 5. 计算加权和out = torch.matmul(attention, V)#[4,8,10,10]* [4,8,10,64] = [4,8,10,64]# 形状: (batch_size, num_heads, seq_len, head_dim)# 6. 合并多头out = out.permute(0, 2, 1, 3).contiguous()out = out.view(batch_size, -1, self.embed_dim)# 形状: (batch_size, seq_len, embed_dim)# 7. 最终线性变换out = self.fc_out(out)return out# 使用示例
if __name__ == "__main__":# 参数设置embed_dim = 512 # 输入维度num_heads = 8 # 注意力头数seq_len = 10 # 序列长度batch_size = 4 # 批大小# 创建多头注意力模块mha = MultiHeadAttention(embed_dim, num_heads)# 生成模拟输入数据input_data = torch.randn(batch_size, seq_len, embed_dim)# 前向传播output = mha(input_data)print("输入形状:", input_data.shape)print("输出形状:", output.shape)
输出示例:
输入形状: torch.Size([4, 10, 512])
输出形状: torch.Size([4, 10, 512])
此实现保持了输入输出维度一致,可直接集成到Transformer等架构中。