0. 背景
假设我们有两个矩阵:
- 矩阵 A,尺寸为
(n, d_k)
- 矩阵 B,尺寸为
(d_k, n)
我们要计算它们的乘积 C = A * B。
那么这个过程所需的计算量是多少?
1. 结果矩阵的尺寸
首先,结果矩阵 C 的尺寸是由第一个矩阵的行数和第二个矩阵的列数决定的。
- C 的行数 = A 的行数 =
n
- C 的列数 = B 的列数 =
n
所以,结果矩阵 C 的尺寸为(n, n)
。
2. 单个元素的计算量
接下来,我们看如何计算结果矩阵 C 中的任意一个元素 C_ij
(第 i 行,第 j 列的元素)。
根据矩阵乘法的定义,C_ij
是由 A 的第 i 行和 B 的第 j 列的点积(dot product)得到的。
- A 的第 i 行是一个有
d_k
个元素的行向量。 - B 的第 j 列是一个有
d_k
个元素的列向量。
计算过程如下:
C_ij = A_i1 * B_1j + A_i2 * B_2j + ... + A_id_k * B_d_kj
为了计算这一个 C_ij
元素,我们需要:
d_k
次乘法 (每个A_ik
乘以B_kj
)d_k - 1
次加法 (将d_k
个乘积相加)
3. 总计算量
现在我们来计算整个矩阵 C 的总计算量。
结果矩阵 C 是一个 (n, n)
的矩阵,所以它总共有 n * n = n^2
个元素。
我们将单个元素的计算量乘以总元素数量:
-
总乘法次数 = (每个元素的乘法次数) × (总元素个数)
=d_k * n^2
-
总加法次数 = (每个元素的加法次数) × (总元素个数)
=(d_k - 1) * n^2
4. 结论
将一个 (n, d_k)
矩阵与一个 (d_k, n)
矩阵相乘:
- 总乘法运算量为
n² * d_k
次。 - 总加法运算量为
n² * (d_k - 1)
次。
在计算机科学和机器学习领域,我们通常使用浮点运算次数 (FLOPs, Floating Point Operations) 来衡量计算量。一次乘法和一次加法通常被打包看作一次操作(特别是在现代硬件的FMA指令中)。
总FLOPs ≈ 总乘法次数 + 总加法次数
= (n² * d_k) + (n² * (d_k - 1))
= n² * (d_k + d_k - 1)
= n² * (2d_k - 1)
当 d_k
比较大时,我们通常近似为 2 * n² * d_k
FLOPs。
5. 应用背景(非常重要)
这个计算 (n, d_k) * (d_k, n)
在 Transformer模型 的自注意力(Self-Attention)机制中非常核心。
n
通常代表序列长度(Sequence Length)。d_k
代表Query和Key向量的维度。
这个计算对应的是 Query (Q) 矩阵 和 Key (K) 矩阵的转置 (Kᵀ) 相乘,以得到注意力分数矩阵(Attention Score Matrix)。
- Q 的尺寸是
(n, d_k)
- K 的尺寸是
(n, d_k)
,所以 Kᵀ 的尺寸是(d_k, n)
- Q * Kᵀ 的结果是一个
(n, n)
的矩阵,其计算复杂度就是 O(n² * d_k) 。
这也解释了为什么标准Transformer模型的计算量和内存占用会随着序列长度 n
的增加而呈平方级增长,这是限制其处理非常长序列的主要瓶颈之一。
6. 附录(泛化)
补充一种更加泛化的计算方式。我们来分析一下将一个 (a, b)
矩阵与一个 (b, c)
矩阵相乘的计算量。
假设我们有两个矩阵:
- 矩阵 A,尺寸为
(a, b)
(a
行,b
列) - 矩阵 B,尺寸为
(b, c)
(b
行,c
列)
我们要计算它们的乘积 C = A * B。
6.1 结果矩阵的尺寸
首先,结果矩阵 C 的尺寸由 A 的行数和 B 的列数决定。
- C 的行数 = A 的行数 =
a
- C 的列数 = B 的列数 =
c
所以,结果矩阵 C 的尺寸为(a, c)
。
6.2 单个元素的计算量
接下来,我们计算结果矩阵 C 中的任意一个元素 C_ij
(第 i
行, 第 j
列)。
C_ij
是由 A 的第 i
行和 B 的第 j
列的点积(dot product)得到的。
- A 的第
i
行是一个长度为b
的行向量。 - B 的第
j
列是一个长度为b
的列向量。
计算公式为:
C_ij = A_i1 * B_1j + A_i2 * B_2j + ... + A_ib * B_bj
为了计算这一个 C_ij
元素,我们需要:
b
次乘法b - 1
次加法
6.3 总计算量
结果矩阵 C 是一个 (a, c)
的矩阵,它总共有 a * c
个元素。
我们将单个元素的计算量乘以总元素数量,得到整个矩阵的计算量:
-
总乘法次数 = (每个元素的乘法次数) × (总元素个数)
=b * (a * c)
=a * b * c
-
总加法次数 = (每个元素的加法次数) × (总元素个数)
=(b - 1) * (a * c)
=a * c * (b - 1)
6.4 结论与总结
对于一个 (a, b)
矩阵和一个 (b, c)
矩阵的乘法:
- 总乘法运算量为
a * b * c
次。 - 总加法运算量为
a * c * (b - 1)
次。
在衡量算法复杂度时,我们通常使用 Big O 表示法,或者计算总的 浮点运算次数 (FLOPs)。
-
总FLOPs ≈ 总乘法次数 + 总加法次数
=(a * b * c) + (a * c * (b - 1))
=a * c * (b + b - 1)
=a * c * (2b - 1)
-
时间复杂度 (Time Complexity):
当a
,b
,c
都很大时,常数2
和-1
可以忽略。因此,计算复杂度为 O(abc)。
6.5 验证一下之前的问题
让我们用这个通用公式来验证你之前的问题:一个 (n, d_k)
矩阵乘以一个 (d_k, n)
矩阵。
这里:
a = n
b = d_k
c = n
代入通用公式:
- 总乘法次数 =
a * b * c
=n * d_k * n
=n² * d_k
- 总加法次数 =
a * c * (b - 1)
=n * n * (d_k - 1)
=n² * (d_k - 1)
这与我们之前得到的结论完全一致。这个 O(abc)
的公式是矩阵乘法计算量分析的基础。