1. 定义
nn.Linear
是 PyTorch 中最基础的全连接(fully‐connected)线性层,也称仿射变换层(affine layer)。它对输入张量做一次线性变换:
output = x W T + b \text{output} = x W^{T} + b output=xWT+b
其中, W W W是形状为 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)的权重矩阵, b b b是长度为 out_features \text{out\_features} out_features的偏置向量。
2. 输入与输出
-
输入(Input)
- 类型:浮点型张量(如
torch.float32
,也可与权重同 dtype) - 形状: ( … , i n _ f e a t u r e s ) (\dots, \mathrm{in\_features}) (…,in_features),最后一维必须与
in_features
匹配;前面的…
可以是 batch 大小或任意额外维度。
- 类型:浮点型张量(如
-
输出(Output)
- 类型:浮点型张量
- 形状: ( … , o u t _ f e a t u r e s ) (\dots, \mathrm{out\_features}) (…,out_features),即将最后一维从
in_features
变换为out_features
,前面的维度保持不变。
3. 底层原理
- 矩阵乘法
对输入张量的最后一维做矩阵乘法:
y [ . . . , j ] = ∑ i = 1 i n _ f e a t u r e s x [ . . . , i ] × W j , i ( ∀ j = 1 … o u t _ f e a t u r e s ) y[..., j] = \sum_{i=1}^{\mathrm{in\_features}} x[..., i] \times W_{j,i} \quad (\forall\,j=1\ldots\mathrm{out\_features}) y[...,j]=i=1∑in_featuresx[...,i]×Wj,i(∀j=1…out_features) - 加偏置
若bias=True
,则在乘积结果上逐元素加偏置 b j b_j bj:
y [ . . . , j ] = ( x W T ) [ . . . , j ] + b j y[..., j] \;=\; (x W^T)[..., j] \;+\; b_j y[...,j]=(xWT)[...,j]+bj - 梯度更新
- 权重 W W W:梯度 ∂ L ∂ W j , i = ∑ x [ . . . , i ] × δ y [ . . . , j ] \frac{\partial \mathcal{L}}{\partial W_{j,i}} = \sum x[..., i]\times \delta y[..., j] ∂Wj,i∂L=∑x[...,i]×δy[...,j]
- 偏置 b b b:梯度 ∑ δ y [ . . . , j ] \sum \delta y[..., j] ∑δy[...,j]
优化器(如 SGD、Adam 等)根据梯度更新参数。
4. 构造函数参数详解
参数 | 类型 & 默认 | 说明 |
---|---|---|
in_features | int | 必填。输入特征维度,即每个样本最后一维的大小。 |
out_features | int | 必填。输出特征维度,即映射后的最后一维大小(神经元个数)。 |
bias | bool ,默认 True | 是否使用偏置向量 b 。若为 False ,则不加偏置。 |
device | torch.device 或 None | 指定权重和偏置所在设备(CPU/GPU);若为 None ,默认继承父模块设备。 |
dtype | torch.dtype 或 None | 指定权重和偏置的数据类型;若为 None ,默认继承父模块 dtype。 |
权重与偏置初始化
- 默认情况下,
W
按均匀分布初始化:
W j , i ∼ U ( − 1 i n _ f e a t u r e s , 1 i n _ f e a t u r e s ) W_{j,i}\sim \mathcal{U}\Bigl(-\sqrt{\tfrac{1}{\mathrm{in\_features}}},\;\sqrt{\tfrac{1}{\mathrm{in\_features}}}\Bigr) Wj,i∼U(−in_features1,in_features1) - 偏置 b b b初始化为全零。
5. 使用示例
import torch
import torch.nn as nn# 1. 定义线性层
in_dim = 64 # 输入维度
out_dim = 10 # 输出维度(例如分类 10 类的 logits)
linear = nn.Linear(in_features=in_dim,out_features=out_dim,bias=True
)# 2. 构造输入
# 假设 batch_size=8
x = torch.randn(8, in_dim) # shape=[8,64]# 3. 前向计算
# 输出 shape=[8,10]
y = linear(x)
print(y.shape) # torch.Size([8, 10])
如果不希望使用偏置:
linear_no_bias = nn.Linear(in_dim, out_dim, bias=False)
在更高维场景下也可:
# 输入 shape=[batch, seq_len, in_dim]
x_seq = torch.randn(8, 5, in_dim)
# 输出 shape=[8, 5, out_dim]
y_seq = linear(x_seq)
6. 注意事项
- 维度匹配
- 确保输入最后一维等于
in_features
,否则会报维度不匹配错误。
- 确保输入最后一维等于
- 批量处理
- 对多维输入,
nn.Linear
自动应用到最后一维,无需手动 reshape(除非想将多个维度合并后统一处理)。
- 对多维输入,
- 初始化
- 默认初始化适用于常见场景;若训练不稳定,可手动调用
nn.init
系列方法(如kaiming_uniform_
,xavier_normal_
等)重新初始化。
- 默认初始化适用于常见场景;若训练不稳定,可手动调用
- 无偏置场景
- 对于某些网络结构(如批归一化紧跟线性层),可关闭
bias
减少参数不用多余偏置。
- 对于某些网络结构(如批归一化紧跟线性层),可关闭
- 设备与 dtype
- 在多 GPU 或混合精度训练时,可通过
device
与dtype
参数显式控制,避免后续.to()
调用。
- 在多 GPU 或混合精度训练时,可通过
- 与 Conv1d 的联系
- 本质上,
nn.Conv1d(in_dim, out_dim, kernel_size=1)
相当于在时间维度(或序列维度)上对每个位置做一个nn.Linear
;理解这一点有助于模型设计。
- 本质上,