使用Keras 3进行多框架AI/ML开发

2024年06月17日 由 alex 发表 112 0

在这篇文章中,我们将重新审视 Keras,并评估它在当前 AI/ML 开发时代提供的价值。我们将通过示例展示它的易用性并指出它的缺点。


谷歌最近发布了名为 Gemma 的开源 NLP 模型系列,并将 Keras 3 作为 API 的核心组件。


为什么使用 Keras 3?

我们认为,Keras 3 提供的最有价值的功能是它的多框架支持。这可能会让一些读者感到惊讶,因为在他们的印象中,Keras 的独特之处在于其用户体验。Keras 3 宣称自己 "简单"、"灵活","专为人类而非机器设计"。事实上,Keras 早期的成功和如日中天的人气都归功于它的用户体验。但现在是 2024 年,有许多高级深度学习应用程序接口都在提供 "降低认知负荷 "的服务。在我们看来,尽管用户体验可能很好,但这已不再是考虑 Keras 而不是其他替代方案的充分理由。它的多框架支持才是。


多框架支持的优点

Keras 3 支持多个用于训练和运行模型的后端。在撰写本文时,这些后端包括 JAX、TensorFlow 和 PyTorch。Keras 3 的公告很好地解释了这一功能的优势。我们将对记录的优势进行扩展,并添加一些自己的特色。


避免选择人工智能/ML 框架的困难:

选择 AI/ML 框架可能是你作为 ML 开发人员需要做出的最重要决定之一。这也是最难的决定之一。这个决定需要考虑很多因素。其中包括用户体验、API 覆盖范围、可编程性、可调试性、支持的输入数据格式和类型、与开发管道上其他组件的一致性(例如,模型部署阶段可能施加的限制),以及最重要的运行时性能。正如我们在以前的许多文章(如这里)中所讨论的,人工智能/ML 模型开发可能非常昂贵,即使是由于选择框架而导致的最小速度提升,对成本的总体影响也可能是巨大的。事实上,在很多情况下,将模型和代码移植到不同的框架和/或甚至维护对多个框架的支持都是值得的。


问题是,要在开始开发之前就知道哪个框架最适合你的模型,即使不是不可能,也是极其困难的。此外,即使你已经选择了一个框架,你也希望随时了解所有框架的演变和发展情况,并不断评估改进模型和/或降低开发成本的潜在机会。人工智能/ML 开发的前景是非常动态的,不断有优化和增强的设计和开发。


Keras 3 解决了框架选择问题,使你能够开发模型,而无需对底层后端做出承诺。通过在多个框架-后端之间切换的选项,你可以专注于模型定义,并在完成后选择最适合你需求的后端。即使 ML 项目的属性发生变化或支持的框架发生演变,Keras 3 也能让你轻松评估更改后端的影响。


通俗地说,Keras 3 可以帮助人类避免他们最讨厌做的事情之一--做出决定并承诺执行。不过,撇开幽默不谈,使用 Keras 3 进行人工智能/ML 模型开发无疑可以避免选择次优框架。


所有优势:

PyTorch、TensorFlow 和 JAX 都有各自独特的优势和差异化特性。例如,JAX 支持即时(JIT)编译,其中模型运算符被转换为中间计算图,然后一起编译为专门针对底层硬件的机器代码。对于许多模型而言,这将大大提高运行时性能。另一方面,PyTorch 通常以运算符立即执行(又称 "急迫执行")的方式使用,通常被认为:具有最 Pythonic 的界面、最易于调试,并提供最佳的整体用户体验。通过使用 Keras 3,你可以享受到这两个方面的优点。在初始模型开发和调试过程中,你可以将后端设置为 PyTorch,而在生产模式下进行训练时,可以切换到 JAX 以获得最佳性能。


与尽可能多的人工智能加速器和运行环境兼容:

我们的目标是与尽可能多的人工智能加速器和运行环境兼容。在人工智能机器容量受限的时代,这一点尤为重要,因为在不同机器类型之间切换的能力是一个巨大的优势。当你使用 Keras 3 及其多后端(multi-backend)支持进行开发时,你将自动增加可在其上训练和运行模型的平台数量。例如,虽然你可能最习惯于在 GPU 上运行 PyTorch,但只需将后端更改为 JAX,你就可以将模型配置为在 Google Cloud TPU 上运行(尽管这可能取决于模型的细节)。


提高模型采用率:

如果你的目标是让其他人工智能/ML 团队使用你的模型,那么你可以通过支持多个框架来增加潜在受众。出于各种原因,一些团队可能会局限于特定的 ML 框架。通过在 Keras 中交付模型,你可以消除采用模型的障碍。


将数据输入管道与模型执行分离:

有些框架鼓励使用特定的数据存储格式和/或数据加载方法。一个典型的例子是 TensorFlow 的 TFRecord 数据格式,用于存储二进制记录序列,通常存储在 .tfrecord 文件中。虽然 TensorFlow 本身支持解析和处理存储在 TFRecord 文件中的数据,但你可能会发现将它们输入 PyTorch 训练循环会比较困难。更适合 PyTorch 训练的格式可能是 WebDataset。但是,创建训练数据可能是一个漫长的过程,而且以多种格式维护数据的成本可能过高。因此,训练数据的存储和维护方式可能会阻碍团队考虑其他框架。


Keras 3 通过将数据输入管道与训练循环完全解耦,帮助团队克服这一障碍。你可以在 PyTorch、TensorFlow、Numpy、Keras 和其他库中定义输入数据管道,而无需考虑将在训练循环中使用的后端。有了 Keras 3,将训练数据存储在 TFRecord 文件中不再是采用 PyTorch 作为后端的障碍。


多框架支持的缺点

与市场上任何其他新的 SW 解决方案一样,了解 Keras 3 的潜在缺点非常重要。SW 开发的一般经验法则是,SW 栈越高,对应用程序行为和性能的控制就越少。在人工智能/ML 领域,成功的程度往往取决于对模型超参数、初始化设置、适当的环境配置等的精确调整,因此这种控制至关重要。以下是一些需要考虑的潜在缺点:


运行时性能可能下降:

使用高级 Keras API 而不是直接使用框架 API,可能会对优化运行时性能造成限制。在我们以分析和优化 PyTorch 模型性能为主题的系列文章中,我们展示了一系列提高训练速度的工具和技术。有时,这些工具和技术需要直接、非中介地使用 PyTorch 的 API。例如,Keras 的 API 目前对 PyTorch 的 JIT 编译选项(通过 jit_compile 设置)的支持非常有限。另一个例子是 PyTorch 对缩放点乘注意的内置支持,而 Keras 则不支持这种支持(截至本文撰写之时)。


跨框架支持的局限性:

虽然 Keras 的跨框架支持非常广泛,但你可能会发现它并非包罗万象。例如,在覆盖范围方面的一个空白(截至本文撰写之时)就是分布式训练。虽然 Keras 引入了 Keras 分布 API 来支持所有后端数据和模型的并行性,但目前仅在 JAX 后端实现了这一功能。要在使用其他后端时运行分布式训练,你需要回到相关框架的标准分布式 API(例如 PyTorch 的分布式数据并行 API)。


维护跨框架兼容性的开销:

Keras 3 支持多种预建模型,你可以重复使用。但不可避免的是,你可能希望引入自己的定制。虽然 Keras 3 支持对模型层、指标、训练循环等进行自定义,但你需要注意不要破坏跨框架兼容性。例如,如果你使用 Keras 的后端无关 API(keras.ops)创建自定义层,你就可以放心地保留多后端支持。不过,有时你可能会选择依赖特定框架的操作。在这种情况下,要保持跨框架兼容性,就需要为每个框架制定专门的实施方案,并根据所使用的后端进行适当的条件编程。当前用于自定义训练步骤和训练循环的方法是针对特定框架的,这意味着它们也需要针对每个后端进行专门的实现,以保持跨框架兼容性。因此,随着模型复杂度的增加,维护这一独特功能所需的开销也会增加。


我们仅指出了 Keras 3 及其多后端(multi-backend)支持的几个潜在缺点。你很可能还会遇到其他问题。虽然多框架产品确实引人注目,但采用它并不一定不需要成本。借用统计推理领域一个著名定理的名称,我们可以说,在选择AI/ML 开发方法时,"天下没有免费的午餐"。


Keras 3 的实践--一个玩具示例

我们将定义的玩具模型是一个由视觉转换器(ViT)支持的分类模型。我们将依赖本 Keras 教程中的参考实现。我们根据 ViT-Base 架构配置了模型(约 8,600 万个参数),将混合精度策略设置为使用 bfloat16,并定义了一个带有随机输入数据的 PyTorch 数据加载器。


下面的模块包括配置设置,以及 ViT 模型核心组件的定义:


import os
# choose backend
backend = 'jax' # 'torch'
os.environ["KERAS_BACKEND"] = backend
import keras
from keras import layers
from keras import ops
# set mixed precision policy
keras.mixed_precision.set_global_policy('mixed_bfloat16')
# use ViT Base settings
num_classes = 1000
image_size = 224
input_shape = (image_size, image_size, 3)
patch_size = 16  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 768
num_heads = 12
transformer_units = [
    projection_dim * 4,
    projection_dim,
]  # Size of the transformer layers
transformer_layers = 12
# set training hyperparams
batch_size = 128
multi_worker = False # toggle to use multiple data loader workers
preproc_workers = 0 if 'jax' else 16
# ViT model components:
# ---------------------
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=keras.activations.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x
class Patches(layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size
    def call(self, images):
        input_shape = ops.shape(images)
        batch_size = input_shape[0]
        height = input_shape[1]
        width = input_shape[2]
        channels = input_shape[3]
        num_patches_h = height // self.patch_size
        num_patches_w = width // self.patch_size
        patches = keras.ops.image.extract_patches(images, size=self.patch_size)
        patches = ops.reshape(
            patches,
            (
                batch_size,
                num_patches_h * num_patches_w,
                self.patch_size * self.patch_size * channels,
            ),
        )
        return patches
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )
    def call(self, patch):
        positions = ops.expand_dims(
            ops.arange(start=0, stop=self.num_patches, step=1), axis=0
        )
        projected_patches = self.projection(patch)
        encoded = projected_patches + self.position_embedding(positions)
        return encoded


利用核心组件,我们定义了一个由 ViT 支持的 Keras 模型:


# the attention layer we will use in our ViT classifier
attention_layer = layers.MultiHeadAttention
def create_vit_classifier():
    inputs = keras.Input(shape=input_shape)
    # Create patches.
    patches = Patches(patch_size)(inputs)
    # Encode patches.
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a multi-head attention layer.
        attention_output = attention_layer(
            num_heads=num_heads, key_dim=projection_dim//num_heads, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])
    # Create a [batch_size, projection_dim] tensor.
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.GlobalAveragePooling1D()(representation)
    representation = layers.Dropout(0.5)(representation)
    # Classify outputs.
    logits = layers.Dense(num_classes)(representation)
    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=logits)
    return model
# create the ViT model
model = create_vit_classifier()
model.summary()


接下来,我们将定义优化器、损失和数据集。


model.compile(compile(
    optimizer=keras.optimizers.SGD(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    )
def get_data_loader(batch_size):
    import torch
    from torch.utils.data import Dataset, DataLoader
    # create dataset of random image and label data
    class FakeDataset(Dataset):
        def __len__(self):
            return 1000000
        def __getitem__(self, index):
            rand_image = torch.randn([224, 224, 3], dtype=torch.float32)
            label = torch.tensor(data=[index % 1000], dtype=torch.int64)
            return rand_image, label
    ds = FakeDataset()
    dl = DataLoader(
        ds,
        batch_size=batch_size,
        num_workers=preproc_workers if multi_worker else 0,
        pin_memory=True
    )
    return dl
dl = get_data_loader(batch_size)


最后,我们使用 Keras 的 Model.fit() 函数开始训练:


model.fit(
    dl,
    batch_size=batch_size,
    epochs=11
)


我们在谷歌云平台(GCP)g2-standard-16 虚拟机(配有单个英伟达 L4 GPU)上运行了上述脚本,该虚拟机配有专用的深度学习虚拟机镜像(common-cu121-v20240514-ubuntu-2204-py310),并安装了 PyTorch (2.3.0)、JAX (0.4.28)、Keras (3.3.3) 和 KerasCV (0.9.0)。


 formatted += f" {time_per_unit:.3f}s/{unit_name}"f" {time_per_unit:.3f}s/{unit_name}"


使用后端标志,我们可以在 Keras 支持的后端之间轻松切换,并比较每个后端的运行时性能。例如,当配置 PyTorch 数据加载器时使用 0 个 Worker,我们发现 JAX 后端的性能比 PyTorch 高出约 24%。当设置 16 个 Worker 时,性能下降到 12%。


自定义注意层

现在,我们定义了一个自定义注意力层,用 PyTorch 的闪存注意力实现取代 Keras 的默认注意力计算。请注意,这只有在后端设置为 torch 时才会起作用。


class MyAttention(layers.MultiHeadAttention):
    def _compute_attention(
            self, query, key, value, attention_mask=None, training=None
    ):
        from torch.nn.functional import scaled_dot_product_attention
        query = ops.multiply(
            query, ops.cast(self._inverse_sqrt_key_dim, query.dtype))
        return scaled_dot_product_attention(
            query.transpose(1,2),
            key.transpose(1,2),
            value.transpose(1,2),
            dropout_p=self._dropout if training else 0.
            ).transpose(1,2), None
attention_layer = MyAttention


下表总结了我们的实验结果。请注意,根据模型和运行环境的细节,相对性能结果可能会有很大差异。


10


当使用我们的自定义关注层时,JAX 和 PyTorch 后端之间的差距几乎消失了。这凸显了多后端解决方案的使用是如何以牺牲任何单个框架(在我们的例子中是 PyTorch SDPA)所支持的独特优化为代价的。


Gemma 中的 Keras 3

Gemma 是谷歌最近发布的轻量级开源模型系列。Keras 3 在 Gemma 发布中扮演着重要角色(例如,请参阅此处),它的多框架支持使 Gemma 能够自动供 PyTorch、TensorFlow 和 Jax 等各种人工智能/ML 开发人员使用。有


以下代码大致基于 Gemma 官方微调教程。要运行脚本,请遵循必要的设置说明。


import os
backend = 'jax' #'torch'
os.environ["KERAS_BACKEND"] = backend
num_batches = 1000
batch_size = 4 if backend == 'jax' else 2
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
os.environ["KAGGLE_USERNAME"]="chaimrand"
os.environ["KAGGLE_KEY"]="29abebb28f899a81ca48bec1fb97faf1"
import keras
import keras_nlp
keras.mixed_precision.set_global_policy('mixed_bfloat16')
import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))
# Only use 1000 training batches, to keep it fast.
data = data[:num_batches*batch_size]
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
gemma_lm.compile(
   loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
   optimizer=keras.optimizers.SGD(learning_rate=5e-5),
   weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=batch_size)


在上述相同的 GCP 环境中运行脚本时,我们发现使用 JAX 后端时的运行性能(每秒 6.87 个采样点)与使用 PyTorch 后端时的运行性能(每秒 3.01 个采样点)之间存在显著差异(令人惊讶)。这部分是由于 JAX 后端允许将训练批次大小增加一倍。


在上一个示例中,我们演示了优化 PyTorch 运行时的一种方法,即在脚本顶部预置以下矩阵乘法操作配置:


import torch
torch.set_float32_matmul_precision('high')


在使用 PyTorch 后端运行时,这一简单改动使性能提升了 29%。我们可以再次看到应用特定框架优化所带来的影响。下表总结了实验结果。


11


结论

我们的演示表明,坚持使用与后端无关的 Keras 代码可能会带来显著的运行时性能损失。在每个示例中,我们都看到了针对特定框架的简单优化是如何对我们所选后端的相对性能产生重大影响的。同时,我们所讨论的多框架 AI/ML 开发的论据也相当有说服力。

文章来源:https://medium.com/towards-data-science/multi-framework-ai-ml-development-with-keras-3-cf7be29eb23d
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消