使用Python中的合成数据集理解和实现残差神经网络 (ResNet)

2023年12月27日 由 alex 发表 245 0

介绍


残差神经网络(通常称为 ResNet)标志着深度学习领域的重大突破,尤其是在计算机视觉领域。ResNet最初由 Kaiming He 及其同事在 2015 年的开创性论文中提出,通过解决一个关键挑战:退化问题,彻底改变了深度学习领域。


1


神经网络退化问题及ResNet的解决方案


随着神经网络变得更深,理论上它们因其增加的学习能力而应该实现更好的性能。然而,在实践中,情况并非如此。更深的网络通常遭受退化问题的困扰,当增加更多层时,会导致更高的训练误差,并不是因为过拟合,而是因为优化更深的网络变得更加困难。ResNet通过一个新颖的架构创新来解决这个问题:跳过连接,也被称为残差连接。


残差网络的架构


典型的ResNet架构由几个叠加的“残差块”组成。每个块都有一个快捷路径或跳过连接,可以绕过一个或多个层。在一个基本的残差块中,输入与叠加层的输出相加(如果必要,进行适当的维度匹配)。这种设计允许网络学习一个恒等函数,确保更深层的模型不会比它的浅层对应物差。


残差学习的好处


  1. 优化的便捷性:ResNet中的跳过连接有助于直接通过网络反向传播梯度,使得训练更深的模型变得更容易。
  2. 解决梯度消失/爆炸问题:通过允许这种交替的快捷路径梯度流动,ResNet缓解了深度网络中常见的梯度消失和爆炸问题。
  3. 实现更深的网络:在ResNet之前,训练具有许多层的深度神经网络是困难的。ResNet使得开发更深度(例如,ResNet-152)的网络成为可能,实现了显著的性能改善。


应用与影响


ResNet已在各个领域得到广泛采用,特别是在计算机视觉任务,如图像分类、物体检测和分割方面。残差学习原理还启发了其他神经网络架构,并且在视觉以外的其他领域(例如自然语言处理)的进展中起到了关键作用。


未来方向和局限性


尽管ResNet是重大进步,但它们并非没有局限性。这些网络庞大的尺寸在计算资源和内存需求方面是一个挑战。此外,随着研究的继续,新兴的架构正在出现,这些架构挑战或建立在ResNet范式之上,提供了更好的性能和效率。


代码


使用Python创建一个残差神经网络(ResNet)的完整实现,以及一个合成数据集和图表,涉及几个步骤。我们将开始定义一个使用TensorFlow和Keras的简单ResNet架构,生成一个合成数据集,训练网络,最后画出结果。鉴于ResNet的复杂性和合成数据集的局限性,这个例子将是相对基础的。


import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, ReLU, Add, MaxPooling2D, GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model
def residual_block(x, filters, kernel_size=3, stride=1):
    shortcut = x
    x = Conv2D(filters, kernel_size, strides=stride, padding='same')(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(filters, kernel_size, strides=stride, padding='same')(x)
    x = BatchNormalization()(x)
    shortcut = Conv2D(filters, kernel_size=1, strides=stride, padding='same')(shortcut)
    shortcut = BatchNormalization()(shortcut)
    x = Add()([x, shortcut])
    x = ReLU()(x)
    return x
def create_resnet(input_shape, num_classes):
    inputs = Input(shape=input_shape)
    x = Conv2D(64, 7, strides=2, padding='same')(inputs)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = MaxPooling2D(3, strides=2, padding='same')(x)
    num_blocks_list = [2, 2, 2, 2]  # Simplified version
    for filters in [64, 128, 256, 512]:
        for _ in range(num_blocks_list.pop(0)):
            x = residual_block(x, filters)
    
    x = GlobalAveragePooling2D()(x)
    outputs = Dense(num_classes, activation='softmax')(x)
    model = Model(inputs=inputs, outputs=outputs)
    return model
resnet_model = create_resnet(input_shape=(224, 224, 3), num_classes=10)
resnet_model.summary()
# Generating synthetic data: 1000 samples with 10 classes
num_samples = 1000
num_classes = 10
x_train = np.random.random((num_samples, 224, 224, 3))
y_train = np.random.randint(num_classes, size=(num_samples,))
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
resnet_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
history = resnet_model.fit(x_train, y_train, batch_size=32, epochs=10, validation_split=0.2)
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.legend()
plt.title('Accuracy over Epochs')
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.title('Loss over Epochs')
plt.show()


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
==================================================================================================
 input_1 (InputLayer)        [(None, 224, 224, 3)]        0         []                            
                                                                                                  
 conv2d (Conv2D)             (None, 112, 112, 64)         9472      ['input_1[0][0]']             
                                                                                                  
 batch_normalization (Batch  (None, 112, 112, 64)         256       ['conv2d[0][0]']              
 Normalization)                                                                                   
                                                                                                  
 re_lu (ReLU)                (None, 112, 112, 64)         0         ['batch_normalization[0][0]'] 
                                                                                                  
 max_pooling2d (MaxPooling2  (None, 56, 56, 64)           0         ['re_lu[0][0]']               
 D)                                                                                               
                                                                                                  
 conv2d_1 (Conv2D)           (None, 56, 56, 64)           36928     ['max_pooling2d[0][0]']       
                                                                                                  
 batch_normalization_1 (Bat  (None, 56, 56, 64)           256       ['conv2d_1[0][0]']            
 chNormalization)                                                                                 
                                                                                                  
 re_lu_1 (ReLU)              (None, 56, 56, 64)           0         ['batch_normalization_1[0][0]'
                                                                    ]                             
                                                                                                  
 conv2d_2 (Conv2D)           (None, 56, 56, 64)           36928     ['re_lu_1[0][0]']             
                                                                                                  
 conv2d_3 (Conv2D)           (None, 56, 56, 64)           4160      ['max_pooling2d[0][0]']       
                                                                                                  
 batch_normalization_2 (Bat  (None, 56, 56, 64)           256       ['conv2d_2[0][0]']            
 chNormalization)                                                                                 
                                                                                                  
 batch_normalization_3 (Bat  (None, 56, 56, 64)           256       ['conv2d_3[0][0]']            
 chNormalization)                                                                                 
                                                                                                  
 add (Add)                   (None, 56, 56, 64)           0         ['batch_normalization_2[0][0]'
                                                                    , 'batch_normalization_3[0][0]
                                                                    ']                            
                                                                                                  
 re_lu_2 (ReLU)              (None, 56, 56, 64)           0         ['add[0][0]']                 
                                                                                                  
 conv2d_4 (Conv2D)           (None, 56, 56, 64)           36928     ['re_lu_2[0][0]']             
                                                                                                  
 batch_normalization_4 (Bat  (None, 56, 56, 64)           256       ['conv2d_4[0][0]']            
 chNormalization)                                                                                 
                                                                                                  
 re_lu_3 (ReLU)              (None, 56, 56, 64)           0         ['batch_normalization_4[0][0]'
                                                                    ]                             
                                                                                                  
 conv2d_5 (Conv2D)           (None, 56, 56, 64)           36928     ['re_lu_3[0][0]']             
                                                                                                  
 conv2d_6 (Conv2D)           (None, 56, 56, 64)           4160      ['re_lu_2[0][0]']             
                                                                                                  
 batch_normalization_5 (Bat  (None, 56, 56, 64)           256       ['conv2d_5[0][0]']            
 chNormalization)                                                                                 
                                                                                                  
 batch_normalization_6 (Bat  (None, 56, 56, 64)           256       ['conv2d_6[0][0]']            
 chNormalization)                                                                                 
                                                                                                  
 add_1 (Add)                 (None, 56, 56, 64)           0         ['batch_normalization_5[0][0]'
                                                                    , 'batch_normalization_6[0][0]
                                                                    ']                            
                                                                                                  
 re_lu_4 (ReLU)              (None, 56, 56, 64)           0         ['add_1[0][0]']               
                                                                                                  
 conv2d_7 (Conv2D)           (None, 56, 56, 128)          73856     ['re_lu_4[0][0]']             
                                                                                                  
 batch_normalization_7 (Bat  (None, 56, 56, 128)          512       ['conv2d_7[0][0]']            
 chNormalization)                                                                                 
                                                                                                  
 re_lu_5 (ReLU)              (None, 56, 56, 128)          0         ['batch_normalization_7[0][0]'
                                                                    ]                             
                                                                                                  
 conv2d_8 (Conv2D)           (None, 56, 56, 128)          147584    ['re_lu_5[0][0]']             
                                                                                                  
 conv2d_9 (Conv2D)           (None, 56, 56, 128)          8320      ['re_lu_4[0][0]']             
                                                                                                  
 batch_normalization_8 (Bat  (None, 56, 56, 128)          512       ['conv2d_8[0][0]']            
 chNormalization)                                                                                 
                                                                                                  
 batch_normalization_9 (Bat  (None, 56, 56, 128)          512       ['conv2d_9[0][0]']            
 chNormalization)                                                                                 
                                                                                                  
 add_2 (Add)                 (None, 56, 56, 128)          0         ['batch_normalization_8[0][0]'
                                                                    , 'batch_normalization_9[0][0]
                                                                    ']                            
                                                                                                  
 re_lu_6 (ReLU)              (None, 56, 56, 128)          0         ['add_2[0][0]']               
                                                                                                  
 conv2d_10 (Conv2D)          (None, 56, 56, 128)          147584    ['re_lu_6[0][0]']             
                                                                                                  
 batch_normalization_10 (Ba  (None, 56, 56, 128)          512       ['conv2d_10[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 re_lu_7 (ReLU)              (None, 56, 56, 128)          0         ['batch_normalization_10[0][0]
                                                                    ']                            
                                                                                                  
 conv2d_11 (Conv2D)          (None, 56, 56, 128)          147584    ['re_lu_7[0][0]']             
                                                                                                  
 conv2d_12 (Conv2D)          (None, 56, 56, 128)          16512     ['re_lu_6[0][0]']             
                                                                                                  
 batch_normalization_11 (Ba  (None, 56, 56, 128)          512       ['conv2d_11[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 batch_normalization_12 (Ba  (None, 56, 56, 128)          512       ['conv2d_12[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 add_3 (Add)                 (None, 56, 56, 128)          0         ['batch_normalization_11[0][0]
                                                                    ',                            
                                                                     'batch_normalization_12[0][0]
                                                                    ']                            
                                                                                                  
 re_lu_8 (ReLU)              (None, 56, 56, 128)          0         ['add_3[0][0]']               
                                                                                                  
 conv2d_13 (Conv2D)          (None, 56, 56, 256)          295168    ['re_lu_8[0][0]']             
                                                                                                  
 batch_normalization_13 (Ba  (None, 56, 56, 256)          1024      ['conv2d_13[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 re_lu_9 (ReLU)              (None, 56, 56, 256)          0         ['batch_normalization_13[0][0]
                                                                    ']                            
                                                                                                  
 conv2d_14 (Conv2D)          (None, 56, 56, 256)          590080    ['re_lu_9[0][0]']             
                                                                                                  
 conv2d_15 (Conv2D)          (None, 56, 56, 256)          33024     ['re_lu_8[0][0]']             
                                                                                                  
 batch_normalization_14 (Ba  (None, 56, 56, 256)          1024      ['conv2d_14[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 batch_normalization_15 (Ba  (None, 56, 56, 256)          1024      ['conv2d_15[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 add_4 (Add)                 (None, 56, 56, 256)          0         ['batch_normalization_14[0][0]
                                                                    ',                            
                                                                     'batch_normalization_15[0][0]
                                                                    ']                            
                                                                                                  
 re_lu_10 (ReLU)             (None, 56, 56, 256)          0         ['add_4[0][0]']               
                                                                                                  
 conv2d_16 (Conv2D)          (None, 56, 56, 256)          590080    ['re_lu_10[0][0]']            
                                                                                                  
 batch_normalization_16 (Ba  (None, 56, 56, 256)          1024      ['conv2d_16[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 re_lu_11 (ReLU)             (None, 56, 56, 256)          0         ['batch_normalization_16[0][0]
                                                                    ']                            
                                                                                                  
 conv2d_17 (Conv2D)          (None, 56, 56, 256)          590080    ['re_lu_11[0][0]']            
                                                                                                  
 conv2d_18 (Conv2D)          (None, 56, 56, 256)          65792     ['re_lu_10[0][0]']            
                                                                                                  
 batch_normalization_17 (Ba  (None, 56, 56, 256)          1024      ['conv2d_17[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 batch_normalization_18 (Ba  (None, 56, 56, 256)          1024      ['conv2d_18[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 add_5 (Add)                 (None, 56, 56, 256)          0         ['batch_normalization_17[0][0]
                                                                    ',                            
                                                                     'batch_normalization_18[0][0]
                                                                    ']                            
                                                                                                  
 re_lu_12 (ReLU)             (None, 56, 56, 256)          0         ['add_5[0][0]']               
                                                                                                  
 conv2d_19 (Conv2D)          (None, 56, 56, 512)          1180160   ['re_lu_12[0][0]']            
                                                                                                  
 batch_normalization_19 (Ba  (None, 56, 56, 512)          2048      ['conv2d_19[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 re_lu_13 (ReLU)             (None, 56, 56, 512)          0         ['batch_normalization_19[0][0]
                                                                    ']                            
                                                                                                  
 conv2d_20 (Conv2D)          (None, 56, 56, 512)          2359808   ['re_lu_13[0][0]']            
                                                                                                  
 conv2d_21 (Conv2D)          (None, 56, 56, 512)          131584    ['re_lu_12[0][0]']            
                                                                                                  
 batch_normalization_20 (Ba  (None, 56, 56, 512)          2048      ['conv2d_20[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 batch_normalization_21 (Ba  (None, 56, 56, 512)          2048      ['conv2d_21[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 add_6 (Add)                 (None, 56, 56, 512)          0         ['batch_normalization_20[0][0]
                                                                    ',                            
                                                                     'batch_normalization_21[0][0]
                                                                    ']                            
                                                                                                  
 re_lu_14 (ReLU)             (None, 56, 56, 512)          0         ['add_6[0][0]']               
                                                                                                  
 conv2d_22 (Conv2D)          (None, 56, 56, 512)          2359808   ['re_lu_14[0][0]']            
                                                                                                  
 batch_normalization_22 (Ba  (None, 56, 56, 512)          2048      ['conv2d_22[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 re_lu_15 (ReLU)             (None, 56, 56, 512)          0         ['batch_normalization_22[0][0]
                                                                    ']                            
                                                                                                  
 conv2d_23 (Conv2D)          (None, 56, 56, 512)          2359808   ['re_lu_15[0][0]']            
                                                                                                  
 conv2d_24 (Conv2D)          (None, 56, 56, 512)          262656    ['re_lu_14[0][0]']            
                                                                                                  
 batch_normalization_23 (Ba  (None, 56, 56, 512)          2048      ['conv2d_23[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 batch_normalization_24 (Ba  (None, 56, 56, 512)          2048      ['conv2d_24[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 add_7 (Add)                 (None, 56, 56, 512)          0         ['batch_normalization_23[0][0]
                                                                    ',                            
                                                                     'batch_normalization_24[0][0]
                                                                    ']                            
                                                                                                  
 re_lu_16 (ReLU)             (None, 56, 56, 512)          0         ['add_7[0][0]']               
                                                                                                  
 global_average_pooling2d (  (None, 512)                  0         ['re_lu_16[0][0]']            
 GlobalAveragePooling2D)                                                                          
                                                                                                  
 dense (Dense)               (None, 10)                   5130      ['global_average_pooling2d[0][
                                                                    0]']                          
                                                                                                  
==================================================================================================
Total params: 11553418 (44.07 MB)
Trainable params: 11541770 (44.03 MB)
Non-trainable params: 11648 (45.50 KB)


这段代码提供了一个基本示范,展示了如何实现和训练一个ResNet模型。


结论


残差神经网络代表了深度学习架构演变中的基石。它们的引入不仅解决了训练非常深层网络的迫切问题,还为未来在该领域的架构创新设定了先例。随着深度学习的不断演进,ResNets所奠定的原则无疑将继续影响更先进和高效模型的开发。

文章来源:https://medium.com/@evertongomede/understanding-and-implementing-residual-neural-networks-resnets-with-a-synthetic-dataset-in-e8411901a3c6
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
写评论取消
回复取消