【指南】Wasserstei GAN生成建模的新范式

2024年02月27日 由 alex 发表 285 0

简介

Wasserstein生成对抗网络(WGAN)是一项举足轻重的创新,它解决了传统生成对抗网络(GAN)经常面临的稳定性和收敛性等基础难题。WGAN 由 Arjovsky 等人于 2017 年提出,通过利用 Wasserstein 距离彻底改变了生成模型的训练方法,提供了一个稳健的框架,既提高了生成样本的质量,又增强了样本的多样性。本文将深入探讨 WGAN 的概念基础、优势和实际意义,说明其在更广泛的生成模型背景下的重要意义。


27


WGAN 的概念框架

WGAN 与其前辈的区别在于用 Wasserstein 距离代替 Jensen-Shannon 散度作为其损失函数。瓦瑟斯坦距离,直观地理解为推土机距离,量化了将一种概率分布转换为另一种概率分布所需的最小成本。该指标赋予 WGAN 在训练过程中更平滑、更可靠的梯度信号,即使在真实数据分布和生成数据分布不重叠的情况下,也有助于生成更高质量的样本。


28


与传统 GAN 的一个重要区别是用批评者取代了判别器。与将输入分类为真实或虚假的判别器不同,WGAN 框架中的批评者评估真实样本和生成样本的分布之间的 Wasserstein 距离。这种从分类到估计的转变标志着生成模型处理学习过程的方式发生了根本性变化,从而实现了更细致、更有效的训练动态。


相对于传统 GAN 的优势

WGAN 提供了几个引人注目的优势,可以解决传统 GAN 框架的局限性。首先,它们表现出改进的训练稳定性,降低了对超参数设置和架构选择的敏感性。这种稳定性源于 Wasserstein 距离的特性,即使真实分布和生成分布之间没有重叠,它也能提供有用的梯度信息——这是一个可能阻碍传统 GAN 训练的常见问题。


此外,WGAN 还缓解了模式崩溃问题,即生成器学习产生有限范围的输出,从而无法捕获真实数据分布的多样性的现象。Wasserstein 距离的连续且更有意义的损失景观鼓励生成器探索更广泛的输出,从而增强生成样本的多样性。


WGAN 中损失度量的可解释性也代表了重大进步。与传统 GAN(判别器的准确性不一定与生成样本的质量相关)不同,WGAN 中的批评者损失提供了更直接的收敛性衡量标准,为训练过程和生成数据的质量提供了有价值的见解。


挑战与未来发展

尽管有其优点,WGAN 也带来了新的挑战,主要与计算效率有关。WGAN 的最初实现需要权重裁剪来强制执行 Lipschitz 约束,这对于 Wasserstein 距离的理论属性至关重要。然而,权重裁剪可能会导致优化困难和容量利用率不足。为了解决这个问题,引入带有梯度惩罚的 WGAN (WGAN-GP) 提出了一种替代方法来强制实施 Lipschitz 约束,而无需进行权重裁剪,从而提高训练稳定性和模型性能。


代码

为 Wasserstein 生成对抗网络 (WGAN) 创建完整的代码示例涉及几个步骤,包括定义生成器和批评者的模型架构、准备合成数据集、训练模型以及通过指标和图评估性能。此示例将说明使用 TensorFlow 和 Keras 的基本实现,并使用简单的合成数据集以便于理解。


import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
def build_critic():
    model = keras.Sequential([
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1),
    ])
    return model
def build_generator(latent_dim):
    model = keras.Sequential([
        keras.Input(shape=(latent_dim,)),
        layers.Dense(7 * 7 * 128),
        layers.Reshape((7, 7, 128)),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(1, (7, 7), padding='same', activation='sigmoid'),
    ])
    return model
class WGAN(keras.Model):
    def __init__(self, critic, generator, latent_dim):
        super(WGAN, self).__init__()
        self.critic = critic
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
        self.g_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
        self.critic_loss_tracker = keras.metrics.Mean(name="critic_loss")
        self.generator_loss_tracker = keras.metrics.Mean(name="generator_loss")
    @property
    def metrics(self):
        return [self.critic_loss_tracker, self.generator_loss_tracker]
    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(WGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn
    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]
        # Sampling random points in the latent space
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        # Decoding them to fake images
        generated_images = self.generator(random_latent_vectors)
        # Combining them with real images
        combined_images = tf.concat([generated_images, real_images], axis=0)
        # Assembling labels, discriminating real from fake images
        labels = tf.concat(
            [tf.ones((batch_size, 1)), -tf.ones((batch_size, 1))], axis=0
        )
        # Adding random noise to the labels - important trick!
        labels += 0.05 * tf.random.uniform(tf.shape(labels))
        # Training the critic
        with tf.GradientTape() as tape:
            predictions = self.critic(combined_images)
            d_loss = self.d_loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.critic.trainable_variables)
        self.d_optimizer.apply_gradients(
            zip(grads, self.critic.trainable_variables)
        )
        # Sampling random points in the latent space
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        # Assembling labels that say "all real images"
        misleading_labels = -tf.ones((batch_size, 1))
        # Training the generator (via the critic's model)
        with tf.GradientTape() as tape:
            predictions = self.critic(self.generator(random_latent_vectors))
            g_loss = self.g_loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_variables)
        self.g_optimizer.apply_gradients(
            zip(grads, self.generator.trainable_variables)
        )
        # Update metrics
        self.critic_loss_tracker.update_state(d_loss)
        self.generator_loss_tracker.update_state(g_loss)
        return {
            "critic_loss": self.critic_loss_tracker.result(),
            "generator_loss": self.generator_loss_tracker.result(),
        }
latent_dim = 128
# Prepare the dataset
(x_train, _), (_, _) = keras.datasets.mnist.load_data()
x_train = x_train.astype("float32") / 255.0
x_train = np.expand_dims(x_train, axis=-1)
# Instantiate the critic and generator models
critic = build_critic()
generator = build_generator(latent_dim)
# Instantiate the WGAN model
wgan = WGAN(critic=critic, generator=generator, latent_dim=latent_dim)
# Compile the WGAN model
wgan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
    d_loss_fn=keras.losses.MeanSquaredError(),
    g_loss_fn=keras.losses.MeanSquaredError(),
)
wgan.fit(x_train, batch_size=32, epochs=100)
def generate_and_save_images(model, epoch, test_input):
    predictions = model.generator(test_input, training=False)
    fig = plt.figure(figsize=(4, 4))
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')
    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()
# Generate latent points
random_latent_vectors = tf.random.normal(shape=(16, latent_dim))
generate_and_save_images(wgan, 0, random_latent_vectors)


1875/1875 [==============================] - 28s 15ms/step - critic_loss: 0.5405 - generator_loss: 2.4530
Epoch 99/100
1875/1875 [==============================] - 28s 15ms/step - critic_loss: 0.5408 - generator_loss: 2.4463
Epoch 100/100
1875/1875 [==============================] - 28s 15ms/step - critic_loss: 0.5384 - generator_loss: 2.4411


29


这段代码使用一个简单的数据集,为使用TensorFlow和Keras实现WGAN提供了一个基础框架。对于实际应用程序,您可能需要调整数据集、体系结构和训练参数以满足您的特定需求。


结论

Wasserstein 生成对抗网络代表了生成建模领域的重大飞跃。通过将 Wasserstein 距离集成到 GAN 框架中,WGAN 为训练生成模型提供了更稳定、可靠和可解释的方法。尽管存在与计算需求和 Lipschitz 约束的执行相关的挑战,但 WGAN 及其后续迭代(如 WGAN-GP)所带来的进步继续影响着生成模型的发展。随着该领域研究的进展,WGAN 有望进一步释放生成模型在从图像合成到自然语言生成等众多应用中的潜力,预示着人工智能驱动的创造力和创新的新时代。

文章来源:https://medium.com/the-modern-scientist/unveiling-the-potential-of-wasserstein-generative-adversarial-networks-a-new-paradigm-in-6823fd7c82ca
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
写评论取消
回复取消