斯坦福推出FlashAttention-2:提速长文本语言模型
2023年07月21日 由 Alex 发表
584095
0
在过去的一年中,自然语言处理领域取得了显著的进展,新型的语言模型具备了更长的上下文。在这些模型中,GPT-4的上下文长度达到32k,MosaicML的MPT具有65k的上下文长度,Anthropic的Claude则引人注目地达到了100k的上下文长度。随着长文档查询和故事创作等应用的不断发展,对于具备扩展上下文的语言模型的需求变得更加明显。然而,挑战在于扩展Transformer模型的上下文长度,因为它们的注意力层在输入序列长度上的计算和内存需求呈二次方增长。
为了应对这一挑战,一年前发布的创新算法FlashAttention在各个组织和研究实验室中迅速得到了采用。该算法成功地加速了注意力计算,并减少了内存占用,而不会在牺牲准确性或结果近似方面有所妥协。在初始发布时,FlashAttention的性能比经过优化的基准线快2-4倍,证明其具有突破性的进展。然而,它仍有未开发的潜力,因为它没有达到经过优化的矩阵乘法(GEMM)运算的极快速度,这在A100 GPU上可以达到124 TFLOPs/s。
为了实现进一步的飞跃,FlashAttention的开发人员现在推出了FlashAttention-2,这是一款重塑版本,显著超越了其前身。借助Nvidia的CUTLASS 3.x和CuTe核心库,FlashAttention-2实现了令人瞩目的2倍加速,在A100 GPU上的性能达到了高达230 TFLOPs/s。此外,在GPT风格语言模型的端到端训练中,FlashAttention-2实现了高达225 TFLOPs/s的训练速度,模型FLOP利用率达到了令人印象深刻的72%。
FlashAttention-2的关键改进在于其更好的并行性和工作分配策略。最初,FlashAttention通过批大小和头数进行并行化,有效利用了GPU上的计算资源。然而,对于具有较小批大小或较少头数的长序列,FlashAttention-2现在通过序列长度维度进行并行化,从而在这些场景中实现了显著的加速。
另一个改进涉及在每个线程块内有效地将工作分配给不同的warp。在FlashAttention中,将K和V分割成四个warp,同时保持Q对所有warp可访问(称为“sliced-K”方案),导致了不必要的共享内存读写,从而降低了计算速度。FlashAttention-2采用了不同的方法,现在将Q分割成四个warp,同时保持K和V对所有warp可访问。这消除了warp之间的通信需求,并显著减少了共享内存的读写操作,进一步提升性能。
FlashAttention-2引入了几个新功能,以拓宽其适用范围并增强其能力。它现在支持最多256个头维度,适用于像GPT-J、CodeGen、CodeGen2和StableDiffusion 1.x这样的模型,提供了更多的加速和节省内存的机会。此外,FlashAttention-2采用了多查询注意力(Multi-Query Attention,MQA)和分组查询注意力(Grouped-Query Attention,GQA)变体,其中查询的多个头可以关注同一个键和值的头,从而提高推理吞吐量和性能表现。
FlashAttention-2的性能确实令人印象深刻。在A100 80GB SXM4 GPU上进行基准测试,与其前身相比,它的速度提高了约2倍,与PyTorch中的标准注意力实现相比,速度提高了高达9倍。此外,当用于端到端训练GPT风格的模型时,FlashAttention-2在A100 GPU上可达到高达225 TFLOPs/s的性能,相比已经高度优化的具有FlashAttention的模型,端到端加速比达到了1.3倍。
FlashAttention-2的潜在应用前景非常广阔。通过与之前的8k上下文模型相同的价格训练具有16k更长上下文的模型,该技术可以帮助分析长篇书籍、报告、高分辨率图像、音频和视频。在包括H100 GPU和AMD GPU在内的设备上拓宽其适用范围,并且对fp8等新数据类型进行优化的计划正在进行中。此外,将FlashAttention-2的低级优化与高级算法变化相结合,可以为训练具有前所未有的更长上下文的AI模型铺平道路。与编译器研究人员的合作,以提升可编程性,也是未来的发展方向,为下一代语言模型带来光明的未来。
来源:https://www.marktechpost.com/2023/07/20/stanford-research-introduces-flashattention-2-a-leap-in-speed-and-efficiency-for-long-context-language-models/