在2017年的里程碑式论文《Attention Is All You Need》中,Transformer架构被广泛认为是过去十年中最具影响力的科学突破之一。Transformer的核心是注意力机制,这是一种新颖的方法,使AI模型能够根据当前任务关注输入序列的不同部分,从而理解复杂结构。最初在自然语言处理领域得到验证,Transformer架构的成功迅速扩展到许多其他领域,包括语音识别、场景理解、强化学习、蛋白质结构预测等。然而,注意力层对资源要求极高,随着这些层在越来越大的模型中成为标准,与它们的训练和部署相关的成本也急剧上升。这迫切需要能够降低这一核心层计算成本的策略,以提高基于Transformer的AI模型的效率和可扩展性。
在这篇文章中,我们将探索在PyTorch中优化注意力的几种工具。我们的重点将放在保持注意力层准确性的方法上,包括PyTorch SDPA、FlashAttention、TransformerEngine Attention、FlexAttention和xFormer attention。
重要的是,注意力优化是一个活跃的研究领域,新方法层出不穷。我们的目标是提高你对现有解决方案的认识,并为你进一步的探索和实验提供基础。我们下面分享的代码仅用于演示目的。
玩具模型
为了便于讨论,我们使用流行的timm Python包(版本0.9.7)构建了一个基于视觉Transformer(ViT)的分类模型。我们将使用这个模型来说明各种注意力内核对性能的影响。
首先,我们定义了一个简化的Transformer块,它允许通过将其传递给构造函数来编程注意力函数。由于注意力实现假设了特定的输入张量格式,我们还包含了一个控制格式的选项,以确保与我们选择的注意力内核兼容。
# general imports
import os, time, functools
# torch imports
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
# timm imports
from timm.models.vision_transformer import VisionTransformer
from timm.layers import Mlp
IMG_SIZE = 224
BATCH_SIZE = 128
# Define ViT settings
NUM_HEADS = 16
HEAD_DIM = 64
DEPTH = 24
PATCH_SIZE = 16
SEQ_LEN = (IMG_SIZE // PATCH_SIZE)**2 # 196
class MyAttentionBlock(nn.Module):
def __init__(
self,
attn_fn,
format = None,
dim: int = 768,
num_heads: int = 12,
**kwargs
) -> None:
super().__init__()
self.attn_fn = attn_fn
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.proj = nn.Linear(dim, dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=dim * 4,
)
permute = (2, 0, 3, 1, 4)
self.permute_attn = functools.partial(torch.transpose,dim0=1,dim1=2)
if format == 'bshd':
permute = (2, 0, 1, 3, 4)
self.permute_attn = nn.Identity()
self.permute_qkv = functools.partial(torch.permute,dims=permute)
def forward(self, x_in: torch.Tensor) -> torch.Tensor:
x = self.norm1(x_in)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
# permute tensor based on the specified format
qkv = self.permute_qkv(qkv)
q, k, v = qkv.unbind(0)
# use the attention function specified by the user
x = self.attn_fn(q, k, v)
# permute output according to the specified format
x = self.permute_attn(x).reshape(B, N, C)
x = self.proj(x)
x = x + x_in
x = x + self.mlp(self.norm2(x))
return x
我们定义了一个随机生成的数据集,将在训练过程中将其输入到我们的模型中。
# Use random data
class FakeDataset(Dataset):
def __len__(self):
return 1000000
def __getitem__(self, index):
rand_image = torch.randn([3, IMG_SIZE, IMG_SIZE],
dtype=torch.float32)
label = torch.tensor(data=index % 1000, dtype=torch.int64)
return rand_image, label
接下来,我们定义ViT训练函数。虽然我们的示例主要关注的是展示训练工作负载,但必须强调的是,在模型推理过程中,优化注意力层同样(如果不是更加)重要。
我们定义的训练函数接受一个自定义的Transformer块和一个控制是否使用torch.compile的标志。
def train_fn(block_fn, compile):
torch.random.manual_seed(0)
device = torch.device("cuda:0")
torch.set_float32_matmul_precision("high")
# Create dataset and dataloader
train_set = FakeDataset()
train_loader = DataLoader(
train_set, batch_size=BATCH_SIZE,
num_workers=12, pin_memory=True, drop_last=True)
model = VisionTransformer(
img_size=IMG_SIZE,
patch_size=PATCH_SIZE,
embed_dim=NUM_HEADS*HEAD_DIM,
depth=DEPTH,
num_heads=NUM_HEADS,
class_token=False,
global_pool="avg",
block_fn=block_fn
).to(device)
if compile:
model = torch.compile(model)
# Define loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters())
model.train()
t0 = time.perf_counter()
summ = 0
count = 0
for step, data in enumerate(train_loader):
# Copy data to GPU
inputs = data[0].to(device=device, non_blocking=True)
label = data[1].to(device=device, non_blocking=True)
with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
outputs = model(inputs)
loss = criterion(outputs, label)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
# Capture step time
batch_time = time.perf_counter() - t0
if step > 20: # Skip first steps
summ += batch_time
count += 1
t0 = time.perf_counter()
if step > 100:
break
print(f'average step time: {summ / count}')
# define compiled and uncompiled variants of our train function
train = functools.partial(train_fn, compile=False)
train_compile = functools.partial(train_fn, compile=True)
在下面的代码块中,我们定义了一个原生的PyTorch注意力函数,并使用它来训练我们的ViT模型:
def attn_fn(q, k, v):
scale = HEAD_DIM ** -0.5
q = q * scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
x = attn @ v
return x
block_fn = functools.partial(MyAttentionBlock, attn_fn=attn_fn)
print('Default Attention')
train(block_fn)
print('Compiled Default Attention')
train_compile(block_fn)
我们在配备CUDA 12.4和PyTorch 2.5.1的NVIDIA H100上运行了上述代码。未编译版本的平均步长时间为370毫秒(ms),而编译版本则提升至242ms。在考虑其他执行注意力计算的方法时,我们将使用这些结果作为比较的基准。
PyTorch SDPA
在PyTorch中提升注意力层性能的最简单方法之一是使用scaled_dot_product_attention(SDPA)函数。目前处于测试阶段的PyTorch SDPA整合了多个内核级别的优化,并根据输入的特性动态选择最高效的优化方案。目前支持的后端包括:FlashAttention-2、内存高效注意力、基于C++的数学注意力以及CuDNN。这些后端融合了高级操作,同时采用GPU级别的优化,以提高计算效率和内存利用率。
SDPA在不断发展,新的和改进的后端实现会定期推出。紧跟PyTorch的最新发布是获取最新性能提升的关键。例如,PyTorch 2.5引入了更新的CuDNN后端,其中包含一个专门为在NVIDIA Hopper架构GPU上训练而定制的SDPA原语。
在下面的代码块中,我们遍历了受支持的后端列表,并评估了使用每个后端进行训练的运行时性能。我们使用了一个辅助函数set_sdpa_backend来编程设置SDPA后端:
from torch.nn.functional import scaled_dot_product_attention as sdpa
def set_sdpa_backend(backend):
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(False)
torch.backends.cuda.enable_cudnn_sdp(False)
if backend in ['flash_sdp','all']:
torch.backends.cuda.enable_flash_sdp(True)
if backend in ['mem_efficient_sdp','all']:
torch.backends.cuda.enable_mem_efficient_sdp(True)
if backend in ['math_sdp','all']:
torch.backends.cuda.enable_math_sdp(True)
if backend in ['cudnn_sdp','all']:
torch.backends.cuda.enable_cudnn_sdp(True)
for backend in ['flash_sdp', 'mem_efficient_sdp',
'math_sdp', 'cudnn_sdp']:
set_sdpa_backend(backend)
block_fn = functools.partial(MyAttentionBlock,
attn_fn=sdpa)
print(f'PyTorch SDPA - {backend}')
train(block_fn)
print(f'Compiled PyTorch SDPA - {backend}')
train_compile(block_fn)
我们将中期结果总结在下表中
虽然在即时模式下,SDPA后端的选择对性能有显著影响,但模型编译所进行的优化似乎掩盖了注意力内核之间的差异。我们再次提醒,不要从这些结果中得出任何结论,因为不同注意力函数对性能的影响可能因具体模型和使用场景的不同而差异显著。
第三方注意力内核
虽然PyTorch的SDPA是一个很好的起点,但使用第三方注意力内核可以进一步加速你的机器学习工作负载。这些替代方案通常具有更高的灵活性,为注意力提供了更广泛的配置选项。有些还可能包括针对特定硬件加速器或更新GPU架构的优化。
在本节中,我们将探索一些可用的第三方注意力内核,并评估它们对运行时性能的潜在影响。
FlashAttention-3
虽然PyTorch SDPA支持FlashAttention后端,但在flash-attn库中可以找到更高级的FlashAttention实现。在这里,我们将探索FlashAttention-3的测试版,其速度相比FlashAttention-2提高了多达2倍。鉴于其仍处于开发早期阶段,FlashAttention-3只能直接从GitHub仓库安装,并且其使用受限于特定的头维度。此外,它还不支持模型编译。在下面的代码块中,我们将Transformer块配置为使用flash-attn-3,同时将注意力输入格式设置为“bshd”(批次、序列、头、深度),以满足该库的期望。
# flash attention 3
from flash_attn_interface import flash_attn_func as fa3
attn_fn = lambda q,k,v: fa3(q,k,v)[0]
block_fn = functools.partial(MyAttentionBlock,
attn_fn=attn_fn,
format='bshd')
print(f'Flash Attention 3')
train(block_fn)
最终得到的步长为240毫秒,比SDPA的flash-attn快了5%。
Transformer Engine
Transformer Engine(TE)是一个专为在NVIDIA GPU上加速Transformer模型而设计的库。TE定期更新,利用最新NVIDIA硬件和软件的功能进行优化,使用户能够在这些优化被集成到如PyTorch等通用框架之前很久,就能访问到专用的内核。
在下面的代码块中,我们使用了TE版本1.11.0中的DotProductAttention。与PyTorch的SDPA类似,TE支持多种后端,这些后端通过环境变量进行控制。在这里,我们展示了如何使用NVTE_FUSED_ATTN后端。
def set_te_backend(backend):
# must be applied before first use of
# transformer_engine.pytorch.attention
os.environ["NVTE_FLASH_ATTN"] = '0'
os.environ["NVTE_FUSED_ATTN"] = '0'
os.environ["NVTE_UNFUSED_ATTN"] = '0'
if backend == 'flash':
os.environ["NVTE_FLASH_ATTN"] = '1'
if backend == 'fused':
os.environ["NVTE_FUSED_ATTN"] = '1'
if backend == 'unfused':
os.environ["NVTE_UNFUSED_ATTN"] = '1'
from transformer_engine.pytorch.attention import DotProductAttention
set_te_backend('fused')
attn_fn = DotProductAttention(NUM_HEADS, HEAD_DIM, NUM_HEADS,
qkv_format='bshd',
# disable masking (default is causal mask)
attn_mask_type='no_mask')
block_fn = functools.partial(MyAttentionBlock,
attn_fn=attn_fn,
format='bshd')
print(f'Transformer Engine Attention')
train(block_fn)
print(f'Compiled Transformer Engine Attention')
train_compile(block_fn)
TE注意力在即时模式和编译模型变体中的平均步长时间分别为243毫秒和204毫秒。
XFormer Attention
PyTorch SDPA的内存高效后端底层是由xFormers库提供的一个注意力内核。同样,我们可以直接访问源代码,以受益于最新的内核优化和完整的API功能集。在下面的代码块中,我们使用了xFormers版本0.0.28中的memory_efficient_attention操作符。
# xformer memory efficient attention
from xformers.ops import memory_efficient_attention as mea
block_fn = functools.partial(MyAttentionBlock,
attn_fn=mea,
format='bshd')
print(f'xFormer Attention ')
train(block_fn)
print(f'Compiled xFormer Attention ')
train_compile(block_fn)
这个即时模型变体的平均步长时间为246毫秒,比SDPA内存高效内核快了10.5%。编译变体的步长时间为203毫秒。
结果
下表总结了我们的实验:
在即时模型中,获胜者是flash-attn-3,其平均步长时间比我们的基线模型快了54%。这意味着训练成本也相应降低了大约54%。在编译模式下,各个优化内核的性能大致相当,最快的实现达到了202毫秒,相比基线实验提高了20%。
如上所述,具体的节省效果很大程度上取决于模型定义。为了评估这种变异性,我们使用修改后的设置重新进行了实验,将注意力序列长度增加到3136个标记。
IMG_SIZE = 224224
BATCH_SIZE = 8
# Define ViT settings
NUM_HEADS = 12
HEAD_DIM = 64
DEPTH = 6
PATCH_SIZE = 4
SEQ_LEN = (IMG_SIZE // PATCH_SIZE)**2 # 3136
结果总结在下表中:
我们的直接观察是,当序列长度增加时,注意力内核对性能的影响更为显著。再一次,flash-attn-3在即时执行模式中脱颖而出——这次与PyTorch原生函数相比,性能提高了约5倍。对于编译模型,我们看到TE内核以最佳的53毫秒步长时间脱颖而出。
使用FlexAttention定制注意力
到目前为止,我们一直专注于标准的注意力函数。然而,有时我们可能希望使用典型注意力计算的变体,其中我们要么屏蔽中间张量的一些值,要么对它们应用某些操作。这些类型的更改可能会干扰我们使用上面提到的优化注意力块的能力。在本节中,我们将讨论解决此问题的一些方法:
利用高级内核API
许多优化后的注意力内核提供了广泛的API,可以控制自定义注意力计算。在实现新解决方案之前,请探索这些API,以确定它们是否已经支持你所需的功能。
实现自定义内核:
如果现有的API无法满足你的需求,你可以考虑创建自己的自定义注意力实现。实现最佳性能可能非常困难。如果你选择走这条路,一种方法可能是从现有的(最优的)内核开始,并进行最小的更改以集成所需的功能。
使用FlexAttention:
FlexAttention是PyTorch最近添加的功能,它使用户能够实现各种注意力变体,而无需在性能上做出妥协。将查询和键标记的点积结果表示为分数,flex_attention允许对分数张量编程一个score_mod函数或一个block_mask掩码。
FlexAttention通过将score_mod操作符编译到注意力操作符中,从而创建一个融合的内核来工作。它还利用block_mask的稀疏性来避免不必要的计算。FlexAttention文档中报告的基准测试显示,对于各种用例,性能都有显著提高。
Score Mod示例:使用Tanh进行软上限
软上限是一种常用的技术,用于控制logit的大小。以下代码块通过软上限扩展了我们的PyTorch原生注意力内核:
def softcap_attn(q, k, v):
scale = HEAD_DIM ** -0.5
q = q * scale
attn = q @ k.transpose(-2, -1)
# apply soft-capping
attn = 30 * torch.tanh(attn/30)
attn = attn.softmax(dim=-1)
x = attn @ v
return x
在下面的代码块中,我们首先使用PyTorch原生内核训练我们的模型,然后使用优化的Flex Attention API进行训练。这些实验是在序列长度为3136的设置下运行的。
# flex attention imports
from torch.nn.attention.flex_attention import (
create_block_mask,
create_mask,
flex_attention
)
compiled_flex = torch.compile(flex_attention)
# score_mod definition
def tanh_softcap(score, b, h, q_idx, kv_idx):
return 30 * torch.tanh(score/30)
block_fn = functools.partial(MyAttentionBlock, attn_fn=softcap_attn)
print(f'Attention with Softcap')
train(block_fn)
print(f'Compiled Attention with Softcap')
train_compile(block_fn)
flex_fn = functools.partial(flex_attention, score_mod=tanh_softcap)
compiled_flex_fn = functools.partial(compiled_flex, score_mod=tanh_softcap)
block_fn = functools.partial(MyAttentionBlock,
attn_fn=flex_fn)
compiled_block_fn = functools.partial(MyAttentionBlock,
attn_fn=compiled_flex_fn)
print(f'Flex Attention with Softcap')
train(compiled_block_fn)
print(f'Compiled Flex Attention with Softcap')
train_compile(block_fn)
实验的结果记录在下表中:
Flash Attention内核的影响显而易见,在即时模式下性能提升了大约3.5倍,在编译模式下提升了1.5倍。
掩码修改示例——邻域掩码
我们通过向注意力分数应用稀疏掩码来评估mask_mod功能。回顾一下,我们序列中的每个标记都代表二维输入图像中的一个补丁。我们修改了内核,使得每个标记只关注对应二维标记数组中5x5窗口内的其他标记。
# convert the token id to a 2d index
def seq_indx_to_2d(idx):
n_row_patches = IMG_SIZE // PATCH_SIZE
r_ind = idx // n_row_patches
c_ind = idx % n_row_patches
return r_ind, c_ind
# only attend to tokens in a 5x5 surrounding window in our 2D token array
def mask_mod(b, h, q_idx, kv_idx):
q_r, q_c = seq_indx_to_2d(q_idx)
kv_r, kv_c = seq_indx_to_2d(kv_idx)
return torch.logical_and(torch.abs(q_r-kv_r)<5, torch.abs(q_c-kv_c)<5)
作为实验的基准,我们使用了PyTorch的SDPA,它支持传入注意力掩码。以下代码块首先包含了带掩码的SDPA实验,随后是Flex Attention的实现:
# materialize the mask to use in SDPA
mask = create_mask(mask_mod, 1, 1, SEQ_LEN, SEQ_LEN, device='cuda')
set_sdpa_backend('all')
masked_sdpa = functools.partial(sdpa, attn_mask=mask)
block_fn = functools.partial(MyAttentionBlock,
attn_fn=masked_sdpa)
print(f'Masked SDPA Attention')
train(block_fn)
print(f'Compiled Masked SDPA Attention')
train_compile(block_fn)
block_mask = create_block_mask(mask_mod, None, None, SEQ_LEN, SEQ_LEN)
flex_fn = functools.partial(flex_attention, block_mask=block_mask)
compiled_flex_fn = functools.partial(compiled_flex, block_mask=block_mask)
block_fn = functools.partial(MyAttentionBlock,
attn_fn=flex_fn)
compiled_block_fn = functools.partial(MyAttentionBlock,
attn_fn=compiled_flex_fn)
print(f'Masked Flex Attention')
train(compiled_block_fn)
print(f'Compiled Masked Flex Attention')
train_compile(block_fn)
实验结果记录如下:
再次,Flex Attention提供了显著的性能提升,在即时模式下达到了2.19倍,在编译模式下达到了2.59倍。
总结
随着机器学习模型对Transformer架构和注意力层的依赖增加,对这些组件进行优化的工具和技术的需求也在增加。在本文中,我们探索了多种注意力内核变体,每种变体都有其独特的属性、能力和局限性。重要的是,没有一种解决方案适合所有情况——不同的模型和用例将需要使用不同的内核和不同的优化策略。这强调了拥有广泛多样的工具和技术来优化注意力层的重要性。