使用JAX进行AI模型训练

2024年05月31日 由 alex 发表 160 0

引擎盖下的 JAX - XLA 编译

让我们直接公开说明这一点:我无意冒犯 JAX,JAX 的真正强大之处在于它使用了XLA编译。JAX 所展示的惊人的运行时性能来自于 XLA 所实现的特定于硬件的优化。许多通常与 JAX 相关的特性和功能,如即时编译(JIT)和 “函数式编程 ”范例,实际上都源自 XLA。事实上,XLA 编译并非 JAX 独有,TensorFlow 和 PyTorch 都支持使用 XLA 的选项。不过,与其他流行框架不同的是,JAX 是自底向上设计来使用 XLA 的。这使得其 JIT、自动微分(grad)、矢量化(vmap)、并行化(pmap)、分片(shard_map)和其他功能(所有这些功能都非常值得尊敬)的设计和实现与底层 XLA 库紧密耦合。


XLA JIT 编译器会对与模型相关的计算图进行全面分析,将连续的张量运算融合到单个内核中,删除多余的计算图组件,并输出对底层加速器最适合的机器代码。这就减少了每个训练步骤的总体机器级运算次数(FLOPS),降低了主机到加速器的通信开销(例如,需要加载到加速器的内核数量减少),减少了内存占用,提高了专用加速器引擎的利用率,等等。


除运行时性能优化外,XLA 的另一个重要特点是其可插拔的基础架构,可将其支持扩展到更多的人工智能加速器。XLA 是 OpenXLA 项目的一部分,由 ML 领域的多个参与者合作构建。


JAX 实际应用--玩具示例

在本文中,我们将演示如何在(单个)GPU 上用 JAX 训练一个玩具人工智能模型,并与 PyTorch 进行比较。现在有很多高级 ML 开发平台都包含多个 ML 框架的后端。这样就可以将 JAX 的性能与其他框架进行比较。在本文中,我们将使用 HuggingFace 的 Transformers 库,其中包括许多常见 Transformer 支持模型的 PyTorch 和 JAX 实现。更具体地说,我们将使用 ViTForImageClassification 模块和 FlaxViTForImageClassification 模块,分别为 PyTorch 和 JAX 实现定义一个视觉转换器(ViT)支持的分类模型。下面的代码块包含模型定义:


import torch
import jax, flax, optax
import jax.numpy as jnp
def get_model(use_jax=False):
    from transformers import ViTConfig
    if use_jax:
        from transformers import FlaxViTForImageClassification as ViTModel
    else:
        from transformers import ViTForImageClassification as ViTModel
    vit_config = ViTConfig(
        num_labels = 1000,
        _attn_implementation = 'eager'  # this disables flash attention
    )
    
    return ViTModel(vit_config)


由于我们在这篇文章中关注的是运行时性能,因此我们将在随机生成的数据集上训练我们的模型。我们将利用 JAX 支持使用 PyTorch 数据加载器这一事实:


def get_data_loader(batch_size, use_jax=False):
    from torch.utils.data import Dataset, DataLoader, default_collate
    # create dataset of random image and label data
    class FakeDataset(Dataset):
        def __len__(self):
            return 1000000
        def __getitem__(self, index):
            if use_jax: # use nhwc
                rand_image = torch.randn([224, 224, 3], dtype=torch.float32)
            else: # use nchw
                rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
            label = torch.tensor(data=[index % 1000], dtype=torch.int64)
            return rand_image, label
    ds = FakeDataset()
    
    if use_jax:  # convert torch tensors to numpy arrays
        def numpy_collate(batch):
            from jax.tree_util import tree_map
            import jax.numpy as jnp
            return tree_map(jnp.asarray, default_collate(batch))
        collate_fn = numpy_collate
    else:
        collate_fn = default_collate
 
    ds = FakeDataset()
    dl = DataLoader(ds, batch_size=batch_size,
                    collate_fn=collate_fn)
    return dl


接下来,我们定义 PyTorch 和 JAX 训练循环。JAX 训练循环依赖于 Flax TrainState 对象,其定义遵循 Flax 中训练 ML 模型的基本教程:


@jax.jit
def train_step_jax(train_state, batch):
    with jax.default_matmul_precision('tensorfloat32'):
        def forward(params):
            logits = train_state.apply_fn({'params': params}, batch[0])
            loss = optax.softmax_cross_entropy(
                logits=logits.logits, labels=batch[1]).mean()
            return loss
        grad_fn = jax.grad(forward)
        grads = grad_fn(train_state.params)
        train_state = train_state.apply_gradients(grads=grads)
        return train_state
def train_step_torch(batch, model, optimizer, loss_fn, device):
    inputs = batch[0].to(device=device, non_blocking=True)
    label = batch[1].squeeze(-1).to(device=device, non_blocking=True)
    outputs = model(inputs)
    loss = loss_fn(outputs.logits, label)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()


现在,让我们将所有内容整合在一起。在下面的脚本中,我们使用 torch.compile 和 torch_xla 来控制 PyTorch 基于图形的 JIT 编译选项:


def train(batch_size, mode, compile_model):
    print(f"Mode: {mode} \n"
          f"Batch size: {batch_size} \n"
          f"Compile model: {compile_model}")
    # init model and data loader
    use_jax = mode == 'jax'
    use_torch_xla = mode == 'torch_xla'
    model = get_model(use_jax)
    train_loader = get_data_loader(batch_size, use_jax)
    if use_jax:
        # init jax settings
        from flax.training import train_state
        params = model.module.init(jax.random.key(0), 
                                   jnp.ones([1, 224, 224, 3]))['params']
        optimizer = optax.sgd(learning_rate=1e-3)
        state = train_state.TrainState.create(apply_fn=model.module.apply,
                                              params=params, tx=optimizer)
    else:
        if use_torch_xla:
            import torch_xla
            import torch_xla.core.xla_model as xm
            import torch_xla.distributed.parallel_loader as pl
            torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
                use_full_mat_mul_precision=False)
       
            device = xm.xla_device()
            backend = 'openxla'
        
            # wrap data loader
            train_loader = pl.MpDeviceLoader(train_loader, device)
        else:
            device = torch.device('cuda')
            backend = 'inductor'
    
        model = model.to(device)
        if compile_model:
            model = torch.compile(model, backend=backend)
        model.train()
        optimizer = torch.optim.SGD(model.parameters())
        loss_fn = torch.nn.CrossEntropyLoss()
    import time
    t0 = time.perf_counter()
    summ = 0
    count = 0
    for step, data in enumerate(train_loader):
        if use_jax:
            state = train_step_jax(state, data)
        else:
            train_step_torch(data, model, optimizer, loss_fn, device)
        # capture step time
        batch_time = time.perf_counter() - t0
        if step > 10:  # skip first steps
            summ += batch_time
        count += 1
        t0 = time.perf_counter()
        if step > 50:
            break
    print(f'average step time: {summ / count}')

if __name__ == '__main__':
    import argparse
    torch.set_float32_matmul_precision('high')
    
    parser = argparse.ArgumentParser(description='Toy Training Script.')
    parser.add_argument('--batch-size', type=int, default=32,
                        help='input batch size for training (default: 2)')
    parser.add_argument('--mode', choices=['pytorch', 'jax', 'torch_xla'],
                        default='jax',
                        help='choose training mode')
    parser.add_argument('--compile-model', action='store_true', default=False,
                        help='whether to apply torch.compile to the model')
    args = parser.parse_args()
    train(**vars(args))


关于基准比较的重要说明

在分析基准比较时,最重要的是我们要对基准比较的进行方式保持极其细致和严谨的态度。在人工智能模型开发中尤其如此,因为根据不准确的数据做出的决定可能会带来极其昂贵的后果。在比较训练模型的运行时性能时,有许多因素会对我们的测量产生主导影响,包括浮点类型精度、矩阵乘法(matmul)精度、数据加载方法、闪存/融合注意力的使用等。例如,如果 PyTorch 的默认 matmul 精度是 float32,而 JAX 的默认 matmul 精度是 tensorfloat32,那么我们就无法从它们的性能比较中学到很多东西。这些设置可以通过 jax.default_matmul_precision 和 torch.set_float32_matmul_precision 等 API 进行控制。在我们的脚本中,我们试图隔离这些潜在问题,但并不保证我们确实成功了。


测试结果

我们在两台谷歌云虚拟机上运行了训练脚本,一台是 g2-standard-16虚拟机(配备单个英伟达 L4 GPU),另一台是 a2-highgpu-1g(配备单个英伟达 A100 GPU),在每种情况下,我们都使用了专用的深度学习虚拟机镜像(common-cu121-v20240514-ubuntu-2204-py310),并安装了 PyTorch (2. 3.0)、JAX (0.4.28)、Flax (0.8.4)、Optax (0.2.2) 等。 3.0)、PyTorch/XLA(2.3.0)、JAX(0.4.28)、Flax(0.8.4)、Optax(0.2.2)和 HuggingFace 的 Transformers 库(4.41.1)。


下表记录了一些实验的运行结果。请注意,根据模型架构和运行环境的不同,比较性能可能会有很大变化。此外,对代码的一些小调整也很可能对结果产生显著影响。


3


4


虽然 JAX 在 L4 GPU 上的性能似乎远超其他同类产品,但在 A100 上却与 PyTorch/XLA 不相上下。考虑到通用的 XLA 后端,这并不奇怪。任何由 JAX 生成的 XLA (HLO) 图形(至少在理论上)都应该可以由 PyTorch/XLA 实现。在这两个平台上,torch.compile 选项都不尽如人意。鉴于我们选择的是全精度浮点数,这在一定程度上是意料之中的。


总结

在这篇文章中,我们探讨了新兴的 JAX ML 开发框架。我们介绍了它对 XLA 编译器的依赖,并在一个玩具示例中演示了它的使用。尽管 PyTorch JIT 编译 API(torch.compile 和 PyTorch/XLA)经常以其快速的运行时执行而著称,但它也支持类似的性能优化潜力。每个选项的相对性能在很大程度上取决于模型的细节和运行环境。

文章来源:https://medium.com/towards-data-science/ai-model-training-with-jax-6e407a7d2dc8
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消