首页 > 解决方案 > 实现图像字幕模型(尤其是 LSTM)时出现 OOM 错误

问题描述

我正在实施一个模型。

到目前为止,我已经实现了以下模型:

class Attention(tf.keras.Model):
    def __init__(self, units):
        super(Attention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)

    def call(self, features, hidden):
        # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)

        # hidden shape == (batch_size, hidden_size)
        # hidden_with_time_axis shape == (batch_size, 1, hidden_size)
        hidden_with_time_axis = tf.expand_dims(hidden, 1)

        # score shape == (batch_size, 64, hidden_size)
        score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))

        # attention_weights shape == (batch_size, 64, 1)
        # you get 1 at the last axis because you are applying score to self.V
        attention_weights = tf.nn.softmax(self.V(score), axis=1)

        # context_vector shape after sum == (batch_size, hidden_size)
        context_vector = attention_weights * features
        context_vector = tf.reduce_sum(context_vector, axis=1)

        return context_vector, attention_weights

class CNN_Encoder(tf.keras.Model):
    # Since you have already extracted the features and dumped it using pickle
    # This encoder passes those features through a Fully connected layer
    def __init__(self, embedding_dim):
        super(CNN_Encoder, self).__init__()
        # shape after fc == (batch_size, 49, embedding_dim)
        self.fc = tf.keras.layers.Dense(embedding_dim)

    def call(self, x):
        x = self.fc(x)
        x = tf.nn.relu(x)
        # shape of x == (batch_size, 49, embedding_dim)
        return x

class RDN_Decoder(tf.keras.Model):
    def __init__(self, embedding_dim, units, vocab_size):
        super(RDN_Decoder, self).__init__()
        self.units = units

        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.lstm1 = tf.keras.layers.LSTM(self.units,
                                          return_sequences=True,
                                          return_state=True,
                                          recurrent_initializer='glorot_uniform')
        self.lstm2 = tf.keras.layers.LSTM(self.units,
                                          return_sequences=True,
                                          return_state=True,
                                          recurrent_initializer='glorot_uniform')
        self.fc1 = tf.keras.layers.Dense(self.units)
        self.fc2 = tf.keras.layers.Dense(vocab_size)

        self.visual_attention = Attention(self.units)
        self.reflective_attention = Attention(self.units)

    def call(self, x, features, hidden_state1, hidden_state2):
        # x shape after passing through embedding == (batch_size, 1, embedding_dim)
        x = self.embedding(x)

        # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
        x = tf.concat([tf.expand_dims(hidden_state1, 1), x], axis=-1)

        # passing through lstm
        output1, hidden_state1, cell_state1 = self.lstm1(x)

        # visual attention as a separate model
        context_vector_v, attention_weights_v = self.visual_attention(features, hidden_state1)

        # change hidden state dimension
        hidden_state2 = tf.concat([tf.expand_dims(hidden_state2, 1), x], axis=-1)

        # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
        x = tf.concat([tf.expand_dims(context_vector_v, 1), hidden_state2], axis=-1)

        # passing the concatenated vector to the lstm
        output2, hidden_state2, cell_state2 = self.lstm2(x)

        # reflective attention as a separate model
        context_vector_r, attention_weights_r = self.reflective_attention(hidden_state2, hidden_state1)

        # shape == (batch_size, max_length, hidden_size)
        x = self.fc1(output2)

        # x shape == (batch_size * max_length, hidden_size)
        x = tf.reshape(x, (-1, x.shape[2]))

        # output shape == (batch_size * max_length, vocab)
        x = self.fc2(x)

        # pass through softmax
        x = tf.nn.softmax(x)

        return x, hidden_state1, hidden_state2, attention_weights_v, attention_weights_r

    def reset_state(self, batch_size):
        return tf.zeros((batch_size, self.units))
encoder = CNN_Encoder(embedding_dim)
decoder = RDN_Decoder(embedding_dim, units, vocab_size)

当我使用此模型进行训练时,出现以下错误:

--------------------------------------------------------------------------- ResourceExhaustedError                    Traceback (most recent call last) <ipython-input-63-e33dbe296f4b> in <module>()
     12 
     13     for (batch, (img_tensor, target)) in enumerate(dataset):
---> 14         batch_loss, t_loss = train_step(img_tensor, target)
     15         total_loss += t_loss
     16 

13 frames <ipython-input-62-b355d0692cf8> in train_step(img_tensor, target)
     15         for i in range(1, target.shape[1]):
     16             # passing the features through the decoder
---> 17             predictions, hidden_state1, hidden_state2, _, _ = decoder(dec_input, features, hidden_state1, hidden_state2)
     18 
     19             loss += loss_function(target[:, i], predictions)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
    966           with base_layer_utils.autocast_context_manager(
    967               self._compute_dtype):
--> 968             outputs = self.call(cast_inputs, *args, **kwargs)
    969           self._handle_activity_regularization(inputs, outputs)
    970           self._set_mask_metadata(inputs, outputs, input_masks)

<ipython-input-57-83f30c4f738b> in call(self, x, features, hidden_state1, hidden_state2)
     80 
     81         # passing the concatenated vector to the lstm
---> 82         output2, hidden_state2, cell_state2 = self.lstm2(x)
     83 
     84         # reflective attention as a separate model

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/recurrent.py in __call__(self, inputs, initial_state, constants, **kwargs)
    652 
    653     if initial_state is None and constants is None:
--> 654       return super(RNN, self).__call__(inputs, **kwargs)
    655 
    656     # If any of `initial_state` or `constants` are specified and are Keras

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
    966           with base_layer_utils.autocast_context_manager(
    967               self._compute_dtype):
--> 968             outputs = self.call(cast_inputs, *args, **kwargs)
    969           self._handle_activity_regularization(inputs, outputs)
    970           self._set_mask_metadata(inputs, outputs, input_masks)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/recurrent_v2.py in call(self, inputs, mask, training, initial_state)    1179         if can_use_gpu:    1180           last_output, outputs, new_h, new_c, runtime = gpu_lstm(
-> 1181               **gpu_lstm_kwargs)    1182         else:    1183           last_output, outputs, new_h, new_c, runtime = standard_lstm(

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/recurrent_v2.py in gpu_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask, time_major, go_backwards, sequence_lengths)    1390       biases=array_ops.split(full_bias, 8),    1391       shape=constant_op.constant([-1]),
-> 1392       transpose_weights=True)    1393     1394   if mask is not None:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/recurrent_v2.py in _canonical_to_params(weights, biases, shape, transpose_weights)    1234     return array_ops.transpose(w) if transpose_weights else w    1235 
-> 1236   weights = [array_ops.reshape(convert(x), shape) for x in weights]    1237   biases = [array_ops.reshape(x, shape) for x in biases]    1238   return array_ops.concat(weights + biases, axis=0)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/recurrent_v2.py in <listcomp>(.0)    1234     return array_ops.transpose(w) if transpose_weights else w    1235 
-> 1236   weights = [array_ops.reshape(convert(x), shape) for x in weights]    1237   biases = [array_ops.reshape(x, shape) for x in biases]    1238   return array_ops.concat(weights + biases, axis=0)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/recurrent_v2.py in convert(w)    1232   """    1233   def convert(w):
-> 1234     return array_ops.transpose(w) if transpose_weights else w    1235     1236   weights = [array_ops.reshape(convert(x), shape) for x in weights]

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/array_ops.py in transpose(a, perm, name, conjugate)    2127     else:    2128       perm = np.arange(rank - 1, -1, -1, dtype=np.int32)
-> 2129     return transpose_fn(a, perm, name=name)    2130     2131 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_array_ops.py in transpose(x, perm, name)   11176         pass  # Add nodes to the TensorFlow graph.   11177     except _core._NotOkStatusException as e:
> 11178       _ops.raise_from_not_ok_status(e, name)   11179   # Add nodes to the TensorFlow graph.   11180   _, _, _op, _outputs =
_op_def_library._apply_op_helper(

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)    6651   message = e.message + (" name: " + name if name is not None else "")    6652   # pylint: disable=protected-access
-> 6653   six.raise_from(core._status_to_exception(e.code, message), None)    6654   # pylint: enable=protected-access    6655 

/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)

ResourceExhaustedError: OOM when allocating tensor with shape[1024,1024] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:Transpose]

请告诉如何正确实现它(我想主要问题是 LSTM 的实现)。任何帮助将不胜感激。

如果你想尝试,请使用这个 Google Colab链接(因为数据和训练代码的生成很长,我不能把所有的东西都放在这里,会很笨拙)。您只需要按顺序运行单元格。

标签: pythontensorflowkeraslstmattention-model

解决方案


对于OOM问题,没有什么可以做的,从我的角度来看有两种选择。

  1. 获得更好的机器或租用可以为您提供更多 RAM 的云服务。
  2. 减少您的网络,我尝试使用以下超参数并且训练有效。

BATCH_SIZE = 8, embedding_dim = 512, units = 512. 其余所有超参数都相同

您将不得不尝试了解哪个是可以安装在您的机器中的最大网络。


推荐阅读