在人工智能领域,很少有创新能像生成对抗网络(GAN)一样吸引研究人员和创作者的想象力。Ian Goodfellow在2014年提出了GAN,它彻底改变了机器生成数据的方式,从超真实图像到合成音乐,甚至包括深度伪造视频。GAN的核心是两个神经网络(生成器和判别器)之间令人着迷的相互作用,它们在一个博弈论的舞蹈中相互竞争,试图智胜对方。
如果你曾经好奇过AI是如何创作出栩栩如生的艺术,或者将简单的草图转变为照片级真实的风景,那么你正在见证GAN的力量。在这篇文章中,我们将揭开GAN背后的神秘面纱,并指导你使用PyTorch(最受欢迎的深度学习框架之一)来实现它们。无论你是AI爱好者还是初出茅庐的开发者,这段循序渐进的旅程都将为你提供开始构建自己生成模型所需的知识。
GAN由两个神经网络组成:一个生成器和一个判别器,它们在博弈论的框架下协同工作。
这两个网络相互竞争:
这个对抗过程推动两个网络不断迭代改进,最终生成器能够产生高度真实的输出。
让我们以人脸为例来看看这是如何工作的。
1. 数据集:
你需要一个人脸数据集,比如CelebA数据集,其中包含数千张名人面孔。
2. 初始化:
3. 训练:
4. 对抗博弈:
经过多次迭代,生成器在生成真实人脸方面变得更好,而判别器则变得更难被欺骗。这个“游戏”持续进行,直到生成器生成的人脸与真实人脸无法区分。
生成器逐渐学习数据集中的复杂模式,例如:
例如,在训练初期,生成的人脸可能模糊或特征扭曲。随着训练的进行,人脸变得更清晰、更真实,反映了训练数据集的多样性和特征。
GAN广泛应用于以下应用:
通过在人脸数据集上训练GAN,你可以生成完全新的、逼真的人脸,这些人脸不属于任何真实的人——这是生成建模领域的一个令人着迷的飞跃。
首先,让我们为这个项目创建一个新的环境!
导入库
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
from PIL import Image
import os
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
然后创建一个类来加载数据集。
class CelebADataset(Dataset):
def __init__(self, root_dir, transform=None):
"""
Args:
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied on a sample.
"""
self.root_dir = root_dir
self.transform = transform
# Get all image file paths from the directory
self.image_paths = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.jpg')]
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# Load image
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('RGB')
# Apply the transform if provided
if self.transform:
image = self.transform(image)
return image
这段代码定义了一个自定义的数据集类 CelebADataset,用于加载和处理图像,特别适用于与 PyTorch 的数据加载工具(如 DataLoader)一起使用。
# Define transformations (resize, crop, convert to tensor, normalize)
transform = transforms.Compose([
transforms.Resize(64), # Resize images to 64x64
transforms.CenterCrop(64), # Crop center to 64x64
transforms.ToTensor(), # Convert images to tensor
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize to [-1, 1]
])
# Load CelebA dataset from the specified directory
dataset_path = r'C:\Users\Harish\Documents\Github\GAN\GAN_Tutorial\img_align_celeba'
dataset = CelebADataset(root_dir=dataset_path, transform=transform)
# Create DataLoader
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
# Check how many images are loaded
print(f"Total number of images loaded: {len(dataset)}")
下一步是创建生成器和判别器。
# Generator and Discriminator classes (same as previously described)
class Generator(nn.Module):
def __init__(self, z_dim=100, img_channels=3):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(z_dim, 256),
nn.ReLU(True),
nn.Linear(256, 512),
nn.ReLU(True),
nn.Linear(512, 1024),
nn.ReLU(True),
nn.Linear(1024, img_channels * 64 * 64),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), 3, 64, 64) # Reshape to image format
return img
class Discriminator(nn.Module):
def __init__(self, img_channels=3):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(img_channels * 64 * 64, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img):
return self.model(img)
生成器类
生成器负责从随机噪声中创建合成数据(在此情况下为图像)。
关键特性
__init__ 方法
目的:定义生成器模型的层。
nn.Linear(z_dim, 256):
全连接层,将噪声向量(z_dim)映射到大小为 256 的隐藏空间。
nn.ReLU(True):
应用 ReLU 激活函数以引入非线性,使网络能够建模复杂模式。参数 True 启用就地操作,节省内存。
nn.Linear(256, 512) → nn.Linear(512, 1024) → nn.Linear(1024, img_channels * 64 * 64):
额外的全连接层逐步增加维度,将噪声向量转换为适合图像生成的高维空间。
nn.Tanh():
最终激活函数将输出图像的像素值缩放到范围 [-1, 1],这是归一化图像数据的常见做法。
forward 方法
目的:定义如何将输入噪声向量转换为图像。
img = self.model(z):
将噪声向量 z 通过网络传递,得到一个表示像素值的展平输出。
img = img.view(img.size(0), 3, 64, 64):
将展平输出重塑为图像格式 (batch_size, 3, 64, 64)。
img.size(0):指批次大小。
3, 64, 64:所需的图像尺寸(RGB,64x64 像素)。
输出
表示合成图像的张量。
判别器类
判别器通过输出概率(真实 = 1,虚假 = 0)来区分真实图像和虚假图像。
关键特性
__init__ 方法
目的:定义判别器模型的层。
nn.Flatten():
将输入图像张量 (3, 64, 64) 转换为大小为 3 * 64 * 64(12,288 个值)的展平向量。
这对于全连接层的处理是必要的。
nn.Linear(img_channels * 64 * 64, 1024):
第一个全连接层,将输入维度从 12,288 减少到 1,024。
nn.LeakyReLU(0.2, inplace=True):
LeakyReLU 激活函数引入非线性,为负输入提供小梯度(通过 0.2 的斜率)以避免神经元死亡。
就地操作节省内存。
后续层:
nn.Sigmoid():
最终激活函数输出范围 [0, 1] 内的概率,其中接近 1 的值表示图像是真实的,接近 0 的值表示图像是虚假的。
forward 方法
目的:定义如何将图像分类为真实或虚假。
self.model(img):
通过网络处理输入图像,为每个图像生成一个标量概率。
输出
批次中概率 ∈ [0, 1] 的张量。
它们如何协同工作
训练流程:
对抗目标:
损失函数:
# Loss function and optimizers
adversarial_loss = nn.BCELoss()
generator = Generator(z_dim=100)
discriminator = Discriminator()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = generator.to(device)
discriminator = discriminator.to(device)
提供的代码片段为训练生成对抗网络(GAN)设置了损失函数、优化器和设备配置。
使用二元交叉熵损失来衡量预测概率与真实二元标签之间的误差。
def save_generated_images(generator, epoch, device, num_images=16):
z = torch.randn(num_images, 100).to(device)
generated_imgs = generator(z).detach().cpu()
grid = torchvision.utils.make_grid(generated_imgs, nrow=4, normalize=True)
plt.imshow(np.transpose(grid, (1, 2, 0)))
plt.title(f"Epoch {epoch}")
plt.axis('off')
plt.show()
# Training loop
def train(generator, discriminator, dataloader, epochs=5):
for epoch in range(epochs):
for i, imgs in enumerate(dataloader):
real_imgs = imgs.to(device)
batch_size = real_imgs.size(0)
valid = torch.ones(batch_size, 1).to(device)
fake = torch.zeros(batch_size, 1).to(device)
# Train Discriminator
optimizer_D.zero_grad()
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(generator(torch.randn(batch_size, 100).to(device)).detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
# Train Generator
optimizer_G.zero_grad()
g_loss = adversarial_loss(discriminator(generator(torch.randn(batch_size, 100).to(device))), valid)
g_loss.backward()
optimizer_G.step()
if i % 50 == 0:
print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")
# Optionally, save generated images at each epoch
save_generated_images(generator, epoch, device)
训练循环在训练判别器和生成器之间交替进行。判别器通过最小化真实图像(标记为真实)和假图像(标记为假)的损失,被训练来区分真实图像和假图像。生成器则通过欺骗判别器,使假图像被标记为真实时损失最小化,从而被训练来创建逼真的假图像。
这个对抗过程会重复多个周期,并使用各自的优化器更新两个模型。定期保存生成的图像,以跟踪生成器的进展。
生成的图像:
我们需要训练多个周期,让模型能够理解复杂的细节,并像原始图像一样重新创建出来。为了这篇文章,我将训练停止在10个周期。