通过注意力机制优化提升Transformer模型效率

2024年11月21日 由 alex 发表 13 0

Transformer 架构于 2017 年的里程碑式论文《Attention Is All You Need》(Vaswani 等人,2017)中提出,被广泛认为是过去十年最具影响力的科学突破之一。Transformer 的核心是注意力机制,这是一种新颖的方法,它使 AI 模型能够根据手头的任务关注输入序列的不同部分,从而理解复杂的结构。Transformers 架构最初在自然语言处理领域得到展示,其成功已迅速传播到许多其他领域,包括语音识别、场景理解、强化学习、蛋白质结构预测等。然而,注意力层非常耗费资源,随着这些层成为越来越大的模型的标准,与其训练和部署相关的成本也大幅增加。这迫切需要制定策略来降低这一核心层的计算成本,以提高基于 Transformer 的 AI 模型的效率和可扩展性。


在本文中,我们将探索在PyTorch中优化注意力的几种工具。我们的重点将放在保持注意力层准确性的方法上,包括PyTorch SDPA、FlashAttention、TransformerEngine Attention、FlexAttention和xFormer attention。我们不会考虑通过近似注意力计算来降低计算成本的其他方法(如DeepSpeed的Sparse Attention、Longformer、Linformer等)。此外,我们也不会讨论虽然对注意力性能有益但并非特定于注意力计算本身的通用优化技术(如FP8训练、模型分片等)。


重要的是,注意力优化是一个活跃的研究领域,新方法层出不穷。我们的目标是提高你对现有解决方案的认识,并为你提供进一步探索和实验的基础。我们下面分享的代码仅用于演示目的——我们不对其准确性、最优性或稳健性做出任何声明。请勿将我们提及的任何平台、库或优化技术视为对其使用的认可。最适合你的选项将极大地取决于你自己用例的具体情况。


玩具模型

为了便于讨论,我们使用流行的timm Python包(版本0.9.7)构建了一个基于视觉Transformer(ViT)的分类模型。我们将使用这个模型来说明各种注意力核的性能影响。


首先,我们定义了一个简化的Transformer块,它允许通过将其传递给构造函数来编程注意力函数。由于注意力实现假设了特定的输入张量格式,我们还包括了一个控制格式的选项,以确保与我们选择的注意力核兼容。


# general imports
import os, time, functools
# torch imports
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
# timm imports
from timm.models.vision_transformer import VisionTransformer
from timm.layers import Mlp
IMG_SIZE = 224
BATCH_SIZE = 128
# Define ViT settings
NUM_HEADS = 16
HEAD_DIM = 64
DEPTH = 24
PATCH_SIZE = 16
SEQ_LEN = (IMG_SIZE // PATCH_SIZE)**2 # 196
class MyAttentionBlock(nn.Module):
    def __init__(
            self,
            attn_fn,
            format = None,
            dim: int = 768,
            num_heads: int = 12,
            **kwargs
    ) -> None:
        super().__init__()
        self.attn_fn = attn_fn
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=dim * 4,
        )
        permute = (2, 0, 3, 1, 4)
        self.permute_attn = functools.partial(torch.transpose,dim0=1,dim1=2)
        if format == 'bshd':
            permute = (2, 0, 1, 3, 4)
            self.permute_attn = nn.Identity()
        self.permute_qkv = functools.partial(torch.permute,dims=permute)
    def forward(self, x_in: torch.Tensor) -> torch.Tensor:
        x = self.norm1(x_in)
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        # permute tensor based on the specified format
        qkv = self.permute_qkv(qkv)
        q, k, v = qkv.unbind(0)
        # use the attention function specified by the user
        x = self.attn_fn(q, k, v)
        # permute output according to the specified format
        x = self.permute_attn(x).reshape(B, N, C)
        x = self.proj(x)
        x = x + x_in
        x = x + self.mlp(self.norm2(x))
        return x


我们定义了一个随机生成的数据集,将在训练过程中将其输入到我们的模型中。


# Use random data
class FakeDataset(Dataset):
    def __len__(self):
        return 1000000
    def __getitem__(self, index):
        rand_image = torch.randn([3, IMG_SIZE, IMG_SIZE],
                                 dtype=torch.float32)
        label = torch.tensor(data=index % 1000, dtype=torch.int64)
        return rand_image, label 


接下来,我们定义我们的ViT训练函数。虽然我们的示例主要关注展示训练工作负载,但必须强调的是,在模型推理过程中,优化注意力层同样(如果不是更加)重要。


我们定义的训练函数接受一个自定义的Transformer块和一个控制是否使用torch.compile的标志。


def train_fn(block_fn, compile):
    torch.random.manual_seed(0)
    device = torch.device("cuda:0")
    torch.set_float32_matmul_precision("high")
    # Create dataset and dataloader
    train_set = FakeDataset()
    train_loader = DataLoader(
        train_set, batch_size=BATCH_SIZE,
        num_workers=12, pin_memory=True, drop_last=True)
    model = VisionTransformer(
       img_size=IMG_SIZE,
       patch_size=PATCH_SIZE,
       embed_dim=NUM_HEADS*HEAD_DIM,
       depth=DEPTH,
       num_heads=NUM_HEADS,
       class_token=False,
       global_pool="avg",
       block_fn=block_fn
    ).to(device)
    if compile:
        model = torch.compile(model)
    # Define loss and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters())
    model.train()
    t0 = time.perf_counter()
    summ = 0
    count = 0
    for step, data in enumerate(train_loader):
        # Copy data to GPU
        inputs = data[0].to(device=device, non_blocking=True)
        label = data[1].to(device=device, non_blocking=True)
        with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
            outputs = model(inputs)
            loss = criterion(outputs, label)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        # Capture step time
        batch_time = time.perf_counter() - t0
        if step > 20:  # Skip first steps
            summ += batch_time
            count += 1
        t0 = time.perf_counter()
        if step > 100:
            break
    print(f'average step time: {summ / count}')
# define compiled and uncompiled variants of our train function
train = functools.partial(train_fn, compile=False)
train_compile = functools.partial(train_fn, compile=True)


在下面的代码块中,我们定义了一个PyTorch原生的注意力函数,并使用它来训练我们的ViT模型:


def attn_fn(q, k, v):
    scale = HEAD_DIM ** -0.5
    q = q * scale
    attn = q @ k.transpose(-2, -1)
    attn = attn.softmax(dim=-1)
    x = attn @ v
    return x
block_fn = functools.partial(MyAttentionBlock, attn_fn=attn_fn)
print('Default Attention')
train(block_fn)
print('Compiled Default Attention')
train_compile(block_fn)


我们在配备CUDA 12.4和PyTorch 2.5.1的NVIDIA H100上运行了上述代码。未编译的版本平均每一步耗时370毫秒(ms),而编译后的版本则提升至242ms。在我们考虑其他执行注意力计算的方法时,将使用这些结果作为比较的基准。


PyTorch SDPA

在PyTorch中提升注意力层性能的最简单方法之一是使用scaled_dot_product_attention(SDPA)函数。目前处于测试阶段的PyTorch SDPA整合了多个内核级别的优化,并根据输入的特性动态选择最高效的优化方案。目前支持的后端包括:FlashAttention-2、内存高效注意力、基于C++的数学注意力以及CuDNN。这些后端融合了高层操作,同时采用GPU级别的优化来提高计算效率和内存利用率。


SDPA在不断演进,新的和改进的后端实现会定期推出。及时更新到最新的PyTorch版本是利用最新性能改进的关键。例如,PyTorch 2.5引入了一个更新的CuDNN后端,其中包含一个专门为在NVIDIA Hopper架构GPU上训练而定制的SDPA原语。


在下面的代码块中,我们遍历了支持的后端列表,并评估了使用每个后端进行训练的运行时性能。我们使用了一个辅助函数set_sdpa_backend来编程SDPA后端:


from torch.nn.functional import scaled_dot_product_attention as sdpa
def set_sdpa_backend(backend):
    torch.backends.cuda.enable_flash_sdp(False)
    torch.backends.cuda.enable_mem_efficient_sdp(False)
    torch.backends.cuda.enable_math_sdp(False)
    torch.backends.cuda.enable_cudnn_sdp(False)
    if backend in ['flash_sdp','all']:
        torch.backends.cuda.enable_flash_sdp(True)
    if backend in ['mem_efficient_sdp','all']:
        torch.backends.cuda.enable_mem_efficient_sdp(True)
    if backend in ['math_sdp','all']:
        torch.backends.cuda.enable_math_sdp(True)
    if backend in ['cudnn_sdp','all']:
        torch.backends.cuda.enable_cudnn_sdp(True)
for backend in ['flash_sdp', 'mem_efficient_sdp',
                'math_sdp', 'cudnn_sdp']:
    set_sdpa_backend(backend)
    block_fn = functools.partial(MyAttentionBlock,
                                 attn_fn=sdpa)
    print(f'PyTorch SDPA - {backend}')
    train(block_fn)
    print(f'Compiled PyTorch SDPA - {backend}')
    train_compile(block_fn)


我们将中期结果总结在下表中。


16


虽然SDPA后端的选择在急切模式下对性能有显著影响,但模型编译所执行的优化似乎掩盖了注意力内核之间的差异。我们再次提醒,不要从这些结果中得出任何结论,因为不同注意力函数对性能的影响可能会因具体模型和使用场景的不同而显著差异。


第三方注意力内核

虽然PyTorch SDPA是一个很好的起点,但使用第三方注意力内核可以进一步加速你的机器学习工作负载。这些替代方案通常具有更高的灵活性,为注意力提供了更广泛的配置选项。有些还可能包含针对特定硬件加速器或更新GPU架构的优化。


在本节中,我们将探索一些可用的第三方注意力内核,并评估它们对运行时性能的潜在影响。


FlashAttention-3

虽然PyTorch SDPA支持FlashAttention后端,但在flash-attn库中可以找到更高级的FlashAttention实现。在这里,我们将探索FlashAttention-3的测试版,其速度相比FlashAttention-2提高了多达2倍。鉴于其仍处于开发早期阶段,FlashAttention-3只能直接从GitHub存储库安装,并且其使用仅限于某些头维度。此外,它还不支持模型编译。在下面的代码块中,我们将transformer块配置为使用flash-attn-3,同时将注意力输入格式设置为“bshd”(批次、序列、头、深度),以满足库的期望。


# flash attention 3
from flash_attn_interface import flash_attn_func as fa3
attn_fn = lambda q,k,v: fa3(q,k,v)[0]
block_fn = functools.partial(MyAttentionBlock,
                             attn_fn=attn_fn,
                             format='bshd')
print(f'Flash Attention 3')
train(block_fn)


最终得到的每步时间为240毫秒,比SDPA的flash-attn快了5%。


Transformer Engine

Transformer Engine(TE)是一个专为在NVIDIA GPU上加速Transformer模型而设计的库。TE定期更新优化,充分利用最新NVIDIA硬件和软件的功能,让用户能够在这些优化被集成到如PyTorch等通用框架之前很久,就能访问到专用的内核。


在下面的代码块中,我们使用了TE版本1.11.0中的DotProductAttention。与PyTorch SDPA类似,TE支持多种后端,这些后端通过环境变量进行控制。在这里,我们展示了如何使用NVTE_FUSED_ATTN后端。


def set_te_backend(backend):
    # must be applied before first use of
    # transformer_engine.pytorch.attention
    os.environ["NVTE_FLASH_ATTN"] = '0'
    os.environ["NVTE_FUSED_ATTN"] = '0'
    os.environ["NVTE_UNFUSED_ATTN"] = '0'
    if backend == 'flash':
        os.environ["NVTE_FLASH_ATTN"] = '1'
    if backend == 'fused':
        os.environ["NVTE_FUSED_ATTN"] = '1'
    if backend == 'unfused':
        os.environ["NVTE_UNFUSED_ATTN"] = '1'
from transformer_engine.pytorch.attention import DotProductAttention
set_te_backend('fused')
attn_fn = DotProductAttention(NUM_HEADS, HEAD_DIM, NUM_HEADS,
                              qkv_format='bshd',
                              # disable masking (default is causal mask)
                              attn_mask_type='no_mask')
block_fn = functools.partial(MyAttentionBlock,
                             attn_fn=attn_fn,
                             format='bshd')
print(f'Transformer Engine Attention')
train(block_fn)
print(f'Compiled Transformer Engine Attention')
train_compile(block_fn)


TE注意力在急切模式和编译模型变体下分别产生了243毫秒和204毫秒的平均每步时间。


XFormer Attention

PyTorch SDPA的内存高效后端底层是由xFormers库提供的一个注意力内核。同样,我们可以直接访问源代码,以受益于最新的内核优化和完整的API功能集。在下面的代码块中,我们使用了xFormers版本0.0.28中的memory_efficient_attention操作符。


# xformer memory efficient attention
from xformers.ops import memory_efficient_attention as mea
block_fn = functools.partial(MyAttentionBlock,
                             attn_fn=mea,
                             format='bshd')
print(f'xFormer Attention ')
train(block_fn)
print(f'Compiled xFormer Attention ')
train_compile(block_fn)


这个急切模型变体的平均每步时间为246毫秒,比SDPA的内存高效内核快了10.5%。编译变体的每步时间为203毫秒。


结果

下表总结了我们的实验结果:

17


在急切模式下,表现最佳的是flash-attn-3,其平均每步时间比我们的基线模型快了54%。这意味着训练成本也相应降低了大约54%。在编译模式下,各个优化内核的性能表现大致相当,最快的实现达到了202毫秒,相比基线实验提高了20%。


如上所述,具体的节省效果在很大程度上取决于模型定义。为了评估这种变异性,我们使用修改后的设置重新进行了实验,将注意力序列长度增加到3136个标记。


IMG_SIZE = 224224
BATCH_SIZE = 8
# Define ViT settings
NUM_HEADS = 12
HEAD_DIM = 64
DEPTH = 6
PATCH_SIZE = 4
SEQ_LEN = (IMG_SIZE // PATCH_SIZE)**2 # 3136


结果总结在下表中:


18


我们的直接观察是,当序列长度增加时,注意力内核的性能影响更为显著。再次地,在急切执行模式下,flash-attn-3表现最佳,这次与PyTorch原生函数相比,性能提高了约5倍。对于编译模型,我们看到TE内核脱颖而出,总体最佳每步时间为53毫秒。


使用FlexAttention定制注意力

到目前为止,我们一直关注标准的注意力函数。然而,有时我们可能想使用典型注意力计算的一种变体,在这种变体中,我们要么屏蔽中间张量的一些值,要么对它们应用某种操作。这些类型的更改可能会干扰我们使用上面介绍的优化注意力块的能力。在本节中,我们将讨论解决这个问题的一些方法:


利用高级内核API

许多优化后的注意力内核提供了广泛的API,可以控制注意力计算的定制。在实现新解决方案之前,请探索这些API,以确定它们是否已经支持你所需的功能。


实现自定义内核:

如果现有的API无法满足你的需求,你可以考虑创建自己的自定义注意力实现。在之前的文章(例如,这里)中,我们讨论了自定义内核开发的一些优缺点。实现最佳性能可能非常困难。如果你选择走这条路,一种方法可能是从现有的(最优)内核开始,并进行最小更改以集成所需的功能。


使用FlexAttention:

FlexAttention是PyTorch最近添加的功能,它使用户能够实现各种注意力变体,而无需在性能上做出妥协。将查询和键标记的点积结果表示为分数,flex_attention允许编程一个score_mod函数或一个block_mask掩码,该掩码会自动应用于分数张量。


FlexAttention通过将score_mod操作符编译到注意力操作符中,从而创建一个单一的融合内核来工作。它还利用block_masks的稀疏性来避免不必要的计算。FlexAttention文档中报告的基准测试显示,在各种用例中,性能都有显著提高。


让我们来看看score_mod和block_mask的实际应用。


Score Mod示例:使用Tanh进行软上限

软上限是一种常用于控制logit大小的技术(例如,见此处)。以下代码块将我们的PyTorch原生注意力内核扩展为包含软上限:


def softcap_attn(q, k, v):
    scale = HEAD_DIM ** -0.5
    q = q * scale
    attn = q @ k.transpose(-2, -1)
    # apply soft-capping
    attn = 30 * torch.tanh(attn/30)
    attn = attn.softmax(dim=-1)
    x = attn @ v
    return x


在下面的代码块中,我们首先使用PyTorch原生内核训练我们的模型,然后使用优化的Flex Attention API进行训练。这些实验是在序列长度为3136的设置下运行的。


# flex attention imports
from torch.nn.attention.flex_attention import (
    create_block_mask,
    create_mask,
    flex_attention
)
compiled_flex = torch.compile(flex_attention)
# score_mod definition
def tanh_softcap(score, b, h, q_idx, kv_idx):
    return 30 * torch.tanh(score/30)

block_fn = functools.partial(MyAttentionBlock, attn_fn=softcap_attn)
print(f'Attention with Softcap')
train(block_fn)
print(f'Compiled Attention with Softcap')
train_compile(block_fn)
flex_fn = functools.partial(flex_attention, score_mod=tanh_softcap)
compiled_flex_fn = functools.partial(compiled_flex, score_mod=tanh_softcap)
block_fn = functools.partial(MyAttentionBlock,
                             attn_fn=flex_fn)
compiled_block_fn = functools.partial(MyAttentionBlock,
                             attn_fn=compiled_flex_fn)
print(f'Flex Attention with Softcap')
train(compiled_block_fn)
print(f'Compiled Flex Attention with Softcap')
train_compile(block_fn)


实验的结果记录在下表中:


19


Flash Attention内核的影响显而易见,它在急切模式下提供了大约3.5倍的性能提升,在编译模式下提供了1.5倍的性能提升。


Mask Mod示例:邻域屏蔽

我们通过向注意力分数应用稀疏掩码来评估mask_mod功能。回忆一下,我们序列中的每个标记都代表二维输入图像中的一个补丁。我们修改了内核,使得每个标记只关注对应二维标记数组中5x5窗口内的其他标记。


# convert the token id to a 2d index
def seq_indx_to_2d(idx):
    n_row_patches = IMG_SIZE // PATCH_SIZE
    r_ind = idx // n_row_patches
    c_ind = idx % n_row_patches
    return r_ind, c_ind
# only attend to tokens in a 5x5 surrounding window in our 2D token array
def mask_mod(b, h, q_idx, kv_idx):
    q_r, q_c = seq_indx_to_2d(q_idx)
    kv_r, kv_c = seq_indx_to_2d(kv_idx)
    return torch.logical_and(torch.abs(q_r-kv_r)<5, torch.abs(q_c-kv_c)<5)


作为我们实验的基线,我们使用了PyTorch的SDPA(它支持传入注意力掩码)。以下代码块首先包含了带掩码的SDPA实验,然后是Flex Attention的实现:


# materialize the mask to use in SDPA
mask = create_mask(mask_mod, 1, 1, SEQ_LEN, SEQ_LEN, device='cuda')
set_sdpa_backend('all')
masked_sdpa = functools.partial(sdpa, attn_mask=mask)
block_fn = functools.partial(MyAttentionBlock,
                             attn_fn=masked_sdpa)
print(f'Masked SDPA Attention')
train(block_fn)
print(f'Compiled Masked SDPA Attention')
train_compile(block_fn)
block_mask = create_block_mask(mask_mod, None, None, SEQ_LEN, SEQ_LEN)
flex_fn = functools.partial(flex_attention, block_mask=block_mask)
compiled_flex_fn = functools.partial(compiled_flex, block_mask=block_mask)
block_fn = functools.partial(MyAttentionBlock,
                             attn_fn=flex_fn)
compiled_block_fn = functools.partial(MyAttentionBlock,
                             attn_fn=compiled_flex_fn)
print(f'Masked Flex Attention')
train(compiled_block_fn)
print(f'Compiled Masked Flex Attention')
train_compile(block_fn)


实验结果记录如下:


20


再次强调,Flex Attention提供了显著的性能提升,在急切模式下提升了2.19倍,在编译模式下提升了2.59倍。


Flex Attention的局限性

尽管我们已经成功展示了Flex Attention的力量和潜力,但还是有一些局限性需要注意:


  1. 修改范围有限:使用Flex Attention(截至本文撰写时)只能修改注意力分数(即查询和键标记之间点积的结果)。它不支持在注意力计算的其他阶段进行修改。
  2. 依赖于torch.compile:由于依赖于torch.compile,因此必须注意避免过多的重新编译,因为这可能会极大地降低运行时性能。例如,虽然文档屏蔽的支持非常吸引人,但只有当所有文档的长度之和保持不变时,它才能按预期工作。
  3. score_mod不支持可训练参数:在本文撰写时,Flex Attention不支持包含可训练参数的score_mod实现。例如,尽管文档强调了支持相对位置编码,但这些编码通常使用可训练参数(而不是固定值)来实现,而Flex Attention目前无法支持这一点。


总结

随着机器学习模型中对Transformer架构和注意力层的依赖增加,对这些组件进行优化所需的工具和技术的需求也在增加。在本文中,我们探索了许多注意力内核变体,每个变体都有其独特的属性、功能和局限性。重要的是,并非一种尺寸适合所有情况——不同的模型和使用场景将需要使用不同的内核和不同的优化策略。这凸显了拥有各种优化注意力层的工具和技术的重要性。



文章来源:https://medium.com/towards-data-science/increasing-transformer-model-efficiency-through-attention-layer-optimization-fefa6f87b1d6
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消