Day 3 second: Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks

  • Post author:
  • Post category:其他




总结

  • self-attention 的计算复杂度会呈二次方增长,因此实用性不大
  • 本文基于两个外部的、小尺寸的、可学习的,且共享的存储,提出 external attention,可以方便地在现有的流行模型中,替换掉self-attention结构。
  • external attention 有线性的计算复杂度,同时它考虑到了

    所有样本

    之间的相关性



Core Ideas and Contribution

  • 首先,通过计算 self query vector 和 external learnable key memory 之间的近似性得到注意力图,再将得到的注意力图和另一个 external learnable value memory 相乘,得到一个精炼过的特征图。
  • 这两个memory都是用线性层(linear layer)实现的。
  • 它们和单个的样本之间互相独立,但是在整个数据集之间共享。
  • 本结构能做到轻量化的最主要原因,是因为在这两个 external memory 里的参数要远小于输入特征中的参数。
  • external memory 设计成了用来学习整个数据集中最有区分度的特征,捕捉最有信息量的部分,以及排除掉其它样本中具有干涉/迷惑性的信息。



Methods and Approaches



目前常用的 self-attention 模块

先来看一下普通的 self-attention 模块是怎样操作的:

给到一个尺寸为



F

R

N

×

d

F \in \mathbb{R}^{N \times d}






F















R












N


×


d













的输入,此处N为像素的数量,d为特征维度(feature dimensions)的数量,普通的self-attention首先将会把输入线性地映射到一个query矩阵



Q

R

N

×

d

Q \in \mathbb{R}^{N \times d^{\prime}}






Q















R












N


×



d

































,一个key矩阵



K

R

N

×

d

K \in \mathbb{R}^{N \times d^{\prime}}






K















R












N


×



d

































,和一个value矩阵



V

R

N

×

d

V \in \mathbb{R}^{N \times d}






V















R












N


×


d













。接下来,用以下的式子得到最终的结果:





A

=

(

α

)

i

,

j

=

softmax

(

Q

K

T

)

F

out 

=

A

V

\begin{aligned}A &=(\alpha)_{i, j}=\operatorname{softmax}\left(Q K^{T}\right) \\F_{\text {out }} &=A V\end{aligned}
















A









F












out















































=




(


α



)











i


,


j





















=





s


o


f


t


m


a


x







(



Q



K











T











)














=




A


V






















上式中的



A

R

N

×

N

A \in \mathbb{R}^{N \times N}






A















R












N


×


N













即为注意力矩阵,矩阵中的



α

i

,

j

\alpha_{i, j}







α











i


,


j






















是第i个像素点和第j个像素点之间的相似度

由此看来,之前的工作基本都是用的图片块(image patch),而不是所有的像素,是因为不这样的话对计算力的要求实在太大


当把注意力图给可视化出来后,注意到大部分的像素其实只与很少的几个像素之间有很强的相关性,因此一个 N x N 的注意力矩阵实在太过于冗余了


在这里插入图片描述



External Attention

于是,作者提出以下的方法,用external memory来计算注意力矩阵,计算得到的是

输入像素

与这个

external memory

之间的attention。external memory 的尺寸是



M

R

S

×

d

M \in \mathbb{R}^{S \times d}






M















R












S


×


d













。此处



(

α

)

i

,

j

(\alpha)_{i, j}






(


α



)











i


,


j






















代表的是第



i

i






i





个像素和记性模块



M

M






M









j

j






j





行的相似度,



M

M






M





是个可学习的参数,且与输入独立(不相关),充当的是整个数据集的记忆模块.





A

=

(

α

)

i

,

j

=

Norm

(

F

M

T

)

F

out 

=

A

M

\begin{aligned}A &=(\alpha)_{i, j}=\operatorname{Norm}\left(F M^{T}\right) \\F_{\text {out }} &=A M\end{aligned}
















A









F












out















































=




(


α



)











i


,


j





















=





N


o


r


m







(



F



M











T











)














=




A


M






















实际使用中,我们用的是两个不同的memory模块,称为



M

k

M_k







M










k

























M

v

M_v







M










v





















,前者是key,后者是value,以达到提升网络能力的目的,计算如下:





A

=

Norm

(

F

M

k

T

)

F

out 

=

A

M

v

\begin{aligned}A &=\operatorname{Norm}\left(F M_{k}^{T}\right) \\F_{\text {out }} &=A M_{v}\end{aligned}
















A









F












out















































=





N


o


r


m







(



F



M











k










T




















)














=




A



M











v







































这样,我们的算法就是和像素的数量呈线性相关的了,复杂度为



O

(

d

S

N

)

\mathcal{O}(d S N)







O



(


d


S


N


)





,其中



d

d






d









S

S






S





为超参。而且试验中发现,即使S的值很少,比如设为64,也有很好的效果。



Python pesudo-code for external attention

# Input: F, an array with shape [B, N, C] (batch size, pixels, channels)
# Parameter: M_k, a linear layer without bias
# Parameter: M_v, a linear layer without bias
# Output: out, an array with shape [B, N, C]
attn = M_k(F) # shape=(B, N, M)
attn = softmax(attn, dim=1)
attn = l1_norm(attn, dim=2)
out = M_v(attn) # shape=(B, N, C)



Normalization

  • The attention calculated by matrix multiplication is

    sensitive to the scale of input features

    , thus need to be normalized, we use

    double-normalization

    here, which

    seperately normalize columns and rows

    .
  • Softmax is used here.





(

α

~

)

i

,

j

=

F

M

k

T

α

i

,

j

=

exp

(

α

~

i

,

j

)

k

exp

(

α

~

k

,

j

)

α

i

,

j

=

α

i

,

j

k

α

i

^

,

k

\begin{aligned}(\tilde{\alpha})_{i, j} &=F M_{k}^{T} \\\alpha_{i, j} &=\frac{\exp \left(\tilde{\alpha}_{i, j}\right)}{\sum_{k} \exp \left(\tilde{\alpha}_{k, j}\right)} \\\alpha_{i, j} &=\frac{\alpha_{i, j}}{\sum_{k} \alpha_{\hat{i}, k}}\end{aligned}
















(










α







~








)











i


,


j


























α











i


,


j


























α











i


,


j














































=




F



M











k










T





























=




























k





















exp





(











α







~
















k


,


j



















)















exp





(











α







~
















i


,


j



















)































=




























k






















α



















i








^








,


k
































α











i


,


j



























































Official Code


https://github.com/MenghaoGuo/-EANet

class External_attention(nn.Module):
    '''
    Arguments:
        c (int): The input and output channel number.
    '''
    def __init__(self, c):
        super(External_attention, self).__init__()
        
        self.conv1 = nn.Conv2d(c, c, 1)

        self.k = 64
        self.linear_0 = nn.Conv1d(c, self.k, 1, bias=False)

        self.linear_1 = nn.Conv1d(self.k, c, 1, bias=False)
        self.linear_1.weight.data = self.linear_0.weight.data.permute(1, 0, 2)        
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(c, c, 1, bias=False),
            norm_layer(c))        
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.Conv1d):
                n = m.kernel_size[0] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, _BatchNorm):
                m.weight.data.fill_(1)
                if m.bias is not None:
                    m.bias.data.zero_()
 

    def forward(self, x):
        idn = x
        x = self.conv1(x)

        b, c, h, w = x.size()
        n = h*w
        x = x.view(b, c, h*w)   # b * c * n 

        attn = self.linear_0(x) # b, k, n
        attn = F.softmax(attn, dim=-1) # b, k, n

        attn = attn / (1e-9 + attn.sum(dim=1, keepdim=True)) #  # b, k, n
        x = self.linear_1(attn) # b, c, n

        x = x.view(b, c, h, w)
        x = self.conv2(x)
        x = x + idn
        x = F.relu(x)
        return x



版权声明:本文为ttppss原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。