GAN实践:从零创建人脸指南

2025年01月21日 由 alex 发表 1572 0

在人工智能领域,很少有创新能像生成对抗网络(GAN)一样吸引研究人员和创作者的想象力。Ian Goodfellow在2014年提出了GAN,它彻底改变了机器生成数据的方式,从超真实图像到合成音乐,甚至包括深度伪造视频。GAN的核心是两个神经网络(生成器和判别器)之间令人着迷的相互作用,它们在一个博弈论的舞蹈中相互竞争,试图智胜对方。


如果你曾经好奇过AI是如何创作出栩栩如生的艺术,或者将简单的草图转变为照片级真实的风景,那么你正在见证GAN的力量。在这篇文章中,我们将揭开GAN背后的神秘面纱,并指导你使用PyTorch(最受欢迎的深度学习框架之一)来实现它们。无论你是AI爱好者还是初出茅庐的开发者,这段循序渐进的旅程都将为你提供开始构建自己生成模型所需的知识。


8


GAN由两个神经网络组成:一个生成器和一个判别器,它们在博弈论的框架下协同工作。

  • 生成器:从随机噪声中创建合成数据。它学习模仿真实数据的分布。
  • 判别器:充当评论家,区分真实数据(来自训练集)和假数据(由生成器生成)。


这两个网络相互竞争:

  • 生成器不断改进,以生成看起来更真实的数据。
  • 判别器提高其区分真假的能力。


这个对抗过程推动两个网络不断迭代改进,最终生成器能够产生高度真实的输出。


让我们以人脸为例来看看这是如何工作的。


1. 数据集:

你需要一个人脸数据集,比如CelebA数据集,其中包含数千张名人面孔。


2. 初始化:

  • 生成器开始通过创建随机噪声(例如,一个随机数向量)作为其输入。
  • 判别器使用真实人脸(来自数据集)和假人脸(由生成器生成)进行训练。


3. 训练:

  • 生成器从噪声中创建一张脸,然后传递给判别器。
  • 判别器评估这张脸是真实的还是假的,并提供反馈(一个概率分数)。
  • 两个网络调整其参数:
  • 生成器学习创建更能欺骗判别器的脸。
  • 判别器提高其检测假脸的能力。


4. 对抗博弈:

经过多次迭代,生成器在生成真实人脸方面变得更好,而判别器则变得更难被欺骗。这个“游戏”持续进行,直到生成器生成的人脸与真实人脸无法区分。


生成器逐渐学习数据集中的复杂模式,例如:

  • 形状:人脸的结构。
  • 细节:眼睛、鼻子、嘴巴和其他特征。
  • 纹理和光照:肤色、头发纹理和阴影。


例如,在训练初期,生成的人脸可能模糊或特征扭曲。随着训练的进行,人脸变得更清晰、更真实,反映了训练数据集的多样性和特征。


GAN广泛应用于以下应用:

  • 深度伪造:逼真的换脸。
  • 艺术和设计:生成艺术品或头像。
  • 数据增强:扩展用于训练其他模型的数据集。
  • 修复:重建损坏或低分辨率的照片。


通过在人脸数据集上训练GAN,你可以生成完全新的、逼真的人脸,这些人脸不属于任何真实的人——这是生成建模领域的一个令人着迷的飞跃。


首先,让我们为这个项目创建一个新的环境!


导入库


import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
from PIL import Image
import os
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np


然后创建一个类来加载数据集。


class CelebADataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        
        # Get all image file paths from the directory
        self.image_paths = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.jpg')]
    def __len__(self):
        return len(self.image_paths)
    def __getitem__(self, idx):
        # Load image
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        
        # Apply the transform if provided
        if self.transform:
            image = self.transform(image)
        
        return image


这段代码定义了一个自定义的数据集类 CelebADataset,用于加载和处理图像,特别适用于与 PyTorch 的数据加载工具(如 DataLoader)一起使用。


# Define transformations (resize, crop, convert to tensor, normalize)
transform = transforms.Compose([
    transforms.Resize(64),  # Resize images to 64x64
    transforms.CenterCrop(64),  # Crop center to 64x64
    transforms.ToTensor(),  # Convert images to tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])
# Load CelebA dataset from the specified directory
dataset_path = r'C:\Users\Harish\Documents\Github\GAN\GAN_Tutorial\img_align_celeba'
dataset = CelebADataset(root_dir=dataset_path, transform=transform)
# Create DataLoader
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
# Check how many images are loaded
print(f"Total number of images loaded: {len(dataset)}")


9


下一步是创建生成器和判别器。


# Generator and Discriminator classes (same as previously described)
class Generator(nn.Module):
    def __init__(self, z_dim=100, img_channels=3):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, img_channels * 64 * 64),
            nn.Tanh()
        )
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 3, 64, 64)  # Reshape to image format
        return img
class Discriminator(nn.Module):
    def __init__(self, img_channels=3):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(img_channels * 64 * 64, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, img):
        return self.model(img)


生成器类

生成器负责从随机噪声中创建合成数据(在此情况下为图像)。


关键特性

  • 输入:大小为 z_dim 的随机噪声向量 z。
  • 输出:尺寸为 (3, 64, 64) 的合成图像(3 个通道用于 RGB,64x64 像素)。


__init__ 方法

目的:定义生成器模型的层。


nn.Linear(z_dim, 256):

全连接层,将噪声向量(z_dim)映射到大小为 256 的隐藏空间。


nn.ReLU(True):

应用 ReLU 激活函数以引入非线性,使网络能够建模复杂模式。参数 True 启用就地操作,节省内存。


nn.Linear(256, 512) → nn.Linear(512, 1024) → nn.Linear(1024, img_channels * 64 * 64):

额外的全连接层逐步增加维度,将噪声向量转换为适合图像生成的高维空间。


nn.Tanh():

最终激活函数将输出图像的像素值缩放到范围 [-1, 1],这是归一化图像数据的常见做法。


forward 方法

目的:定义如何将输入噪声向量转换为图像。


img = self.model(z):

将噪声向量 z 通过网络传递,得到一个表示像素值的展平输出。


img = img.view(img.size(0), 3, 64, 64):

将展平输出重塑为图像格式 (batch_size, 3, 64, 64)。


img.size(0):指批次大小。


3, 64, 64:所需的图像尺寸(RGB,64x64 像素)。


输出

表示合成图像的张量。


判别器类

判别器通过输出概率(真实 = 1,虚假 = 0)来区分真实图像和虚假图像。


关键特性

  • 输入:尺寸为 (3, 64, 64) 的图像张量。
  • 输出:批次中每个图像的标量概率值。


__init__ 方法

目的:定义判别器模型的层。


nn.Flatten():

将输入图像张量 (3, 64, 64) 转换为大小为 3 * 64 * 64(12,288 个值)的展平向量。

这对于全连接层的处理是必要的。


nn.Linear(img_channels * 64 * 64, 1024):

第一个全连接层,将输入维度从 12,288 减少到 1,024。


nn.LeakyReLU(0.2, inplace=True):

LeakyReLU 激活函数引入非线性,为负输入提供小梯度(通过 0.2 的斜率)以避免神经元死亡。

就地操作节省内存。


后续层:

  • 1024→512→256→1\text{1024} \to \text{512} \to \text{256} \to \text{1}1024→512→256→1:全连接层逐步降低维数。
  • 目的:捕捉层次特征以区分真假。


nn.Sigmoid():

最终激活函数输出范围 [0, 1] 内的概率,其中接近 1 的值表示图像是真实的,接近 0 的值表示图像是虚假的。


forward 方法

目的:定义如何将图像分类为真实或虚假。


self.model(img):

通过网络处理输入图像,为每个图像生成一个标量概率。


输出

批次中概率 ∈ [0, 1] 的张量。


它们如何协同工作

训练流程:

  • 生成器从随机噪声创建合成图像。
  • 判别器评估真实图像和合成图像,预测每个图像是真实还是虚假。


对抗目标:

  • 生成器学习改进其输出以“欺骗”判别器。
  • 判别器学习更好地识别虚假图像。


损失函数:

  • 生成器:优化以最大化判别器将虚假图像预测为“真实”的可能性。
  • 判别器:优化以正确预测所有输入的真实与虚假。


# Loss function and optimizers
adversarial_loss = nn.BCELoss()
generator = Generator(z_dim=100)
discriminator = Discriminator()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = generator.to(device)
discriminator = discriminator.to(device)


提供的代码片段为训练生成对抗网络(GAN)设置了损失函数、优化器和设备配置。


使用二元交叉熵损失来衡量预测概率与真实二元标签之间的误差。


def save_generated_images(generator, epoch, device, num_images=16):
    z = torch.randn(num_images, 100).to(device)
    generated_imgs = generator(z).detach().cpu()
    grid = torchvision.utils.make_grid(generated_imgs, nrow=4, normalize=True)
    plt.imshow(np.transpose(grid, (1, 2, 0)))
    plt.title(f"Epoch {epoch}")
    plt.axis('off')
    plt.show()
# Training loop
def train(generator, discriminator, dataloader, epochs=5):
    for epoch in range(epochs):
        for i, imgs in enumerate(dataloader):
            real_imgs = imgs.to(device)
            batch_size = real_imgs.size(0)
            valid = torch.ones(batch_size, 1).to(device)
            fake = torch.zeros(batch_size, 1).to(device)
            # Train Discriminator
            optimizer_D.zero_grad()
            real_loss = adversarial_loss(discriminator(real_imgs), valid)
            fake_loss = adversarial_loss(discriminator(generator(torch.randn(batch_size, 100).to(device)).detach()), fake)
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()
            # Train Generator
            optimizer_G.zero_grad()
            g_loss = adversarial_loss(discriminator(generator(torch.randn(batch_size, 100).to(device))), valid)
            g_loss.backward()
            optimizer_G.step()
            if i % 50 == 0:
                print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")
        # Optionally, save generated images at each epoch
        save_generated_images(generator, epoch, device)


训练循环在训练判别器和生成器之间交替进行。判别器通过最小化真实图像(标记为真实)和假图像(标记为假)的损失,被训练来区分真实图像和假图像。生成器则通过欺骗判别器,使假图像被标记为真实时损失最小化,从而被训练来创建逼真的假图像。


这个对抗过程会重复多个周期,并使用各自的优化器更新两个模型。定期保存生成的图像,以跟踪生成器的进展。


生成的图像:


10


11


我们需要训练多个周期,让模型能够理解复杂的细节,并像原始图像一样重新创建出来。为了这篇文章,我将训练停止在10个周期。



文章来源:https://medium.com/gitconnected/creating-human-faces-from-scratch-a-hands-on-guide-to-gans-2b374c173a65
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消