One Line Description
规约操作合集
描述
核心是通过一个字符串方程来描述输入张量的维度如何组合、变换、相乘和求和,最终得到输出张量。
字符串方程格式
格式为 输入下标 -> 输出下标,例如 "ij,jk->ik"。
>左边:描述一个或多个输入张量的维度。每个字母代表一个维度,用逗号分隔不同的输入张量。>右边:描述输出张量的维度。
核心规则
- 求和/收缩 (Summation/Contraction):在输入中出现,但在输出中没有出现的字母,代表这些维度需要被求和(或“收缩”)。
- 保留 (Preservation):在输入和输出中都出现的字母,代表这些维度被保留下来。它们在输出中的顺序决定了最终的维度排列。
常见用例
- 矩阵转置 (Transpose)
- 方程:
"ij->ji" - 解释: 输入是一个2D张量,维度为
i和j。输出也是一个2D张量,但维度顺序换成了j和i。
- 方程:
- 求和 (Summation)
- 沿特定轴求和:
"ij->i"(对j轴求和) - 所有元素求和:
"ij->"(对i和j轴都求和,输出一个标量)
- 沿特定轴求和:
- 矩阵乘法 (Matrix Multiplication)
- 方程:
"ij,jk->ik" - 解释: 输入两个2D张量。第一个形状为
(i, j),第二个为(j, k)。j维度在输入中出现,但在输出中没有,所以它被求和。最终输出形状为(i, k)。
- 方程:
- 批量矩阵乘法 (Batch Matrix Multiplication)
- 方程:
"bij,bjk->bik" - 解释:
b是批次维度,它在输入和输出中都存在,所以被保留。j维度被求和。这会在批次维度上独立执行矩阵乘法。
- 方程:
- 点积 (Dot Product)
- 方程:
"i,i->" - 解释: 输入两个1D向量,长度都为
i。i维度被求和,输出一个标量。
- 方程:
- 提取对角线 (Extract Diagonal)
- 方程:
"ii->i" - 解释: 输入一个2D方阵
(i, i)。只取i和j相同的元素,组成一个1D向量。
- 方程:
为什么引入这个算子?
- 简洁性 (Conciseness): 一个
Einsum操作可以替代多个Transpose,MatMul,ReduceSum等操作的组合,让计算图更清晰。 - 可读性 (Readability): 对于熟悉它的人来说,方程本身就清晰地描述了张量的变换逻辑,比一步步的操作链更容易理解。
- 通用性 (Versatility): 它覆盖了绝大多数常见的张量代数运算。
实际例子

逐步分解
- 输入1 (
embed): 维度为(b, m, c, h, w)b: Batch size (批次大小)m: Modalities / Heads (模态或头的数量)c: Channels (通道数)h: Height (高度)w: Width (宽度)- 从命名
embed和维度(b, c, h, w)来看,这很可能是一个卷积网络输出的特征图(Feature Map)。
- 输入2 (
guide): 维度为(b, n, m, c)b: Batch size (与输入1匹配)n: Number of new features / tokens (新特征或令牌的数量)m: Heads (与输入1匹配)c: Channels (与输入1匹配)- 从命名
guide来看,它可能是一个用于指导特征变换的“引导张量”或“注意力图”。
- 输出: 维度为
(b, m, h, w, n)
核心规则分析
- 求和维度 (被“消除”的维度) 我们找出在输入中出现、但在输出中没有出现的字母。
- 输入字母:
b, m, c, h, w, n - 输出字母:
b, m, h, w, n - 对比发现,字母
c在输入中存在,但在输出中消失了。 - 结论:
Einsum操作会对c这个维度进行求和(Summation)或收缩(Contraction)。这意味着它会沿着通道维度c对两个输入进行逐元素相乘,然后将结果累加起来。这本质上是一个**点积(Dot Product)**操作。
- 输入字母:
- 保留维度 (被“保留”的维度) 所有在输入和输出中都存在的字母都会被保留下来。
b和m在两个输入和输出中都存在,是批次和头的维度。h和w只在输入1中存在,并被保留到输出中,代表空间维度。n只在输入2中存在,并被保留到输出中,代表新的特征维度。
直接解释
对于批次中的每一项 (b) 和每一个头 (m):
- 我们有一个
(c, h, w)的特征图 (embed)。 - 我们还有一个
(n, c)的引导矩阵 (guide)。 - 操作的核心是,
guide矩阵告诉我们如何将embed特征图的c个通道线性组合成n个新的特征。 - 对于
embed特征图上的每一个空间位置(h, w),它的c维通道向量都会与guide矩阵中的每一行(共n行)进行点积运算。 - 这样,对于一个
(h, w)位置,经过计算后,它的特征就从原来的c维向量变成了新的n维向量。 - 这个过程在所有空间位置
(h, w)上重复,最终生成一个形状为(h, w, n)的新特征图。 - 将
b和m维度考虑回来,最终的输出形状就是(b, m, h, w, n)。
一句话总结:该操作使用 guide 张量作为权重,将 embed 特征图的通道 c 进行了重新线性组合,从 c 个旧通道映射到了 n 个新通道。
伪代码形式
# embed.shape = (b, m, c, h, w)
# guide.shape = (b, n, m, c)
# output = np.zeros((b, m, h, w, n))
for b_idx in range(b):
for m_idx in range(m):
for h_idx in range(h):
for w_idx in range(w):
for n_idx in range(n):
# 以下是核心的点积操作,对 c 维度进行求和
sum_val = 0
for c_idx in range(c):
sum_val += embed[b_idx, m_idx, c_idx, h_idx, w_idx] * guide[b_idx, n_idx, m_idx, c_idx]
output[b_idx, m_idx, h_idx, w_idx, n_idx] = sum_val等价算子替换
分析
- 目标: 实现
bmchw,bnmc -> bmhwn - 输入1 (
embed):A, 形状为(b, m, c, h, w) - 输入2 (
guide):B, 形状为(b, n, m, c)
MatMul 算子通常对输入张量的最后两个维度进行矩阵乘法,而视前面的维度为批次(Batch)维度。标准的 MatMul(X, Y) 要求 X 的形状是 (..., N, K),Y 的形状是 (..., K, M),输出形状为 (..., N, M)。我们的目标就是把 A 和 B 变成这种形式。
第1步: 重排输入1 (embed)
我们需要将 c 维度放到最后,并将 h 和 w 合并成一个维度,作为矩阵的“行”。
Transpose(转置):- 原始维度:
(b, m, c, h, w) - 目标排列: 将
c移到最后,hw移到c前面。排列为(b, m, h, w, c)。 perm参数:[0, 1, 3, 4, 2]- 输出形状:
(b, m, h, w, c)
- 原始维度:
Reshape(重塑):- 将
h和w两个空间维度合并成一个,以形成矩阵的“行”。 - 目标形状:
(b, m, h*w, c) - 输出: 得到准备好的
A_prepared,其形状为(b, m, h*w, c)。
- 将
第2步: 重排输入2 (guide)
我们需要将 c 维度放到倒数第二位,n 维度放到最后,以匹配 MatMul 的要求。
Transpose(转置):- 原始维度:
(b, n, m, c) - 目标排列: 将
bm作为批次维度,cn作为矩阵维度。排列为(b, m, c, n)。 perm参数:[0, 2, 3, 1]- 输出: 得到准备好的
B_prepared,其形状为(b, m, c, n)。
- 原始维度:
第3步: 核心计算 (MatMul)
现在两个输入的形状已经完美匹配 MatMul 的要求。
MatMul(A_prepared, B_prepared):A_prepared形状:(b, m, h*w, c)B_prepared形状:(b, m, c, n)- 前面的
(b, m)被视为批次维度。 - 核心矩阵乘法是
(h*w, c)@(c, n)。 - 输出 (
C_temp) 形状:(b, m, h*w, n)
第4步: 重塑输出
MatMul 的输出结果 C_temp 的维度顺序和我们最终想要的 (b, m, h, w, n) 还差一点,需要最后一次 Reshape。
Reshape(重塑):- 将
h*w这个维度重新拆分成h和w。 - 目标形状:
(b, m, h, w, n) - 最终输出: 这就是我们想要的结果。
- 将
注意: 在进行最后一步 Reshape 时,如果 h 和 w 是动态的,你需要先用 Shape 算子从原始输入 embed中获取 h 和 w 的值,然后与 b, m, n 拼接成最终 Reshape 的目标形状张量。
graph TD subgraph "替换Einsum" A["embed (b,m,c,h,w)"] --> T1["Transpose (perm=[0,1,3,4,2])"]; T1 --> R1["Reshape to (b,m,h*w,c)"]; B["guide (b,n,m,c)"] --> T2["Transpose (perm=[0,2,3,1])"]; R1 --> M[MatMul]; T2 --> M; M --> R2["Reshape to (b,m,h,w,n)"]; R2 --> O["Output (b,m,h,w,n)"]; end