介绍
残差神经网络(通常称为 ResNet)标志着深度学习领域的重大突破,尤其是在计算机视觉领域。ResNet最初由 Kaiming He 及其同事在 2015 年的开创性论文中提出,通过解决一个关键挑战:退化问题,彻底改变了深度学习领域。
神经网络退化问题及ResNet的解决方案
随着神经网络变得更深,理论上它们因其增加的学习能力而应该实现更好的性能。然而,在实践中,情况并非如此。更深的网络通常遭受退化问题的困扰,当增加更多层时,会导致更高的训练误差,并不是因为过拟合,而是因为优化更深的网络变得更加困难。ResNet通过一个新颖的架构创新来解决这个问题:跳过连接,也被称为残差连接。
残差网络的架构
典型的ResNet架构由几个叠加的“残差块”组成。每个块都有一个快捷路径或跳过连接,可以绕过一个或多个层。在一个基本的残差块中,输入与叠加层的输出相加(如果必要,进行适当的维度匹配)。这种设计允许网络学习一个恒等函数,确保更深层的模型不会比它的浅层对应物差。
残差学习的好处
应用与影响
ResNet已在各个领域得到广泛采用,特别是在计算机视觉任务,如图像分类、物体检测和分割方面。残差学习原理还启发了其他神经网络架构,并且在视觉以外的其他领域(例如自然语言处理)的进展中起到了关键作用。
未来方向和局限性
尽管ResNet是重大进步,但它们并非没有局限性。这些网络庞大的尺寸在计算资源和内存需求方面是一个挑战。此外,随着研究的继续,新兴的架构正在出现,这些架构挑战或建立在ResNet范式之上,提供了更好的性能和效率。
代码
使用Python创建一个残差神经网络(ResNet)的完整实现,以及一个合成数据集和图表,涉及几个步骤。我们将开始定义一个使用TensorFlow和Keras的简单ResNet架构,生成一个合成数据集,训练网络,最后画出结果。鉴于ResNet的复杂性和合成数据集的局限性,这个例子将是相对基础的。
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, ReLU, Add, MaxPooling2D, GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model
def residual_block(x, filters, kernel_size=3, stride=1):
shortcut = x
x = Conv2D(filters, kernel_size, strides=stride, padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)
x = Conv2D(filters, kernel_size, strides=stride, padding='same')(x)
x = BatchNormalization()(x)
shortcut = Conv2D(filters, kernel_size=1, strides=stride, padding='same')(shortcut)
shortcut = BatchNormalization()(shortcut)
x = Add()([x, shortcut])
x = ReLU()(x)
return x
def create_resnet(input_shape, num_classes):
inputs = Input(shape=input_shape)
x = Conv2D(64, 7, strides=2, padding='same')(inputs)
x = BatchNormalization()(x)
x = ReLU()(x)
x = MaxPooling2D(3, strides=2, padding='same')(x)
num_blocks_list = [2, 2, 2, 2] # Simplified version
for filters in [64, 128, 256, 512]:
for _ in range(num_blocks_list.pop(0)):
x = residual_block(x, filters)
x = GlobalAveragePooling2D()(x)
outputs = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=inputs, outputs=outputs)
return model
resnet_model = create_resnet(input_shape=(224, 224, 3), num_classes=10)
resnet_model.summary()
# Generating synthetic data: 1000 samples with 10 classes
num_samples = 1000
num_classes = 10
x_train = np.random.random((num_samples, 224, 224, 3))
y_train = np.random.randint(num_classes, size=(num_samples,))
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
resnet_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
history = resnet_model.fit(x_train, y_train, batch_size=32, epochs=10, validation_split=0.2)
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.legend()
plt.title('Accuracy over Epochs')
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.title('Loss over Epochs')
plt.show()
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 224, 224, 3)] 0 []
conv2d (Conv2D) (None, 112, 112, 64) 9472 ['input_1[0][0]']
batch_normalization (Batch (None, 112, 112, 64) 256 ['conv2d[0][0]']
Normalization)
re_lu (ReLU) (None, 112, 112, 64) 0 ['batch_normalization[0][0]']
max_pooling2d (MaxPooling2 (None, 56, 56, 64) 0 ['re_lu[0][0]']
D)
conv2d_1 (Conv2D) (None, 56, 56, 64) 36928 ['max_pooling2d[0][0]']
batch_normalization_1 (Bat (None, 56, 56, 64) 256 ['conv2d_1[0][0]']
chNormalization)
re_lu_1 (ReLU) (None, 56, 56, 64) 0 ['batch_normalization_1[0][0]'
]
conv2d_2 (Conv2D) (None, 56, 56, 64) 36928 ['re_lu_1[0][0]']
conv2d_3 (Conv2D) (None, 56, 56, 64) 4160 ['max_pooling2d[0][0]']
batch_normalization_2 (Bat (None, 56, 56, 64) 256 ['conv2d_2[0][0]']
chNormalization)
batch_normalization_3 (Bat (None, 56, 56, 64) 256 ['conv2d_3[0][0]']
chNormalization)
add (Add) (None, 56, 56, 64) 0 ['batch_normalization_2[0][0]'
, 'batch_normalization_3[0][0]
']
re_lu_2 (ReLU) (None, 56, 56, 64) 0 ['add[0][0]']
conv2d_4 (Conv2D) (None, 56, 56, 64) 36928 ['re_lu_2[0][0]']
batch_normalization_4 (Bat (None, 56, 56, 64) 256 ['conv2d_4[0][0]']
chNormalization)
re_lu_3 (ReLU) (None, 56, 56, 64) 0 ['batch_normalization_4[0][0]'
]
conv2d_5 (Conv2D) (None, 56, 56, 64) 36928 ['re_lu_3[0][0]']
conv2d_6 (Conv2D) (None, 56, 56, 64) 4160 ['re_lu_2[0][0]']
batch_normalization_5 (Bat (None, 56, 56, 64) 256 ['conv2d_5[0][0]']
chNormalization)
batch_normalization_6 (Bat (None, 56, 56, 64) 256 ['conv2d_6[0][0]']
chNormalization)
add_1 (Add) (None, 56, 56, 64) 0 ['batch_normalization_5[0][0]'
, 'batch_normalization_6[0][0]
']
re_lu_4 (ReLU) (None, 56, 56, 64) 0 ['add_1[0][0]']
conv2d_7 (Conv2D) (None, 56, 56, 128) 73856 ['re_lu_4[0][0]']
batch_normalization_7 (Bat (None, 56, 56, 128) 512 ['conv2d_7[0][0]']
chNormalization)
re_lu_5 (ReLU) (None, 56, 56, 128) 0 ['batch_normalization_7[0][0]'
]
conv2d_8 (Conv2D) (None, 56, 56, 128) 147584 ['re_lu_5[0][0]']
conv2d_9 (Conv2D) (None, 56, 56, 128) 8320 ['re_lu_4[0][0]']
batch_normalization_8 (Bat (None, 56, 56, 128) 512 ['conv2d_8[0][0]']
chNormalization)
batch_normalization_9 (Bat (None, 56, 56, 128) 512 ['conv2d_9[0][0]']
chNormalization)
add_2 (Add) (None, 56, 56, 128) 0 ['batch_normalization_8[0][0]'
, 'batch_normalization_9[0][0]
']
re_lu_6 (ReLU) (None, 56, 56, 128) 0 ['add_2[0][0]']
conv2d_10 (Conv2D) (None, 56, 56, 128) 147584 ['re_lu_6[0][0]']
batch_normalization_10 (Ba (None, 56, 56, 128) 512 ['conv2d_10[0][0]']
tchNormalization)
re_lu_7 (ReLU) (None, 56, 56, 128) 0 ['batch_normalization_10[0][0]
']
conv2d_11 (Conv2D) (None, 56, 56, 128) 147584 ['re_lu_7[0][0]']
conv2d_12 (Conv2D) (None, 56, 56, 128) 16512 ['re_lu_6[0][0]']
batch_normalization_11 (Ba (None, 56, 56, 128) 512 ['conv2d_11[0][0]']
tchNormalization)
batch_normalization_12 (Ba (None, 56, 56, 128) 512 ['conv2d_12[0][0]']
tchNormalization)
add_3 (Add) (None, 56, 56, 128) 0 ['batch_normalization_11[0][0]
',
'batch_normalization_12[0][0]
']
re_lu_8 (ReLU) (None, 56, 56, 128) 0 ['add_3[0][0]']
conv2d_13 (Conv2D) (None, 56, 56, 256) 295168 ['re_lu_8[0][0]']
batch_normalization_13 (Ba (None, 56, 56, 256) 1024 ['conv2d_13[0][0]']
tchNormalization)
re_lu_9 (ReLU) (None, 56, 56, 256) 0 ['batch_normalization_13[0][0]
']
conv2d_14 (Conv2D) (None, 56, 56, 256) 590080 ['re_lu_9[0][0]']
conv2d_15 (Conv2D) (None, 56, 56, 256) 33024 ['re_lu_8[0][0]']
batch_normalization_14 (Ba (None, 56, 56, 256) 1024 ['conv2d_14[0][0]']
tchNormalization)
batch_normalization_15 (Ba (None, 56, 56, 256) 1024 ['conv2d_15[0][0]']
tchNormalization)
add_4 (Add) (None, 56, 56, 256) 0 ['batch_normalization_14[0][0]
',
'batch_normalization_15[0][0]
']
re_lu_10 (ReLU) (None, 56, 56, 256) 0 ['add_4[0][0]']
conv2d_16 (Conv2D) (None, 56, 56, 256) 590080 ['re_lu_10[0][0]']
batch_normalization_16 (Ba (None, 56, 56, 256) 1024 ['conv2d_16[0][0]']
tchNormalization)
re_lu_11 (ReLU) (None, 56, 56, 256) 0 ['batch_normalization_16[0][0]
']
conv2d_17 (Conv2D) (None, 56, 56, 256) 590080 ['re_lu_11[0][0]']
conv2d_18 (Conv2D) (None, 56, 56, 256) 65792 ['re_lu_10[0][0]']
batch_normalization_17 (Ba (None, 56, 56, 256) 1024 ['conv2d_17[0][0]']
tchNormalization)
batch_normalization_18 (Ba (None, 56, 56, 256) 1024 ['conv2d_18[0][0]']
tchNormalization)
add_5 (Add) (None, 56, 56, 256) 0 ['batch_normalization_17[0][0]
',
'batch_normalization_18[0][0]
']
re_lu_12 (ReLU) (None, 56, 56, 256) 0 ['add_5[0][0]']
conv2d_19 (Conv2D) (None, 56, 56, 512) 1180160 ['re_lu_12[0][0]']
batch_normalization_19 (Ba (None, 56, 56, 512) 2048 ['conv2d_19[0][0]']
tchNormalization)
re_lu_13 (ReLU) (None, 56, 56, 512) 0 ['batch_normalization_19[0][0]
']
conv2d_20 (Conv2D) (None, 56, 56, 512) 2359808 ['re_lu_13[0][0]']
conv2d_21 (Conv2D) (None, 56, 56, 512) 131584 ['re_lu_12[0][0]']
batch_normalization_20 (Ba (None, 56, 56, 512) 2048 ['conv2d_20[0][0]']
tchNormalization)
batch_normalization_21 (Ba (None, 56, 56, 512) 2048 ['conv2d_21[0][0]']
tchNormalization)
add_6 (Add) (None, 56, 56, 512) 0 ['batch_normalization_20[0][0]
',
'batch_normalization_21[0][0]
']
re_lu_14 (ReLU) (None, 56, 56, 512) 0 ['add_6[0][0]']
conv2d_22 (Conv2D) (None, 56, 56, 512) 2359808 ['re_lu_14[0][0]']
batch_normalization_22 (Ba (None, 56, 56, 512) 2048 ['conv2d_22[0][0]']
tchNormalization)
re_lu_15 (ReLU) (None, 56, 56, 512) 0 ['batch_normalization_22[0][0]
']
conv2d_23 (Conv2D) (None, 56, 56, 512) 2359808 ['re_lu_15[0][0]']
conv2d_24 (Conv2D) (None, 56, 56, 512) 262656 ['re_lu_14[0][0]']
batch_normalization_23 (Ba (None, 56, 56, 512) 2048 ['conv2d_23[0][0]']
tchNormalization)
batch_normalization_24 (Ba (None, 56, 56, 512) 2048 ['conv2d_24[0][0]']
tchNormalization)
add_7 (Add) (None, 56, 56, 512) 0 ['batch_normalization_23[0][0]
',
'batch_normalization_24[0][0]
']
re_lu_16 (ReLU) (None, 56, 56, 512) 0 ['add_7[0][0]']
global_average_pooling2d ( (None, 512) 0 ['re_lu_16[0][0]']
GlobalAveragePooling2D)
dense (Dense) (None, 10) 5130 ['global_average_pooling2d[0][
0]']
==================================================================================================
Total params: 11553418 (44.07 MB)
Trainable params: 11541770 (44.03 MB)
Non-trainable params: 11648 (45.50 KB)
这段代码提供了一个基本示范,展示了如何实现和训练一个ResNet模型。
结论
残差神经网络代表了深度学习架构演变中的基石。它们的引入不仅解决了训练非常深层网络的迫切问题,还为未来在该领域的架构创新设定了先例。随着深度学习的不断演进,ResNets所奠定的原则无疑将继续影响更先进和高效模型的开发。