使用自定义内核加速 PyTorch

2025年01月10日 由 alex 发表 609 0

PyTorch 提供了非凡的灵活性,允许你在几秒钟内编写复杂的 GPU 加速操作。然而,这种便利性是有代价的。PyTorch 按顺序执行你的代码,导致性能不佳。这会使得模型训练变慢,进而影响你实验的迭代周期、团队的稳健性、财务影响等等。


在这篇文章中,我将探讨三种加速 PyTorch 操作的策略。每种方法都使用 softmax 作为我们的“Hello World”示例,但你可以替换成任何你喜欢的函数,所讨论的方法仍然适用。


我们将从 torch.compile 开始,接着编写自定义 Triton 内核,最后深入探讨设计 CUDA 内核。


torch.compile —— 快速提升性能的方法

torch.compile 是 PyTorch 中一个相对较新的 API,它在底层使用运行时图捕获和内核融合。只需一个装饰器,你通常就能看到速度提升,而无需对代码进行重大更改。


简单来说,例如,我们可以通过将操作合并为一个 GPU 函数来加速计算,从而消除单独 GPU 调用的开销。或者更好的是,通过用一个等效操作替换一系列操作来优化它们!


在常规的 PyTorch 执行模式(eager)中,这样的优化是不可能的,因为它会按照代码中调用的顺序执行操作。


使用torch.compile的Softmax实现

下面是一个简单的示例,展示了如何使用 torch.compile 实现和编译一个 softmax 函数。将其替换到你模型的 forward 传递中,你的代码(希望)就会运行得更快。


import torch
# Our softmax function in PyTorch land
def softmax_pytorch(x):
    # Avoid numerical instability by subtracting max
    x_max = torch.max(x, dim=-1, keepdim=True).values
    x_exp = torch.exp(x - x_max)
    return x_exp / torch.sum(x_exp, dim=-1, keepdim=True)
# Let's compile it with torch.compile
@torch.compile
def compiled_softmax(x):
    return softmax_pytorch(x)
if __name__ == "__main__":
    # Example usage:
    input_tensor = torch.randn((2, 4), device="cuda")
    output = compiled_softmax(input_tensor)
    print("Input:", input_tensor)
    print("Compiled Softmax Output:", output)


注意,如果你编译整个模型而不仅仅是单个操作,你将获得更大的速度提升。


优点:

  • 只需一行代码即可启用编译器。
  • 无需复杂的魔法仪式(除了可能需要处理的动态形状)。


缺点:

  • 首次编译时可能会较慢,之后速度会提升。
  • 并非所有模型都能显著加速,如果你的代码过于复杂,有时可能会导致问题。
  • 在处理动态形状时仍存在问题。


当输入形状发生变化且我们不想为每个特定大小重新编译代码时,需要使用动态形状编译模式。


Triton 代码 —— 用 Python 轻松编写 GPU 内核


为什么使用 Triton?

Triton 是一种语言,它可以编译成高效的 GPU 内核,同时允许你编写类似 Python 的代码。它在 PyTorch 的 dynamo/inductor 栈的底层使用,但你也可以编写自己的自定义操作!对于许多矩阵/张量操作(如 softmax)来说,你可以获得巨大的速度提升。因为既然可以自己编写内核,为什么还要等待官方的 PyTorch 内核呢?


Triton中的Softmax

以下是一个简短的代码片段,展示了如何在 Triton 中实现一个简单的 softmax 前向操作。为了演示,我将保持简短明了。在实际项目中,你可能会进行更高级的分块和块管理。


这看起来可能有些复杂,但只要你熟悉了 Triton,就会开始理解其逻辑。


import torch
import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config(
            kwargs=dict(
                BLOCK_SIZE_ROWS=BLOCK_SIZE_ROWS,
                num_stages=num_stages,
            ),
            num_warps=num_warps,
            num_stages=num_stages,
        )
        for BLOCK_SIZE_ROWS in (16, 32, 64, 128)
        for num_stages in (2, 3, 4)
        for num_warps in (2, 4, 8)
    ],
    key=['N_COLS'],
)
@triton.heuristics(
    values=dict(
        BLOCK_SIZE_COLS=lambda args: triton.next_power_of_2(args['N_COLS'])
    )
)
@triton.jit
def softmax_kernel(
    input_ptr: tl.tensor,
    output_ptr: tl.tensor,
    input_row_stride: int,
    output_row_stride: int,
    n_rows: int,
    N_COLS: tl.constexpr,
    BLOCK_SIZE_ROWS: tl.constexpr,
    BLOCK_SIZE_COLS: tl.constexpr,
    num_stages: tl.constexpr
):
    input_ptr = tl.make_block_ptr(
        base=input_ptr,
        shape=(n_rows, N_COLS),
        strides=(input_row_stride, 1),
        offsets=(0, 0),
        block_shape=(BLOCK_SIZE_ROWS, BLOCK_SIZE_COLS),
        order=(1, 0),
    )
    output_ptr = tl.make_block_ptr(
        base=output_ptr,
        shape=(n_rows, N_COLS),
        strides=(output_row_stride, 1),
        offsets=(0, 0),
        block_shape=(BLOCK_SIZE_ROWS, BLOCK_SIZE_COLS),
        order=(1, 0),
    )
    cols_mask = tl.arange(0, BLOCK_SIZE_COLS) < N_COLS
    row_idx = tl.program_id(0) * BLOCK_SIZE_ROWS
    in_tile_ptr = tl.advance(input_ptr, (row_idx, 0))
    row = tl.load(pointer=in_tile_ptr, boundary_check=(0, 1))
    # Subtract maximum for numerical stability
    row_minus_max = row - tl.max(row, axis=1, keep_dims=True)
    row_minus_max = tl.where(cols_mask, row_minus_max, -float('inf'))
    numerator = tl.exp(row_minus_max)
    denominator = tl.sum(numerator, axis=1, keep_dims=True)
    softmax_output = numerator / denominator
    out_tile_ptr = tl.advance(output_ptr, (row_idx, 0))
    tl.store(out_tile_ptr, softmax_output, boundary_check=(0, 1))

def softmax(x: torch.Tensor):
    x_orig_shape = x.shape
    x = x.view(-1, x_orig_shape[-1])
    n_rows, n_cols = x.shape
    y = torch.empty_like(x, memory_format=torch.contiguous_format)
    grid = lambda args: (
        triton.cdiv(n_rows, args['BLOCK_SIZE_ROWS']),
        1,
        1
    )
    softmax_kernel[grid](
        input_ptr=x,
        output_ptr=y,
        input_row_stride=x.stride(0),
        output_row_stride=y.stride(0),
        n_rows=n_rows,
        N_COLS=n_cols,
    )
    return y.view(*x_orig_shape)


确实,它看起来很复杂。但算法的核心只用几行代码就概括出来了。


    row_minus_max = row - tl.max(row, axis=1, keep_dims=True)
    row_minus_max = tl.where(cols_mask, row_minus_max, -float('inf'))
    numerator = tl.exp(row_minus_max)
    denominator = tl.sum(numerator, axis=1, keep_dims=True)
    
    softmax_output = numerator / denominator


其他部分都只是数据管理和一些额外的工作。


如果我们对不同数据长度进行基准测试,我们会发现我们的实现与 torch.nn.functional.softmax 的性能相当(后者是一个高度优化的内核!),并且远超简单的 PyTorch 实现。


3


优点:

  • 通过融合操作和优化内存访问模式,可能获得巨大的速度提升。
  • 比 torch.compile 提供更多的控制权。
  • 易于编写高效代码(我们的实现与 torch 相当!)
  • 也易于编写低效代码(如果你不知道自己在做什么)。


缺点:

  • 你现在成了内核开发者,这意味着如果出现问题,你需要负责调试。这真的很难。
  • 如果你进一步自定义反向传播,你可能需要再来一杯咖啡……或者更多。因为 torch 无法为 Triton 使用 autograd,所以你需要自己定义反向传播。


纯 CUDA(即Going Hardcore)

有时即使 Triton 也不够用,或者你就是喜欢挑战极限。在这种情况下,你可以用 C++ 编写自定义 CUDA 内核,编译它,并通过自定义扩展将其集成到 PyTorch 中。像这样的项目展示了人们如何构建专用内核以实现最大速度。


自定义 CUDA 中的 Softmax

你通常会有一个 setup.py 文件,用于编译 .cu 或 .cpp 文件,并将一个 Python 函数作为扩展暴露出来。


在这篇文章中,我不会提供这种方法的代码,因此这一事实本身就说明了一切。这种方法相当复杂,需要充分的理由,并且通常是你应该尝试的最后一种方法。


很容易写出低效、有错误、不安全的代码。


优点:

  • 最大控制权。“如果你想把事情做好,就自己动手。”
  • 如果优化得当,有潜力实现最快的内核。


缺点:

  • 需要深入理解 CUDA。
  • 内存管理、块大小、共享内存——这些都很难!
  • 维护开销可能非常高。


结论

在加速 PyTorch 操作方面,你可以选择逐步更复杂的方法:

  1. torch.compile:所需代码更改最少。
  2. Triton 内核:对内核行为有更多控制权,编码仍然相对容易。
  3. 纯 CUDA:最大优化潜力,但复杂度要高得多。


如果你寻求最简单的改进,请从 torch.compile 开始。如果这还不够,可以探索 Triton。对于高级用户,编写自定义 CUDA 内核可以带来进一步的性能提升,但这需要深厚的 GPU 编程技能。


文章来源:https://medium.com/towards-data-science/speed-up-pytorch-with-custom-kernels-but-it-gets-progressively-darker-e5a057796269
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消