很少有算法能像 FlashAttention 一样对新一代变压器架构产生如此大的影响。FlashAttention 最初由普林斯顿大学的研究人员开发,其中包括著名的 Tri Dao,FlashAttention 及其后续版本 FlashAttention-2 能够通过最小化读写来提高 GPU 中注意力机制的性能。几乎就在最初的论文发表之后,FlashAttention 就迅速被新一代Transformer 所采用。人们对 FlashAttention 的抱怨并不多,但为数不多的抱怨之一就是它无法充分利用新的硬件架构。例如,FlashAttention-2 在 H100 GPU 中的最大 FLOPs 利用率仅为 35%。
但现在我们有了新版本。
上周,一组来自 Meta、普林斯顿大学、英伟达(NVIDIA)和其他人工智能实验室的人工智能研究人员发表了 FlashAttention-3 的论文和开源代码。新版方法采用了多项技术,利用张量内核的异步性,加快了 H100 GPU 的注意力。结果很简单: FlashAttention-3 的速度快得惊人。新模型在 H100 中实现了 75% 的理论最大 FLOP 利用率,实际性能提高了 1.5-2 倍。新算法还能使用精度更低的数字,从而减少内存占用。
让我们深入了解一些细节,但在此之前,让我们先回顾一下 FlashAttention 的一些细节。
Flash注意力
FlashAttention 的设计目的是通过重新排序步骤以及利用平铺和重新计算来优化注意力机制的计算。这种方法大大加快了处理速度,并将内存使用量从与序列长度相关的二次方降低到线性。该算法利用平铺将输入块从 GPU 内存(HBM)加载到速度更快的高速缓存(SRAM),处理该块内的注意力,并将输出更新回 GPU 内存。通过避免在 HBM 中存储大型中间矩阵,FlashAttention 减少了内存读/写操作,从而将挂机时间提高了 2-4 倍。
在 FlashAttention 前向传递中,平铺和softmax 重缩放允许算法按块运行。这种方法避免了从 HBM 进行大量读/写操作,确保了无近似值的精确输出。
H100 GPU 和注意力
FlashAttention-3 的神奇之处在于利用最新的 H100 功能来提高注意力性能,并解决前代产品的一些局限性。
虽然 FlashAttention-2 在Ampere(A100)图形处理器上实现了高达 70% 的理论最大 FLOPS,但它并没有充分利用 Hopper 图形处理器的新功能。以下是 Hopper GPU 的一些关键特性及其重要意义:
WGMMA(Warpgroup Matrix Multiply-Accumulate): 利用 Hopper GPU 上的新张量核,与 Ampere GPU 中的旧 mma.sync 指令相比,吞吐量要高得多。
TMA(张量内存加速器): 该硬件单元可加快全局内存和共享内存之间的数据传输,处理索引计算和越界预测。它能释放寄存器,提高磁贴大小和效率。
FP8 的低精度:该功能通过使用更少的位来表示浮点数,将张量核的吞吐量提高了一倍(例如,从 FP16 的 989 TFLOPS 提高到 FP8 的 1978 TFLOPS),从而以一定的精度换取速度。
FlashAttention-3
FlashAttention-3 利用英伟达™(NVIDIA®)公司 CUTLASS 库的抽象功能整合了这些新的 Hopper 功能。ThunderKitten 2 和 cuDNN 9 等研究表明,这些硬件特性可显著加快注意力计算速度。通过调整 FlashAttention 以利用这些特性,其性能得到了显著提高(例如,从 FlashAttention-2 FP16 前向传递的 350 TFLOPS 提高到约 540-570 TFLOPS)。Hopper 上的异步指令(WGMMA 和 TMA)进一步为算法优化提供了机会。
FlashAttention-3 引入了三项关键技术,以提高现代 GPU 架构的性能:
1. 生产者-消费者异步:这种方法采用了翘曲专用软件流水线技术,将数据生产者和消费者分成不同的warp。这种分离利用异步执行来更好地隐藏内存和指令问题延迟。
2. 在异步分块 GEMM 下隐藏软最大值: 通过将低吞吐量 softmax 操作与异步 WGMMA 指令重叠,FlashAttention-3 可以规避 softmax 和 GEMM 之间的顺序依赖关系。例如,在两阶段版本中,当 softmax 处理分数矩阵的一个区块时,WGMMA 会计算下一个区块。
3. 硬件加速的低精度 GEMM:这一调整针对 FP8 张量内核的 GEMM,将实测的 TFLOPS/s 提高了近一倍。它涉及通过块量化和不连贯处理来管理 FP32 累加器和 FP8 操作数矩阵的不同布局要求,以减轻精度降低带来的精度损失。
结果
FlashAttention-3 背后的团队测量了不同序列长度的运行时间,并将其与标准 PyTorch 实现、FlashAttention-2、Triton 中的 FlashAttention-2(使用 H100 特定指令)以及来自 cuDNN 的供应商 H100 优化 FlashAttention-2 进行了比较。结果发现,FlashAttention-3 比 FlashAttention-2 快 2 倍,比 Triton 中的 FlashAttention-2 快 1.5 倍,在 H100 GPU 上实现了高达 740 TFLOPS/s 的速度,或理论最大速度的 75%。
FlashAttention-3 是生成式人工智能算法中一个令人兴奋的发展。这种方法几乎肯定会改善 LLM 中的大上下文窗口,并在现代 GPU 架构上实现更好的推理性能。令人印象深刻的进展