PyTorch 提供了非凡的灵活性,允许你在几秒钟内编写复杂的 GPU 加速操作。然而,这种便利性是有代价的。PyTorch 按顺序执行你的代码,导致性能不佳。这会使得模型训练变慢,进而影响你实验的迭代周期、团队的稳健性、财务影响等等。
在这篇文章中,我将探讨三种加速 PyTorch 操作的策略。每种方法都使用 softmax 作为我们的“Hello World”示例,但你可以替换成任何你喜欢的函数,所讨论的方法仍然适用。
我们将从 torch.compile 开始,接着编写自定义 Triton 内核,最后深入探讨设计 CUDA 内核。
torch.compile —— 快速提升性能的方法
torch.compile 是 PyTorch 中一个相对较新的 API,它在底层使用运行时图捕获和内核融合。只需一个装饰器,你通常就能看到速度提升,而无需对代码进行重大更改。
简单来说,例如,我们可以通过将操作合并为一个 GPU 函数来加速计算,从而消除单独 GPU 调用的开销。或者更好的是,通过用一个等效操作替换一系列操作来优化它们!
在常规的 PyTorch 执行模式(eager)中,这样的优化是不可能的,因为它会按照代码中调用的顺序执行操作。
使用torch.compile的Softmax实现
下面是一个简单的示例,展示了如何使用 torch.compile 实现和编译一个 softmax 函数。将其替换到你模型的 forward 传递中,你的代码(希望)就会运行得更快。
import torch
# Our softmax function in PyTorch land
def softmax_pytorch(x):
# Avoid numerical instability by subtracting max
x_max = torch.max(x, dim=-1, keepdim=True).values
x_exp = torch.exp(x - x_max)
return x_exp / torch.sum(x_exp, dim=-1, keepdim=True)
# Let's compile it with torch.compile
@torch.compile
def compiled_softmax(x):
return softmax_pytorch(x)
if __name__ == "__main__":
# Example usage:
input_tensor = torch.randn((2, 4), device="cuda")
output = compiled_softmax(input_tensor)
print("Input:", input_tensor)
print("Compiled Softmax Output:", output)
注意,如果你编译整个模型而不仅仅是单个操作,你将获得更大的速度提升。
优点:
缺点:
当输入形状发生变化且我们不想为每个特定大小重新编译代码时,需要使用动态形状编译模式。
Triton 代码 —— 用 Python 轻松编写 GPU 内核
为什么使用 Triton?
Triton 是一种语言,它可以编译成高效的 GPU 内核,同时允许你编写类似 Python 的代码。它在 PyTorch 的 dynamo/inductor 栈的底层使用,但你也可以编写自己的自定义操作!对于许多矩阵/张量操作(如 softmax)来说,你可以获得巨大的速度提升。因为既然可以自己编写内核,为什么还要等待官方的 PyTorch 内核呢?
Triton中的Softmax
以下是一个简短的代码片段,展示了如何在 Triton 中实现一个简单的 softmax 前向操作。为了演示,我将保持简短明了。在实际项目中,你可能会进行更高级的分块和块管理。
这看起来可能有些复杂,但只要你熟悉了 Triton,就会开始理解其逻辑。
import torch
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config(
kwargs=dict(
BLOCK_SIZE_ROWS=BLOCK_SIZE_ROWS,
num_stages=num_stages,
),
num_warps=num_warps,
num_stages=num_stages,
)
for BLOCK_SIZE_ROWS in (16, 32, 64, 128)
for num_stages in (2, 3, 4)
for num_warps in (2, 4, 8)
],
key=['N_COLS'],
)
@triton.heuristics(
values=dict(
BLOCK_SIZE_COLS=lambda args: triton.next_power_of_2(args['N_COLS'])
)
)
@triton.jit
def softmax_kernel(
input_ptr: tl.tensor,
output_ptr: tl.tensor,
input_row_stride: int,
output_row_stride: int,
n_rows: int,
N_COLS: tl.constexpr,
BLOCK_SIZE_ROWS: tl.constexpr,
BLOCK_SIZE_COLS: tl.constexpr,
num_stages: tl.constexpr
):
input_ptr = tl.make_block_ptr(
base=input_ptr,
shape=(n_rows, N_COLS),
strides=(input_row_stride, 1),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_ROWS, BLOCK_SIZE_COLS),
order=(1, 0),
)
output_ptr = tl.make_block_ptr(
base=output_ptr,
shape=(n_rows, N_COLS),
strides=(output_row_stride, 1),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_ROWS, BLOCK_SIZE_COLS),
order=(1, 0),
)
cols_mask = tl.arange(0, BLOCK_SIZE_COLS) < N_COLS
row_idx = tl.program_id(0) * BLOCK_SIZE_ROWS
in_tile_ptr = tl.advance(input_ptr, (row_idx, 0))
row = tl.load(pointer=in_tile_ptr, boundary_check=(0, 1))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=1, keep_dims=True)
row_minus_max = tl.where(cols_mask, row_minus_max, -float('inf'))
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=1, keep_dims=True)
softmax_output = numerator / denominator
out_tile_ptr = tl.advance(output_ptr, (row_idx, 0))
tl.store(out_tile_ptr, softmax_output, boundary_check=(0, 1))
def softmax(x: torch.Tensor):
x_orig_shape = x.shape
x = x.view(-1, x_orig_shape[-1])
n_rows, n_cols = x.shape
y = torch.empty_like(x, memory_format=torch.contiguous_format)
grid = lambda args: (
triton.cdiv(n_rows, args['BLOCK_SIZE_ROWS']),
1,
1
)
softmax_kernel[grid](
input_ptr=x,
output_ptr=y,
input_row_stride=x.stride(0),
output_row_stride=y.stride(0),
n_rows=n_rows,
N_COLS=n_cols,
)
return y.view(*x_orig_shape)
确实,它看起来很复杂。但算法的核心只用几行代码就概括出来了。
row_minus_max = row - tl.max(row, axis=1, keep_dims=True)
row_minus_max = tl.where(cols_mask, row_minus_max, -float('inf'))
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=1, keep_dims=True)
softmax_output = numerator / denominator
其他部分都只是数据管理和一些额外的工作。
如果我们对不同数据长度进行基准测试,我们会发现我们的实现与 torch.nn.functional.softmax 的性能相当(后者是一个高度优化的内核!),并且远超简单的 PyTorch 实现。
优点:
缺点:
纯 CUDA(即Going Hardcore)
有时即使 Triton 也不够用,或者你就是喜欢挑战极限。在这种情况下,你可以用 C++ 编写自定义 CUDA 内核,编译它,并通过自定义扩展将其集成到 PyTorch 中。像这样的项目展示了人们如何构建专用内核以实现最大速度。
自定义 CUDA 中的 Softmax
你通常会有一个 setup.py 文件,用于编译 .cu 或 .cpp 文件,并将一个 Python 函数作为扩展暴露出来。
在这篇文章中,我不会提供这种方法的代码,因此这一事实本身就说明了一切。这种方法相当复杂,需要充分的理由,并且通常是你应该尝试的最后一种方法。
很容易写出低效、有错误、不安全的代码。
优点:
缺点:
结论
在加速 PyTorch 操作方面,你可以选择逐步更复杂的方法:
如果你寻求最简单的改进,请从 torch.compile 开始。如果这还不够,可以探索 Triton。对于高级用户,编写自定义 CUDA 内核可以带来进一步的性能提升,但这需要深厚的 GPU 编程技能。