1. 定义
nn.Embedding
是 PyTorch 中的 查表式嵌入层(lookup‐table),用于将离散的整数索引(如词 ID、实体 ID、离散特征类别等)映射到一个连续的、可训练的低维向量空间。它通过维护一个形状为 (num_embeddings, embedding_dim)
的权重矩阵,实现高效的“索引 → 向量”转换。
2. 输入与输出
-
输入
- 类型:整型张量(
torch.long
或torch.int64
),必须是 LongTensor,其他类型会报错。 - 形状:任意形状
(*, L)
,其中最内层长度L
常为序列长度,前面的*
可以是 batch 及其他维度。 - 取值范围:
0 ≤ index < num_embeddings
;超出范围会抛出IndexError
。
- 类型:整型张量(
-
输出
- 类型:浮点型张量(与权重相同的
dtype
,默认为torch.float32
)。 - 形状:
(*, L, embedding_dim)
;就是在输入张量后追加一个维度embedding_dim
。 - 语义:若输入某位置的值为
j
,则该位置对应输出就是权重矩阵的第j
行。
- 类型:浮点型张量(与权重相同的
3. 底层原理
-
查表操作 vs. One-hot 乘法
- 直观上,Embedding 相当于:
output = one_hot ( i n p u t ) × W \text{output} = \text{one\_hot}(input) \;\times\; W output=one_hot(input)×W
其中W
是(num_embeddings×embedding_dim)
的权重矩阵。 - 为避免显式构造稀疏的 one-hot 张量,PyTorch 直接根据索引做“取行”操作,效率更高、内存更省。
- 直观上,Embedding 相当于:
-
梯度更新
- 稠密模式(默认):整个
W
都有梯度缓冲,优化器根据梯度更新所有行。 - 稀疏模式(
sparse=True
):仅对被索引过的行计算和存储梯度,可配合optim.SparseAdam
高效更新,适合超大字典(百万级以上)但每次只访问少量索引的场景。
- 稠密模式(默认):整个
-
范数裁剪
- 若指定
max_norm
,每次前向都会对输出向量(即对应的行)做范数裁剪,保证其 L-norm_type
范数不超过max_norm
,有助于防止某些频繁访问的词向量过大。
- 若指定
-
权重初始化
- 默认初始化使用均匀分布:
W i , j ∼ U ( − 1 num_embeddings , 1 num_embeddings ) W_{i,j} \sim \mathcal{U}\Bigl(-\sqrt{\tfrac{1}{\text{num\_embeddings}}},\;\sqrt{\tfrac{1}{\text{num\_embeddings}}}\Bigr) Wi,j∼U(−num_embeddings1,num_embeddings1) - 可以通过
_weight
参数传入外部预训练权重(如 Word2Vec、GloVe 等)。
- 默认初始化使用均匀分布:
4. 构造函数参数详解
参数 | 类型及默认 | 说明 |
---|---|---|
num_embeddings | int | 必填。嵌入表行数,等于类别总数(最大索引 + 1)。 |
embedding_dim | int | 必填。每个向量的维度。 |
padding_idx | int 或 None | 默认 None 。指定该索引对应行始终输出全零,并且该行的梯度永远为 0,适合做序列填充。 |
max_norm | float 或 None | 默认 None 。若设为数值,每次前向时对取出的向量做范数裁剪(L-norm_type ≤ max_norm )。 |
norm_type | float ,默认 2 | 与 max_norm 配合使用时定义范数类型,如 1-范数、2-范数等。 |
scale_grad_by_freq | bool ,默认 False | 若为 True ,在反向传播阶段按照索引在 batch 中出现的频次对梯度做缩放(出现越多,梯度越小),有助于高频词的梯度平滑。 |
sparse | bool ,默认 False | 若为 True ,开启稀疏更新,仅对被访问行生成梯度;必须配合 optim.SparseAdam 使用,不支持常规稠密优化器。 |
_weight | Tensor 或 None | 若提供,则用此张量(形状应为 (num_embeddings, embedding_dim) )作为权重初始化,否则随机初始化。 |
5. 使用示例
import torch
import torch.nn as nn# 1. 参数设定
vocab_size = 10000 # 词表大小
embed_dim = 300 # 嵌入维度# 2. 创建 Embedding 层
embedding = nn.Embedding(num_embeddings=vocab_size,embedding_dim=embed_dim,padding_idx=0, # 将 0 作为填充索引,输出全 0max_norm=5.0, # 向量范数不超过 5norm_type=2.0,scale_grad_by_freq=True,sparse=False
)# 3. 构造输入
# batch_size=2, seq_len=6
input_ids = torch.tensor([[ 1, 234, 56, 789, 0, 23],[123, 4, 567, 8, 9, 0],
], dtype=torch.long)# 4. 前向计算
# 输出 shape = [2, 6, 300]
output = embedding(input_ids)
print(output.shape) # -> torch.Size([2, 6, 300])
加载并冻结预训练权重
import numpy as np# 假设有预训练权重 pre_trained.npy,shape=(10000,300)
weights = torch.from_numpy(np.load("pre_trained.npy"))
embed_pre = nn.Embedding(num_embeddings=vocab_size,embedding_dim=embed_dim,_weight=weights
)
# 冻结所有权重
embed_pre.weight.requires_grad = False
6. 注意事项
- 类型与范围
- 输入必须为 LongTensor,且所有索引满足
0 ≤ index < num_embeddings
。
- 输入必须为 LongTensor,且所有索引满足
- Padding 与 Mask
- 仅指定
padding_idx
会返回零向量,但上游网络(如 RNN、Transformer)还需显式 mask,避免无效位置影响注意力或累积状态。
- 仅指定
- 性能考量
max_norm
每次前向都做范数计算和裁剪,若不需要可关闭以提升速度。
- 稀疏更新限制
sparse=True
可节省内存,但只支持SparseAdam
,且在 GPU 上效率有时不如稠密模式。
- EmbeddingBag
- 对于可变长度序列的 sum/mean/power-mean 汇聚,可使用
nn.EmbeddingBag
,避免中间张量开销。
- 对于可变长度序列的 sum/mean/power-mean 汇聚,可使用
- 分布式与大词表
- 在分布式训练时,可将嵌入表切分到多个进程上(
torch.nn.parallel.DistributedDataParallel
+torch.nn.Embedding
支持参数分布式)。 - 超大词表(千万级)时,可考虑动态加载、分布式哈希表或专用库(如 DeepSpeed 的嵌入稀疏优化)。
- 在分布式训练时,可将嵌入表切分到多个进程上(