总结
- self-attention 的计算复杂度会呈二次方增长,因此实用性不大
- 本文基于两个外部的、小尺寸的、可学习的,且共享的存储,提出 external attention,可以方便地在现有的流行模型中,替换掉self-attention结构。
-
external attention 有线性的计算复杂度,同时它考虑到了
所有样本
之间的相关性
Core Ideas and Contribution
- 首先,通过计算 self query vector 和 external learnable key memory 之间的近似性得到注意力图,再将得到的注意力图和另一个 external learnable value memory 相乘,得到一个精炼过的特征图。
- 这两个memory都是用线性层(linear layer)实现的。
- 它们和单个的样本之间互相独立,但是在整个数据集之间共享。
- 本结构能做到轻量化的最主要原因,是因为在这两个 external memory 里的参数要远小于输入特征中的参数。
- external memory 设计成了用来学习整个数据集中最有区分度的特征,捕捉最有信息量的部分,以及排除掉其它样本中具有干涉/迷惑性的信息。
Methods and Approaches
目前常用的 self-attention 模块
先来看一下普通的 self-attention 模块是怎样操作的:
给到一个尺寸为
F
∈
R
N
×
d
F \in \mathbb{R}^{N \times d}
F
∈
R
N
×
d
的输入,此处N为像素的数量,d为特征维度(feature dimensions)的数量,普通的self-attention首先将会把输入线性地映射到一个query矩阵
Q
∈
R
N
×
d
′
Q \in \mathbb{R}^{N \times d^{\prime}}
Q
∈
R
N
×
d
′
,一个key矩阵
K
∈
R
N
×
d
′
K \in \mathbb{R}^{N \times d^{\prime}}
K
∈
R
N
×
d
′
,和一个value矩阵
V
∈
R
N
×
d
V \in \mathbb{R}^{N \times d}
V
∈
R
N
×
d
。接下来,用以下的式子得到最终的结果:
A
=
(
α
)
i
,
j
=
softmax
(
Q
K
T
)
F
out
=
A
V
\begin{aligned}A &=(\alpha)_{i, j}=\operatorname{softmax}\left(Q K^{T}\right) \\F_{\text {out }} &=A V\end{aligned}
A
F
out
=
(
α
)
i
,
j
=
s
o
f
t
m
a
x
(
Q
K
T
)
=
A
V
上式中的
A
∈
R
N
×
N
A \in \mathbb{R}^{N \times N}
A
∈
R
N
×
N
即为注意力矩阵,矩阵中的
α
i
,
j
\alpha_{i, j}
α
i
,
j
是第i个像素点和第j个像素点之间的相似度
由此看来,之前的工作基本都是用的图片块(image patch),而不是所有的像素,是因为不这样的话对计算力的要求实在太大
当把注意力图给可视化出来后,注意到大部分的像素其实只与很少的几个像素之间有很强的相关性,因此一个 N x N 的注意力矩阵实在太过于冗余了
External Attention
于是,作者提出以下的方法,用external memory来计算注意力矩阵,计算得到的是
输入像素
与这个
external memory
之间的attention。external memory 的尺寸是
M
∈
R
S
×
d
M \in \mathbb{R}^{S \times d}
M
∈
R
S
×
d
。此处
(
α
)
i
,
j
(\alpha)_{i, j}
(
α
)
i
,
j
代表的是第
i
i
i
个像素和记性模块
M
M
M
第
j
j
j
行的相似度,
M
M
M
是个可学习的参数,且与输入独立(不相关),充当的是整个数据集的记忆模块.
A
=
(
α
)
i
,
j
=
Norm
(
F
M
T
)
F
out
=
A
M
\begin{aligned}A &=(\alpha)_{i, j}=\operatorname{Norm}\left(F M^{T}\right) \\F_{\text {out }} &=A M\end{aligned}
A
F
out
=
(
α
)
i
,
j
=
N
o
r
m
(
F
M
T
)
=
A
M
实际使用中,我们用的是两个不同的memory模块,称为
M
k
M_k
M
k
和
M
v
M_v
M
v
,前者是key,后者是value,以达到提升网络能力的目的,计算如下:
A
=
Norm
(
F
M
k
T
)
F
out
=
A
M
v
\begin{aligned}A &=\operatorname{Norm}\left(F M_{k}^{T}\right) \\F_{\text {out }} &=A M_{v}\end{aligned}
A
F
out
=
N
o
r
m
(
F
M
k
T
)
=
A
M
v
这样,我们的算法就是和像素的数量呈线性相关的了,复杂度为
O
(
d
S
N
)
\mathcal{O}(d S N)
O
(
d
S
N
)
,其中
d
d
d
和
S
S
S
为超参。而且试验中发现,即使S的值很少,比如设为64,也有很好的效果。
Python pesudo-code for external attention
# Input: F, an array with shape [B, N, C] (batch size, pixels, channels)
# Parameter: M_k, a linear layer without bias
# Parameter: M_v, a linear layer without bias
# Output: out, an array with shape [B, N, C]
attn = M_k(F) # shape=(B, N, M)
attn = softmax(attn, dim=1)
attn = l1_norm(attn, dim=2)
out = M_v(attn) # shape=(B, N, C)
Normalization
-
The attention calculated by matrix multiplication is
sensitive to the scale of input features
, thus need to be normalized, we use
double-normalization
here, which
seperately normalize columns and rows
. - Softmax is used here.
(
α
~
)
i
,
j
=
F
M
k
T
α
i
,
j
=
exp
(
α
~
i
,
j
)
∑
k
exp
(
α
~
k
,
j
)
α
i
,
j
=
α
i
,
j
∑
k
α
i
^
,
k
\begin{aligned}(\tilde{\alpha})_{i, j} &=F M_{k}^{T} \\\alpha_{i, j} &=\frac{\exp \left(\tilde{\alpha}_{i, j}\right)}{\sum_{k} \exp \left(\tilde{\alpha}_{k, j}\right)} \\\alpha_{i, j} &=\frac{\alpha_{i, j}}{\sum_{k} \alpha_{\hat{i}, k}}\end{aligned}
(
α
~
)
i
,
j
α
i
,
j
α
i
,
j
=
F
M
k
T
=
∑
k
exp
(
α
~
k
,
j
)
exp
(
α
~
i
,
j
)
=
∑
k
α
i
^
,
k
α
i
,
j
Official Code
https://github.com/MenghaoGuo/-EANet
class External_attention(nn.Module):
'''
Arguments:
c (int): The input and output channel number.
'''
def __init__(self, c):
super(External_attention, self).__init__()
self.conv1 = nn.Conv2d(c, c, 1)
self.k = 64
self.linear_0 = nn.Conv1d(c, self.k, 1, bias=False)
self.linear_1 = nn.Conv1d(self.k, c, 1, bias=False)
self.linear_1.weight.data = self.linear_0.weight.data.permute(1, 0, 2)
self.conv2 = nn.Sequential(
nn.Conv2d(c, c, 1, bias=False),
norm_layer(c))
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.Conv1d):
n = m.kernel_size[0] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, _BatchNorm):
m.weight.data.fill_(1)
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
idn = x
x = self.conv1(x)
b, c, h, w = x.size()
n = h*w
x = x.view(b, c, h*w) # b * c * n
attn = self.linear_0(x) # b, k, n
attn = F.softmax(attn, dim=-1) # b, k, n
attn = attn / (1e-9 + attn.sum(dim=1, keepdim=True)) # # b, k, n
x = self.linear_1(attn) # b, c, n
x = x.view(b, c, h, w)
x = self.conv2(x)
x = x + idn
x = F.relu(x)
return x