本文主要讨论使用生成式对抗网络实现图像去模糊。
代码:https://github.com/RaphaelMeudec/deblur-gan
在生成对抗网络中,两个网络进行对抗训练。生成式对抗网络通过生成器通过创建逼真的假输入来误导鉴别器。鉴别器鉴别输入是真实的还是伪造的。
GAN训练过程
生成式对抗网络训练主要分为3个步骤:
-使用生成器根据噪声创建假输入。
- 根据真的输入和假的输入训练鉴别器
- 训练整个模型:模型被构建成用鉴别器限制生成器。
注意鉴别器的权重在第三步中要进行冻结。
之所以链接两个网络,是因为对生成器的输出没有合适的反馈。我们唯一的衡量标准是鉴别器是否接受生成的样本。
在本教程中,我们使用GAN进行图像去模糊。因此,生成器的输入不是噪声而是模糊的图像。
数据集是GOPRO数据集。您可以下载一个轻量版(9GB)或完整版(35GB)。它包含来自多个街景的模糊图像。数据集在子文件夹中按场景分类。
我们首先将图像分配到两个文件夹A(模糊)和B(清晰)。
训练过程保持不变。首先,让我们看看神经网络架构!
生成器旨在重现清晰的图像。网络基于ResNet模块。它跟踪应用于原始模糊图像的演变。
DeblurGAN生成网络的结构
核心是用于对原始图像进行重新采样的9个ResNet模块。让我们看看Keras的实现。
from keras.layers import Input, Conv2D, Activation, BatchNormalization
from keras.layers.merge import Add
from keras.layers.core import Dropout
def res_block(input, filters, kernel_size=(3,3), strides=(1,1), use_dropout=False):
"""
Instanciate a Keras Resnet Block using sequential API.
:param input: Input tensor
:param filters: Number of filters to use
:param kernel_size: Shape of the kernel for the convolution
:param strides: Shape of the strides for the convolution
:param use_dropout: Boolean value to determine the use of dropout
:return: Keras Model
"""
x = ReflectionPadding2D((1,1))(input)
x = Conv2D(filters=filters,
kernel_size=kernel_size,
strides=strides,)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
if use_dropout:
x = Dropout(0.5)(x)
x = ReflectionPadding2D((1,1))(x)
x = Conv2D(filters=filters,
kernel_size=kernel_size,
strides=strides,)(x)
x = BatchNormalization()(x)
# Two convolution layers followed by a direct connection between input and output
merged = Add()([input, x])
return merged
from keras.layers import Input, Activation, Add
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.core import Lambda
from keras.layers.normalization import BatchNormalization
from keras.models import Model
from layer_utils import ReflectionPadding2D, res_block
ngf = 64
input_nc = 3
output_nc = 3
input_shape_generator = (256, 256, input_nc)
n_blocks_gen = 9
def generator_model():
"""Build generator architecture."""
# Current version : ResNet block
inputs = Input(shape=image_shape)
x = ReflectionPadding2D((3, 3))(inputs)
x = Conv2D(filters=ngf, kernel_size=(7,7), padding='valid')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# Increase filter number
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
x = Conv2D(filters=ngf*mult*2, kernel_size=(3,3), strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# Apply 9 ResNet blocks
mult = 2**n_downsampling
for i in range(n_blocks_gen):
x = res_block(x, ngf*mult, use_dropout=True)
# Decrease filter number to 3 (RGB)
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
x = Conv2DTranspose(filters=int(ngf * mult / 2), kernel_size=(3,3), strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = ReflectionPadding2D((3,3))(x)
x = Conv2D(filters=output_nc, kernel_size=(7,7), padding='valid')(x)
x = Activation('tanh')(x)
# Add direct connection from input to output and recenter to [-1, 1]
outputs = Add()([x, inputs])
outputs = Lambda(lambda z: z/2)(outputs)
model = Model(inputs=inputs, outputs=outputs, name='Generator')
return model
Keras实现生成器的架构
按计划,9个ResNet模块应用于之前的输入采样版本。我们添加从输入到输出的连接,然后除以2以保持归一化输出。
这样生成器就完成了,让我们来看看鉴别器的架构。
生成式对抗网络鉴别器的目标是确定输入图像是否是伪造的。因此,鉴别器的架构是卷积的并输出单一值。
from keras.layers import Input
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D
from keras.layers.core import Dense, Flatten
from keras.layers.normalization import BatchNormalization
from keras.models import Model
ndf = 64
output_nc = 3
input_shape_discriminator = (256, 256, output_nc)
def discriminator_model():
"""Build discriminator architecture."""
n_layers, use_sigmoid = 3, False
inputs = Input(shape=input_shape_discriminator)
x = Conv2D(filters=ndf, kernel_size=(4,4), strides=2, padding='same')(inputs)
x = LeakyReLU(0.2)(x)
nf_mult, nf_mult_prev = 1, 1
for n in range(n_layers):
nf_mult_prev, nf_mult = nf_mult, min(2**n, 8)
x = Conv2D(filters=ndf*nf_mult, kernel_size=(4,4), strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
nf_mult_prev, nf_mult = nf_mult, min(2**n_layers, 8)
x = Conv2D(filters=ndf*nf_mult, kernel_size=(4,4), strides=1, padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(filters=1, kernel_size=(4,4), strides=1, padding='same')(x)
if use_sigmoid:
x = Activation('sigmoid')(x)
x = Flatten()(x)
x = Dense(1024, activation='tanh')(x)
x = Dense(1, activation='sigmoid')(x)
model = Model(inputs=inputs, outputs=x, name='Discriminator')
return model
使用Keras进行鉴别者架构的实现
最后一步是构建完整生成式对抗网络模型。这个GAN的一个特点是输入是真实的图像而不是噪音。因此,我们对生成机的输出有直接反馈。
from keras.layers import Input
from keras.models import Model
def generator_containing_discriminator_multiple_outputs(generator, discriminator):
inputs = Input(shape=image_shape)
generated_images = generator(inputs)
outputs = discriminator(generated_images)
model = Model(inputs=inputs, outputs=[generated_images, outputs])
return model
我们分别在两处提取损失,在生成式对抗网络生成器的末尾和整个模型的末尾。
第一个是直接根据生成输出计算的感知损失。这个损失确保了GAN模型进行去模糊的任务。它比较VGG的第一个卷积的输出。
import keras.backend as K
from keras.applications.vgg16 import VGG16
from keras.models import Model
image_shape = (256, 256, 3)
def perceptual_loss(y_true, y_pred):
vgg = VGG16(include_top=False, weights='imagenet', input_shape=image_shape)
loss_model = Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output)
loss_model.trainable = False
return K.mean(K.square(loss_model(y_true) - loss_model(y_pred)))
import keras.backend as K
def wasserstein_loss(y_true, y_pred):
return K.mean(y_true*y_pred)
第一步,加载数据并初始化所有模型。我们使用我们的自定义函数来加载数据集,并为我们的模型添加Adam优化。我们设置Keras可训练选项,防止鉴别器进行训练。
# Load dataset
data = load_images('./images/train', n_images)
y_train, x_train = data['B'], data['A']
# Initialize models
g = generator_model()
d = discriminator_model()
d_on_g = generator_containing_discriminator_multiple_outputs(g, d)
# Initialize optimizers
g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
d_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
d_on_g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
# Compile models
d.trainable = True
d.compile(optimizer=d_opt, loss=wasserstein_loss)
d.trainable = False
loss = [perceptual_loss, wasserstein_loss]
loss_weights = [100, 1]
d_on_g.compile(optimizer=d_on_g_opt, loss=loss, loss_weights=loss_weights)
d.trainable = True
然后,我们开始生成式对抗网络训练,并将数据集分成几个批次。
for epoch in range(epoch_num):
print('epoch: {}/{}'.format(epoch, epoch_num))
print('batches: {}'.format(x_train.shape[0] / batch_size))
# Randomize images into batches
permutated_indexes = np.random.permutation(x_train.shape[0])
for index in range(int(x_train.shape[0] / batch_size)):
batch_indexes = permutated_indexes[index*batch_size:(index+1)*batch_size]
image_blur_batch = x_train[batch_indexes]
image_full_batch = y_train[batch_indexes]
for epoch in range(epoch_num):
for index in range(batches):
# [Batch Preparation]
# Generate fake inputs
generated_images = g.predict(x=image_blur_batch, batch_size=batch_size)
# Train multiple times discriminator on real and fake inputs
for _ in range(critic_updates):
d_loss_real = d.train_on_batch(image_full_batch, output_true_batch)
d_loss_fake = d.train_on_batch(generated_images, output_false_batch)
d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
d.trainable = False
# Train generator only on discriminator's decision and generated images
d_on_g_loss = d_on_g.train_on_batch(image_blur_batch, [image_full_batch, output_true_batch])
d.trainable = True
Github:https://www.github.com/raphaelmeudec/deblur-gan
我使用了AWS实例(p2.xlarge)和Deep Learning AMI(版本3.0)。使用GOPRO数据集,训练时间约为5小时(50个周期)。
从左到右:原始图像,模糊图像,GAN输出
上图是我们Keras去模糊GAN的结果。即使在模糊很重的情况下,网络也能够减少模糊并生成令人信服的图像。我们能够看到车灯和树枝更清晰了。
左:GOPRO测试图像,右:GAN输出
我们能看到图像顶部的缺陷(条纹状),这可能是因为使用VGG作为损失引起的。
左:GOPRO测试图像,右:GAN输出
左:GOPRO测试图像,右:GAN输出