Transformer在计算机视觉中的工作原理

2024年02月28日 由 alex 发表 202 0

自从2017年《Attention is All You Need》提出以来,Transformer 模型已经成为自然语言处理(NLP)领域的最先进技术。在2021年,《An Image is Worth 16x16 Words》2 成功将 Transformer 模型应用到了计算机视觉任务中。此后,许多基于 Transformer 的架构也被提出用于计算机视觉。


本文将深入探讨注意力层在计算机视觉中的工作原理。我们将涵盖单头和多头注意力,并提供了注意力层的开源代码,以及底层数学的概念性解释。代码使用了 Python 的 PyTorch 包。


1


注意事项

在自然语言处理(NLP)应用中,注意力通常被描述为句子中单词(标记)之间的关系。在计算机视觉应用中,注意力关注的是图像中各个区块(标记)之间的关系。


将图像分解为一系列标记有多种方法。最初的 ViT 将图像分割为各个区块,然后将其展平为标记;而 Tokens-to-Token ViT 则采用了更复杂的方法来从图像中创建标记;有关该方法的更多信息可以在《Tokens-To-Token ViT》文章中找到。


本文将在假定标记为输入的情况下继续讨论注意力层。在 Transformer 的开始阶段,这些标记将代表输入图像中的区块。然而,更深层的注意力层将计算已经被前面层修改过的标记之间的注意力,从而消除了表示的直接性。


本文将详细探讨《Attention is All You Need》 中定义的点积(等效于乘法)注意力。这与《An Image is Worth 16x16 Words》 和 Tokens-to-Token ViT等衍生作品中使用的相同注意力机制。代码基于公开可用的 Tokens-to-Token ViT 的 GitHub 代码,经过一些修改。对源代码的更改包括但不限于将两个注意力模块合并为一个,并实现多头注意力。


完整的注意力模块如下所示:


class Attention(nn.Module):
    def __init__(self, 
                dim: int,
                chan: int,
                num_heads: int=1,
                qkv_bias: bool=False,
                qk_scale: NoneFloat=None):
        """ Attention Module
            Args:
                dim (int): input size of a single token
                chan (int): resulting size of a single token (channels)
                num_heads(int): number of attention heads in MSA
                qkv_bias (bool): determines if the qkv layer learns an addative bias
                qk_scale (NoneFloat): value to scale the queries and keys by; 
                                    if None, queries and keys are scaled by ``head_dim ** -0.5``
        """
        super().__init__()
        ## Define Constants
        self.num_heads = num_heads
        self.chan = chan
        self.head_dim = self.chan // self.num_heads
        self.scale = qk_scale or self.head_dim ** -0.5
        assert self.chan % self.num_heads == 0, '"Chan" must be evenly divisible by "num_heads".'
        ## Define Layers
        self.qkv = nn.Linear(dim, chan * 3, bias=qkv_bias)
        #### Each token gets projected from starting length (dim) to channel length (chan) 3 times (for each Q, K, V)
        self.proj = nn.Linear(chan, chan)
    def forward(self, x):
        B, N, C = x.shape
        ## Dimensions: (batch, num_tokens, token_len)
        ## Calcuate QKVs
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        #### Dimensions: (3, batch, heads, num_tokens, chan/num_heads = head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]
        ## Calculate Attention
        attn = (q * self.scale) @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        #### Dimensions: (batch, heads, num_tokens, num_tokens)
        ## Attention Layer
        x = (attn @ v).transpose(1, 2).reshape(B, N, self.chan)
        #### Dimensions: (batch, heads, num_tokens, chan)
        ## Projection Layers
        x = self.proj(x)
        ## Skip Connection Layer
        v = v.transpose(1, 2).reshape(B, N, self.chan)
        x = v + x     
        #### Because the original x has different size with current x, use v to do skip connection
        return x


单头注意力 

从只有一个注意力头开始,让我们逐行分析前向传播,并在此过程中查看一些矩阵图。我们使用 7∗7=49 作为我们的起始标记大小,因为这是 T2T-ViT 模型中的起始标记大小。我们使用 64 个通道,因为这也是 T2T-ViT 的默认值。我们使用 100 个标记,因为这是一个不错的数字。我们使用批量大小为 13,因为它是质数,不会与其他参数混淆。


# Define an Input
token_len = 7*7
channels = 64
num_tokens = 100
batch = 13
x = torch.rand(batch, num_tokens, token_len)
B, N, C = x.shape
print('Input dimensions are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken size:', x.shape[2])
# Define the Module
A = Attention(dim=token_len, chan=channels, num_heads=1, qkv_bias=False, qk_scale=None)
A.eval();


Input dimensions are
   batchsize: 13 
   number of tokens: 100 
   token size: 49


从《Attention is All You Need》中,注意力是通过查询(Queries)、键(Keys)和值(Values)矩阵来定义的。首先,我们通过一个可学习的线性层来计算这些矩阵。布尔值 qkv_bias 表示这些线性层是否具有偏置项。此步骤还将输入的标记长度从 49 改变为我们设置的 64。


2


qkv = A.qkv(x).reshape(B, N, 3, A.num_heads, A.head_dim).permute(2, 0, 3, 1, 4)3, A.num_heads, A.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
print('Dimensions for Queries are\n\tbatchsize:', q.shape[0], '\n\tattention heads:', q.shape[1], '\n\tnumber of tokens:', q.shape[2], '\n\tnew length of tokens:', q.shape[3])
print('See that the dimensions for queries, keys, and values are all the same:')
print('\tShape of Q:', q.shape, '\n\tShape of K:', k.shape, '\n\tShape of V:', v.shape)


Dimensions for Queries are
   batchsize: 13 
   attention heads: 1 
   number of tokens: 100 
   new length of tokens: 64
See that the dimensions for queries, keys, and values are all the same:
   Shape of Q: torch.Size([13, 1, 100, 64]) 
   Shape of K: torch.Size([13, 1, 100, 64]) 
   Shape of V: torch.Size([13, 1, 100, 64])


现在,我们可以开始计算注意力,其定义如下:


3


在这个公式中,Q、K 和 V 分别表示查询、键和值;dk​ 是键的维度,等于键标记的长度,也等于 chan 的长度。


我们将按照代码中的实现逐步解析这个方程。我们将把中间矩阵称为 Attn。


首先,我们要计算:


4


在代码中设置了:


5


默认情况下:


6


然而,用户可以将替代的缩放值作为超参数进行指定。


分子中的矩阵乘法 Q·Kᵀ 如下所示:


7


所有这些在代码中看起来像:


attn = (q * A.scale) @ k.transpose(-2, -1)2, -1)
print('Dimensions for Attn are\n\tbatchsize:', attn.shape[0], '\n\tattention heads:', attn.shape[1], '\n\tnumber of tokens:', attn.shape[2], '\n\tnumber of tokens:', attn.shape[3])


Dimensions for Attn are
   batchsize: 13 
   attention heads: 1 
   number of tokens: 100 
   number of tokens: 100


接下来,我们计算 A 的 softmax,这不会改变它的形状。


attn = attn.softmax(dim=-1)1)
print('Dimensions for Attn are\n\tbatchsize:', attn.shape[0], '\n\tattention heads:', attn.shape[1], '\n\tnumber of tokens:', attn.shape[2], '\n\tnumber of tokens:', attn.shape[3])


Dimensions for Attn are
   batchsize: 13 
   attention heads: 1 
   number of tokens: 100 
   number of tokens: 100


最后,我们计算 A·V=x,如下所示:


8


x = attn @ v
print('Dimensions for x are\n\tbatchsize:', x.shape[0], '\n\tattention heads:', x.shape[1], '\n\tnumber of tokens:', x.shape[2], '\n\tlength of tokens:', x.shape[3])


Dimensions for x are
   batchsize: 13 
   attention heads: 1 
   number of tokens: 100 
   length of tokens: 64


输出 x 被重塑以删除注意力头维度。


x = x.transpose(1, 2).reshape(B, N, A.chan)1, 2).reshape(B, N, A.chan)
print('Dimensions for x are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])


Dimensions for x are
   batchsize: 13 
   number of tokens: 100 
   length of tokens: 64


然后我们将 x 输入一个可学习的线性层,该层不会改变它的形状。


x = A.proj(x)
print('Dimensions for x are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])


Dimensions for x are
   batchsize: 13 
   number of tokens: 100 
   length of tokens: 64


最后,我们实现了一个跳跃连接。由于当前的 x 形状与输入 x 的形状不同,我们使用 V 来进行跳跃连接。我们在注意力头维度上对 V 进行了展平。


orig_shape = (batch, num_tokens, token_len)
curr_shape = (x.shape[0], x.shape[1], x.shape[2])0], x.shape[1], x.shape[2])
v = v.transpose(1, 2).reshape(B, N, A.chan)
v_shape = (v.shape[0], v.shape[1], v.shape[2])
print('Original shape of input x:', orig_shape)
print('Current shape of x:', curr_shape)
print('Shape of V:', v_shape)
x = v + x     
print('After skip connection, dimensions for x are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])


Original shape of input x: (13, 100, 49)
Current shape of x: (13, 100, 64)
Shape of V: (13, 100, 64)
After skip connection, dimensions for x are
   batchsize: 13 
   number of tokens: 100 
   length of tokens: 64


多头注意力

在计算机视觉中,多头注意力通常被称为多头自注意力(Multi-headed Self Attention,MSA)。


与单头注意力相同,我们使用 7∗7=49 作为起始标记大小,64 个通道是 T2T-ViT 的默认值³。我们使用 100 个标记,因为这是一个不错的数字。我们使用批量大小为 13,因为它是质数,不会与其他参数混淆。


注意力头的数量必须能够整除通道数,因此在本示例中,我们将使用 4 个注意力头。


# Define an Input
token_len = 7*7
channels = 64
num_tokens = 100
batch = 13
num_heads = 4
x = torch.rand(batch, num_tokens, token_len)
B, N, C = x.shape
print('Input dimensions are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken size:', x.shape[2])
# Define the Module
MSA = Attention(dim=token_len, chan=channels, num_heads=num_heads, qkv_bias=False, qk_scale=None)
MSA.eval();


Input dimensions are
   batchsize: 13 
   number of tokens: 100 
   token size: 49


计算查询(Queries)、键(Keys)和值(Values)的过程与单头注意力中相同。然而,你可以看到标记的新长度为 chan/num_heads。Q、K 和 V 矩阵的总大小没有改变;它们的内容只是在头维度上分布。你可以将其视为将单头矩阵分割为多个头:


9


我们将查询头 i 的子矩阵表示为 Qₕᵢ。


qkv = MSA.qkv(x).reshape(B, N, 3, MSA.num_heads, MSA.head_dim).permute(2, 0, 3, 1, 4)3, MSA.num_heads, MSA.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
print('Head Dimension = chan / num_heads =', MSA.chan, '/', MSA.num_heads, '=', MSA.head_dim)
print('Dimensions for Queries are\n\tbatchsize:', q.shape[0], '\n\tattention heads:', q.shape[1], '\n\tnumber of tokens:', q.shape[2], '\n\tnew length of tokens:', q.shape[3])
print('See that the dimensions for queries, keys, and values are all the same:')
print('\tShape of Q:', q.shape, '\n\tShape of K:', k.shape, '\n\tShape of V:', v.shape)


Head Dimension = chan / num_heads = 64 / 4 = 16
Dimensions for Queries are
   batchsize: 13 
   attention heads: 4 
   number of tokens: 100 
   new length of tokens: 16
See that the dimensions for queries, keys, and values are all the same:
   Shape of Q: torch.Size([13, 4, 100, 16]) 
   Shape of K: torch.Size([13, 4, 100, 16]) 
   Shape of V: torch.Size([13, 4, 100, 16])


下一步计算:


10


对于每个头我。在这种情况下,密钥的长度是:


11


与单头注意力一样,我们使用默认值


12


然而,用户可以将替代的缩放值作为超参数进行指定。


我们在这一步中得到了 num_heads = 4 个不同的 Attn 矩阵,如下所示:


13


attn = (q * MSA.scale) @ k.transpose(-2, -1)2, -1)
print('Dimensions for Attn are\n\tbatchsize:', attn.shape[0], '\n\tattention heads:', attn.shape[1], '\n\tnumber of tokens:', attn.shape[2], '\n\tnumber of tokens:', attn.shape[3])


Dimensions for Attn are
   batchsize: 13 
   attention heads: 4 
   number of tokens: 100 
   number of tokens: 100


接下来我们计算 A 的 softmax,这不会改变它的形状。


然后,我们可以计算


14


这在多个注意力头中也有类似的分布:


15


attn = attn.softmax(dim=-1)1)
x = attn @ v
print('Dimensions for x are\n\tbatchsize:', x.shape[0], '\n\tattention heads:', x.shape[1], '\n\tnumber of tokens:', x.shape[2], '\n\tlength of tokens:', x.shape[3])


Dimensions for x are
   batchsize: 13 
   attention heads: 4 
   number of tokens: 100 
   length of tokens: 16


现在,我们通过一些重塑操作将所有的 xₕᵢ 连接在一起。这是与第一步相反的操作:


16


x = x.transpose(1, 2).reshape(B, N, MSA.chan)1, 2).reshape(B, N, MSA.chan)
print('Dimensions for x are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])


Dimensions for x are
   batchsize: 13 
   number of tokens: 100 
   length of tokens: 64


现在,我们已经将所有注意力头连接在一起,注意力模块的其余部分保持不变。对于跳跃连接,我们仍然使用 V,但我们需要对其进行重塑以去除头维度。


x = MSA.proj(x)
print('Dimensions for x are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])
orig_shape = (batch, num_tokens, token_len)
curr_shape = (x.shape[0], x.shape[1], x.shape[2])
v = v.transpose(1, 2).reshape(B, N, A.chan)
v_shape = (v.shape[0], v.shape[1], v.shape[2])
print('Original shape of input x:', orig_shape)
print('Current shape of x:', curr_shape)
print('Shape of V:', v_shape)
x = v + x     
print('After skip connection, dimensions for x are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])


Dimensions for x are
   batchsize: 13 
   number of tokens: 100 
   length of tokens: 64
Original shape of input x: (13, 100, 49)
Current shape of x: (13, 100, 64)
Shape of V: (13, 100, 64)
After skip connection, dimensions for x are
   batchsize: 13 
   number of tokens: 100 
   length of tokens: 64


结论

这篇文章详细介绍了视觉 transformer 中注意力层的每个步骤。注意力层中的可学习权重位于从标记到查询、键和值的第一个投影以及最终投影中。大部分注意力层是确定性的矩阵乘法。然而,当使用长标记时,线性层中可能包含大量权重。QKV 投影层中的权重数量等于 input_token_len * chan * 3,最终投影层中的权重数量等于 chan²。



文章来源:https://medium.com/towards-data-science/attention-for-vision-transformers-explained-70f83984c673
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
写评论取消
回复取消