Wasserstein GAN (WGAN):分析指南

2024年03月25日 由 alex 发表 2409 0

简介

生成对抗网络(GAN)的出现标志着生成建模领域的一个重要里程碑。然而,GAN 经常面临训练稳定性问题,导致 Wasserstein GAN (WGAN) 的诞生。 WGAN,由 Arjovsky 等人提出。 2017 年,通过重新制定用于训练 GAN 的损失函数来解决这些不稳定性,提供了对传统 GAN 架构的理论和实践改进。


1


标准 GAN 的问题

标准 GAN 的训练过程包括一个相互竞争的鉴别器和一个生成器。鉴别器学习如何区分准确数据和虚假数据,而生成器则努力创造与真实数据无异的数据。然而,这种设置往往会导致模式崩溃(生成器产生的输出种类有限)和训练不稳定等问题,从而导致众所周知的梯度消失问题。


Wasserstein距离

WGAN 引入了 Wasserstein 距离的概念,也称为 Earth-Mover (EM) 距离,用于测量数据分布与生成器创建的分布之间的差异。EM 距离能为生成器提供更平滑的梯度信号,因为它能测量将一个分布转化为另一个分布需要移动多少 "质量 "以及移动的距离。在两个分布不重叠或仅有轻微重叠的情况下,这种距离更为有效。


WGAN 方法

WGAN 框架的核心是将传统 GAN 的损失函数替换为最小化 Wasserstein 距离的函数。WGAN 建议对判别器的权重(在 WGAN 术语中称为批评者)进行剪切,以强制执行瓦瑟斯坦距离所需的 Lipschitz 约束。这就对批判者的能力施加了软约束,通过提供更有意义的梯度来帮助稳定训练过程。


WGAN 的结果和优势

WGAN 在训练过程中表现出更高的稳定性,不易出现 GAN 的常见问题(如模式崩溃)。此外,与传统的 GAN 损失函数相比,Wasserstein 距离提供了在训练过程中生成样本质量的有用衡量标准,与生成图像的视觉质量更相关。此外,WGAN 的训练过程趋向于更可靠地收敛,从而形成更平滑的学习曲线。


代码

创建 Wasserstein GAN(WGAN)的完整 Python 实现涉及多个步骤,包括设置合成数据集、定义生成器和批评者(WGAN 版本的判别器)、训练网络和评估结果。


然而,从头开始生成一个 WGAN 是相当复杂的,而且代码可能很长。由于 GAN(尤其是 WGAN)的复杂性,训练过程通常会耗费大量资源和时间,因此可能不适合在这里立即执行。


2


尽管如此,我还是要概述一下步骤,并提供一个简化示例,说明如何用 Python 创建 WGAN。你通常会使用 TensorFlow 或 PyTorch 等机器学习框架来编写完整的可运行代码。在此,由于环境限制,我将提供带有伪代码元素的概念大纲:


用 Python 实现 WGAN 的步骤


生成合成数据集: 使用 numpy 或 scipy 创建一个合成数据集,WGAN 可以从中学习。


定义生成器和批判模型: 使用 TensorFlow 或 PyTorch 等框架定义生成器和批判器的神经网络架构。


定义损失函数和优化器: 损失函数将基于 Wasserstein 距离。最好使用支持梯度剪切的优化器来执行 Lipschitz 约束。


训练循环:


  • 在每次迭代中,训练批判者的次数要多于训练生成器的次数(如 WGAN 论文中建议的那样)。
  • 通过提升随机梯度来更新批判者。
  • 每次梯度更新后,确保批判者的权重被剪切到一个很小的固定范围。
  • 通过降低随机梯度更新生成器。


评估结果: 评估生成器生成的图像质量。如果适用,可使用适合 GAN 的指标,如入门分数 (IS) 或弗雷谢特入门距离 (FID)。


绘制结果图: 可视化随时间变化的损失和生成图像的质量。


import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
# Generate synthetic dataset
def generate_synthetic_data(n_samples=1000):
    X, y = make_moons(n_samples=n_samples, noise=0.05)
    return X, y
# Using the function to generate the dataset
X, y = generate_synthetic_data()
# Visualizing the dataset
plt.figure(figsize=(8, 8))
plt.scatter(X[:, 0], X[:, 1], c=y, cmap='viridis', edgecolor='k')
plt.title('Synthetic Dataset for WGAN')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.grid(True)
plt.show()

from tensorflow.keras import layers, models
# Define the Generator Model
def make_generator_model(input_dim, output_dim):
    model = models.Sequential()
    model.add(layers.Dense(128, input_dim=input_dim, activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(256, activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(512, activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(output_dim, activation='tanh'))  # 'tanh' can be used for normalized data
    return model
# Define the Critic Model
def make_critic_model(input_dim):
    model = models.Sequential()
    model.add(layers.Dense(512, input_dim=input_dim, activation='leaky_relu'))
    model.add(layers.Dropout(0.3))
    model.add(layers.Dense(256, activation='leaky_relu'))
    model.add(layers.Dropout(0.3))
    model.add(layers.Dense(128, activation='leaky_relu'))
    model.add(layers.Dense(1))  # No activation function since this is not a classification problem
    return model
# Input dimensions for the generator
generator_input_dim = 100  # Dimension of the random noise
generator_output_dim = 2   # This should match our data's dimensionality
# Create the generator and critic models
generator = make_generator_model(generator_input_dim, generator_output_dim)
critic = make_critic_model(generator_output_dim)

import tensorflow as tf
# Critic Loss
def critic_loss(real_output, fake_output):
    return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)
# Generator Loss
def generator_loss(fake_output):
    return -tf.reduce_mean(fake_output)
# Optimizers
generator_optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.0005)
critic_optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.0005)
# For WGAN, the critic's weights need to be clipped to a small range to enforce Lipschitz constraint
# This can be done after each critic update during training like this:
# for w in critic.trainable_variables:
#     w.assign(tf.clip_by_value(w, -clip_value, clip_value))

# Assuming we have defined `generator`, `critic`, `generator_loss`, `critic_loss`,
# `generator_optimizer`, `critic_optimizer`, and the dataset `X`
# Hyperparameters
epochs = 10000
batch_size = 32
critic_iterations = 5  # Number of critic updates per generator update
clip_value = 0.01  # Value for weight clipping in WGAN
# Training Loop
for epoch in range(epochs):
    for _ in range(critic_iterations):
        # Sample a batch of real data
        idx = np.random.randint(0, X.shape[0], batch_size)
        real_data = X[idx]
        # Generate a batch of fake data
        noise = tf.random.normal([batch_size, generator_input_dim])
        fake_data = generator(noise, training=True)
        # Critic update
        with tf.GradientTape() as critic_tape:
            real_output = critic(real_data, training=True)
            fake_output = critic(fake_data, training=True)
            c_loss = critic_loss(real_output, fake_output)
        critic_gradients = critic_tape.gradient(c_loss, critic.trainable_variables)
        critic_optimizer.apply_gradients(zip(critic_gradients, critic.trainable_variables))
        # Apply weight clipping to critic weights to enforce Lipschitz constraint
        for w in critic.trainable_variables:
            w.assign(tf.clip_by_value(w, -clip_value, clip_value))
    # Generator update
    noise = tf.random.normal([batch_size, generator_input_dim])
    with tf.GradientTape() as gen_tape:
        generated_data = generator(noise, training=True)
        gen_output = critic(generated_data, training=True)
        g_loss = generator_loss(gen_output)
    generator_gradients = gen_tape.gradient(g_loss, generator.trainable_variables)
    generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
    # Log the progress
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Critic Loss: {c_loss.numpy()}, Generator Loss: {g_loss.numpy()}")


3


你提供的图表似乎显示了两个数据点集群,这可能代表了在合成数据集上训练的 WGAN 的真实数据和生成数据。这两个不同的数据集让人联想到 make_moons 数据集,由于其非线性可分性,make_moons 数据集是机器学习中常见的测试数据集。


在这个散点图中:


  • 紫色显示的是一个数据点群,它可能代表来自 make_moons 函数的真实数据。该数据集群呈新月形。
  • 另一个黄色的数据集群可能是 WGAN 生成的数据。它也形成了一个新月形,似乎与真实数据形成的形状大致接近。


总之,这种可视化分析提供了证据,证明 WGAN 正在发挥作用,可以生成与目标分布相似的数据,但可能需要进一步的训练或超参数调整,才能更准确地复制真实数据分布。


4


这些图表说明了 Wasserstein GAN 的批评者和生成器部分在 10,000 个历时期内的训练损失曲线。


批评者损失曲线:

  • 批判者损失曲线显示为蓝色,并在训练期间波动。这是 WGAN 训练中的预期行为,因为批判者(在其他 GAN 框架中也称为判别器)会不断改进区分真实数据和虚假数据的能力。
  • 批判者损失的波动表明,训练过程是动态的,批判者会根据生成器生成的伪造数据质量的逐步提高进行调整。
  • 批判者损失似乎并没有收敛到一个明显较低的值,这表明批判者的性能随着时间的推移趋于稳定。损失徘徊在一个稳定的范围内,表明批判者在对抗训练过程中有效地发挥了作用。


发电机损耗曲线:

  • 生成器损耗曲线以红色显示,损耗最初急剧增加,表明生成器开始学习并适应批判者的反馈。
  • 在最初的学习阶段之后,信号发生器的损耗趋于稳定,并在略高于零的水平上波动。这表明生成器生成的数据越来越真实,负损失值的减少就是证明。
  • 与批评者相比,生成器损失稳定在较高水平的总体趋势表明,生成器在整个训练过程中保持了生成令人信服的数据的能力。


总体解释: 图中显示了一个典型的对抗训练过程,批判者和生成器都在随着时间的推移不断改进。批判者损失稳定且波动一致,这意味着批判者表现良好,而生成器损失稳定在一个较低但积极的值上,这表明生成器有能力生成与真实数据相似的数据。


就 WGAN 训练而言,这些结果通常被认为是成功的,表明对抗过程正在按预期运行。两条曲线都很稳定,没有出现极端的峰值或谷底,这是一个好迹象,意味着生成器正在成功地学习创建批评者认为越来越难以归类为伪造的数据。这通常是训练 GAN(特别是 WGAN)的目标,其目的是实现平衡,即生成器和批评者都不会明显压倒对方。


结论

Wasserstein GAN 是生成模型领域的重大进步。通过解决与标准 GANs 相关的难题,WGANs 为生成更稳定、更可靠的合成数据铺平了道路。Wasserstein 距离和 Lipschitz 约束的引入在实现这些改进方面发挥了重要作用。WGAN 框架激发了进一步的研究和开发,产生了更稳健、更高效的变体,如 WGAN-GP(梯度惩罚),它用梯度惩罚取代了权重剪切,使训练过程更加有效。从本质上讲,WGAN 解决了 GAN 训练中的关键问题,并让人们对生成式建模的基本动态有了更深入的了解。

文章来源:https://medium.com/@evertongomede/enhancing-generative-modeling-a-comprehensive-analysis-of-wasserstein-gan-wgan-81b4c4a5f333
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消