首页 > 解决方案 > Tensorflow 中的 dropout 层会影响哪些层?

问题描述

考虑迁移学习,以便在 keras/tensorflow 中使用预训练模型。对于每个旧层,trained将参数设置为false使其权重在训练期间不会更新,而最后一层已被新层替换,并且必须对其进行训练。512特别是添加了两个具有神经元和1024relu 激活函数的全连接隐藏层。在这些层之后,Dropout 层与rate 0.2. 这意味着在20%神经元的每个训练时期都被随机丢弃。

这个 dropout 层会影响哪些层?它会影响所有网络,包括layer.trainable=false已设置的预训练层还是仅影响新添加的层?还是它只影响前一层(即有1024神经元的层)?

换句话说,在每个 epoch 中被 dropout 关闭的神经元属于哪一层?

import os

from tensorflow.keras import layers
from tensorflow.keras import Model
  
from tensorflow.keras.applications.inception_v3 import InceptionV3

local_weights_file = 'weights.h5'

pre_trained_model = InceptionV3(input_shape = (150, 150, 3), 
                                include_top = False, 
                                weights = None)

pre_trained_model.load_weights(local_weights_file)

for layer in pre_trained_model.layers:
  layer.trainable = False
  
# pre_trained_model.summary()

last_layer = pre_trained_model.get_layer('mixed7')
last_output = last_layer.output

# Flatten the output layer to 1 dimension
x = layers.Flatten()(last_output)
# Add two fully connected layers with 512 and 1,024 hidden units and ReLU activation
x = layers.Dense(512, activation='relu')(x)
x = layers.Dense(1024, activation='relu')(x)
# Add a dropout rate of 0.2
x = layers.Dropout(0.2)(x)                  
# Add a final sigmoid layer for classification
x = layers.Dense  (1, activation='sigmoid')(x)           

model = Model( pre_trained_model.input, x) 

model.compile(optimizer = RMSprop(lr=0.0001), 
              loss = 'binary_crossentropy', 
              metrics = ['accuracy'])

标签: pythontensorflowkerastransfer-learningdropout

解决方案


只有一层的神经元被“关闭”,但所有层都在反向传播方面“受到影响”。

  • 后层:Dropout 的输出是下一层的输入,所以下一层的输出会改变,next-next 的输出也会改变,等等。
  • 先前的层:随着 pre-Dropout 层的“有效输出”发生变化,它的梯度也会发生变化,因此任何后续的梯度也会发生变化。在 的极端情况下Dropout(rate=1),零梯度将流动。

另外,请注意,只有当 Dense 的输入是 2D 时,整个神经元才会被丢弃(batch_size, features);Dropout 对所有维度应用随机统一掩码(相当于在 2D 情况下丢弃整个神经元)。要删除整个神经元,请设置Dropout(.2, noise_shape=(batch_size, 1, features))(3D 案例)。要在所有样本中删除相同noise_shape=(1, 1, features)的神经元,请使用(或(1, features)用于 2D)。


推荐阅读