首页 > 解决方案 > TPU 不支持动态空间卷积

问题描述

我正在尝试将 Keras (TF 2.3.1) 模型用于图像分类,并将多个二进制标签作为输出。该模型由 Xception CNN + 注意力层 + 密集分类器组成,仅在一些 TPU 上出现错误: UnimplementedError: {{function_node __inference_train_function_644557}} Compilation failure: Dynamic Spatial Convolution is not supported. 这在 Kaggle TPU 上失败,但在 Colab 上失败- 在 TF 版本 2.3.1 上进行了测试。

我在看这里,但建议的解决方案意味着没有设置图像尺寸,这里不是这种情况。train_df是类型<PrefetchDataset shapes: ((None, 750, 750, 3), (None, 11)), types: (tf.float32, tf.int64)>,因此每个图像的大小为 750x750x3。根据下面的模型摘要,每一层都有一个定义的输出形状,因此跟随它们的层应该正确地推断它们的输入形状。

从错误来看,问题似乎出在attn_layer = LocallyConnected2D(.... Passingimplementation = 2是一种让训练完成的解决方法,但这不适用于大型模型(请参阅 LocallyConnected2D文档

建模代码:

import tensorflow as tf
from tensorflow.keras import models, layers
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.applications import Xception
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout, Flatten, Input, Conv2D, multiply, LocallyConnected2D, Lambda, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.metrics import mean_absolute_error

def create_model():
    input_shape = (TARGET_SIZE, TARGET_SIZE, 3)
    in_lay = Input(input_shape)
    conv_base = Xception(include_top = False, weights = 'imagenet', input_shape = input_shape)
    pt_features = conv_base(in_lay)
    bn_features = BatchNormalization()(pt_features)

    # here we do an attention mechanism to turn pixels in the GAP on an off
    attn_layer = Conv2D(64, kernel_size = (1,1), padding = 'same', activation = 'relu')(bn_features)
    attn_layer = Conv2D(16, kernel_size = (1,1), padding = 'same', activation = 'relu')(attn_layer)
    attn_layer = LocallyConnected2D(1, kernel_size = (1,1), padding = 'valid', activation = 'sigmoid')(attn_layer)
    # fan it out to all of the channels
    pt_depth = conv_base.get_output_shape_at(0)[-1]
    up_c2_w = np.ones((1, 1, 1, pt_depth))
    up_c2 = Conv2D(pt_depth, kernel_size = (1,1), padding = 'same', 
                activation = 'linear', use_bias = False, weights = [up_c2_w])
    up_c2.trainable = False
    attn_layer = up_c2(attn_layer)

    mask_features = multiply([attn_layer, bn_features])
    gap_features = GlobalAveragePooling2D()(mask_features)
    gap_mask = GlobalAveragePooling2D()(attn_layer)
    # to account for missing values from the attention model
    gap = Lambda(lambda x: x[0]/x[1], name = 'RescaleGAP')([gap_features, gap_mask])
    gap_dr = Dropout(0.5)(gap)
    dr_steps = Dropout(0.25)(Dense(1024, activation = 'elu')(gap_dr))
    out_layer = Dense(11, activation = 'sigmoid')(dr_steps)
    model = Model(inputs = [in_lay], outputs = [out_layer])
    model.compile(optimizer = Adam(lr = 0.002), loss = 'binary_crossentropy', metrics = ["AUC"])
    return model


with tpu_strategy.scope():
    model = create_model()
model.summary()

history = model.fit(
    train_df,
    epochs = EPOCHS,
    steps_per_epoch = STEPS_PER_EPOCH,
    validation_data = valid_df,
    validation_steps = VALIDATION_STEPS
)

生成的模型摘要:

Model: "model_8"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_19 (InputLayer)           [(None, 750, 750, 3) 0                                            
__________________________________________________________________________________________________
xception (Model)                (None, 24, 24, 2048) 20861480    input_19[0][0]                   
__________________________________________________________________________________________________
batch_normalization_49 (BatchNo (None, 24, 24, 2048) 8192        xception[1][0]                   
__________________________________________________________________________________________________
conv2d_67 (Conv2D)              (None, 24, 24, 64)   131136      batch_normalization_49[0][0]     
__________________________________________________________________________________________________
conv2d_68 (Conv2D)              (None, 24, 24, 16)   1040        conv2d_67[0][0]                  
__________________________________________________________________________________________________
locally_connected2d_9 (LocallyC (None, 24, 24, 1)    9792        conv2d_68[0][0]                  
__________________________________________________________________________________________________
conv2d_69 (Conv2D)              (None, 24, 24, 2048) 2048        locally_connected2d_9[0][0]      
__________________________________________________________________________________________________
multiply_9 (Multiply)           (None, 24, 24, 2048) 0           conv2d_69[0][0]                  
                                                                 batch_normalization_49[0][0]     
__________________________________________________________________________________________________
global_average_pooling2d_23 (Gl (None, 2048)         0           multiply_9[0][0]                 
__________________________________________________________________________________________________
global_average_pooling2d_24 (Gl (None, 2048)         0           conv2d_69[0][0]                  
__________________________________________________________________________________________________
RescaleGAP (Lambda)             (None, 2048)         0           global_average_pooling2d_23[0][0]
                                                                 global_average_pooling2d_24[0][0]
__________________________________________________________________________________________________
dropout_18 (Dropout)            (None, 2048)         0           RescaleGAP[0][0]                 
__________________________________________________________________________________________________
dense_17 (Dense)                (None, 1024)         2098176     dropout_18[0][0]                 
__________________________________________________________________________________________________
dropout_19 (Dropout)            (None, 1024)         0           dense_17[0][0]                   
__________________________________________________________________________________________________
dense_18 (Dense)                (None, 11)           11275       dropout_19[0][0]                 
==================================================================================================
Total params: 23,123,139
Trainable params: 23,062,467
Non-trainable params: 60,672
__________________________________________________________________________________________________

完整的堆栈跟踪 + 错误消息:

---------------------------------------------------------------------------
UnimplementedError                        Traceback (most recent call last)
<ipython-input-53-5130a0bcf331> in <module>
     19     validation_data = valid_df,
     20     validation_steps = VALIDATION_STEPS,
---> 21     callbacks = [model_save, early_stop, reduce_lr]
     22 )

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
     64   def _method_wrapper(self, *args, **kwargs):
     65     if not self._in_multi_worker_mode():  # pylint: disable=protected-access
---> 66       return method(self, *args, **kwargs)
     67 
     68     # Running inside `run_distribute_coordinator` already.

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
    853                 context.async_wait()
    854               logs = tmp_logs  # No error, now safe to assign to logs.
--> 855               callbacks.on_train_batch_end(step, logs)
    856         epoch_logs = copy.copy(logs)
    857 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py in on_train_batch_end(self, batch, logs)
    387     """
    388     if self._should_call_train_batch_hooks:
--> 389       logs = self._process_logs(logs)
    390       self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
    391 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py in _process_logs(self, logs)
    263     """Turns tensors into numpy arrays or Python scalars."""
    264     if logs:
--> 265       return tf_utils.to_numpy_or_python_type(logs)
    266     return {}
    267 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/tf_utils.py in to_numpy_or_python_type(tensors)
    521     return t  # Don't turn ragged or sparse tensors to NumPy.
    522 
--> 523   return nest.map_structure(_to_single_numpy_or_python_type, tensors)
    524 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/nest.py in map_structure(func, *structure, **kwargs)
    615 
    616   return pack_sequence_as(
--> 617       structure[0], [func(*x) for x in entries],
    618       expand_composites=expand_composites)
    619 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/nest.py in <listcomp>(.0)
    615 
    616   return pack_sequence_as(
--> 617       structure[0], [func(*x) for x in entries],
    618       expand_composites=expand_composites)
    619 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/tf_utils.py in _to_single_numpy_or_python_type(t)
    517   def _to_single_numpy_or_python_type(t):
    518     if isinstance(t, ops.Tensor):
--> 519       x = t.numpy()
    520       return x.item() if np.ndim(x) == 0 else x
    521     return t  # Don't turn ragged or sparse tensors to NumPy.

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in numpy(self)
    959     """
    960     # TODO(slebedev): Consider avoiding a copy for non-CPU or remote tensors.
--> 961     maybe_arr = self._numpy()  # pylint: disable=protected-access
    962     return maybe_arr.copy() if isinstance(maybe_arr, np.ndarray) else maybe_arr
    963 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in _numpy(self)
    927       return self._numpy_internal()
    928     except core._NotOkStatusException as e:
--> 929       six.raise_from(core._status_to_exception(e.code, e.message), None)
    930 
    931   @property

/opt/conda/lib/python3.7/site-packages/six.py in raise_from(value, from_value)

UnimplementedError: {{function_node __inference_train_function_644557}} Compilation failure: Dynamic Spatial Convolution is not supported: %convolution.30660 = f32[<=8,24,24,2048]{3,2,1,0} convolution(f32[<=8,24,24,1]{3,2,1,0} %add.30633, f32[1,1,1,2048]{3,2,1,0} %get-tuple-element.354), window={size=1x1}, dim_labels=b01f_01io->b01f, metadata={op_type="Conv2D" op_name="model_8/conv2d_69/Conv2D"}
    TPU compilation failed
     [[{{node tpu_compile_succeeded_assert/_17367812259898276239/_5}}]]

标签: pythonpython-3.xtensorflowkerastpu

解决方案


推荐阅读