python - TPU 不支持动态空间卷积
问题描述
我正在尝试将 Keras (TF 2.3.1) 模型用于图像分类,并将多个二进制标签作为输出。该模型由 Xception CNN + 注意力层 + 密集分类器组成,仅在一些 TPU 上出现错误:
UnimplementedError: {{function_node __inference_train_function_644557}} Compilation failure: Dynamic Spatial Convolution is not supported
. 这在 Kaggle TPU 上失败,但在 Colab 上失败- 在 TF 版本 2.3.1 上进行了测试。
我在看这里,但建议的解决方案意味着没有设置图像尺寸,这里不是这种情况。train_df
是类型<PrefetchDataset shapes: ((None, 750, 750, 3), (None, 11)), types: (tf.float32, tf.int64)>
,因此每个图像的大小为 750x750x3。根据下面的模型摘要,每一层都有一个定义的输出形状,因此跟随它们的层应该正确地推断它们的输入形状。
从错误来看,问题似乎出在attn_layer = LocallyConnected2D(...
. Passingimplementation = 2
是一种让训练完成的解决方法,但这不适用于大型模型(请参阅 LocallyConnected2D文档)
建模代码:
import tensorflow as tf
from tensorflow.keras import models, layers
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.applications import Xception
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout, Flatten, Input, Conv2D, multiply, LocallyConnected2D, Lambda, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.metrics import mean_absolute_error
def create_model():
input_shape = (TARGET_SIZE, TARGET_SIZE, 3)
in_lay = Input(input_shape)
conv_base = Xception(include_top = False, weights = 'imagenet', input_shape = input_shape)
pt_features = conv_base(in_lay)
bn_features = BatchNormalization()(pt_features)
# here we do an attention mechanism to turn pixels in the GAP on an off
attn_layer = Conv2D(64, kernel_size = (1,1), padding = 'same', activation = 'relu')(bn_features)
attn_layer = Conv2D(16, kernel_size = (1,1), padding = 'same', activation = 'relu')(attn_layer)
attn_layer = LocallyConnected2D(1, kernel_size = (1,1), padding = 'valid', activation = 'sigmoid')(attn_layer)
# fan it out to all of the channels
pt_depth = conv_base.get_output_shape_at(0)[-1]
up_c2_w = np.ones((1, 1, 1, pt_depth))
up_c2 = Conv2D(pt_depth, kernel_size = (1,1), padding = 'same',
activation = 'linear', use_bias = False, weights = [up_c2_w])
up_c2.trainable = False
attn_layer = up_c2(attn_layer)
mask_features = multiply([attn_layer, bn_features])
gap_features = GlobalAveragePooling2D()(mask_features)
gap_mask = GlobalAveragePooling2D()(attn_layer)
# to account for missing values from the attention model
gap = Lambda(lambda x: x[0]/x[1], name = 'RescaleGAP')([gap_features, gap_mask])
gap_dr = Dropout(0.5)(gap)
dr_steps = Dropout(0.25)(Dense(1024, activation = 'elu')(gap_dr))
out_layer = Dense(11, activation = 'sigmoid')(dr_steps)
model = Model(inputs = [in_lay], outputs = [out_layer])
model.compile(optimizer = Adam(lr = 0.002), loss = 'binary_crossentropy', metrics = ["AUC"])
return model
with tpu_strategy.scope():
model = create_model()
model.summary()
history = model.fit(
train_df,
epochs = EPOCHS,
steps_per_epoch = STEPS_PER_EPOCH,
validation_data = valid_df,
validation_steps = VALIDATION_STEPS
)
生成的模型摘要:
Model: "model_8"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_19 (InputLayer) [(None, 750, 750, 3) 0
__________________________________________________________________________________________________
xception (Model) (None, 24, 24, 2048) 20861480 input_19[0][0]
__________________________________________________________________________________________________
batch_normalization_49 (BatchNo (None, 24, 24, 2048) 8192 xception[1][0]
__________________________________________________________________________________________________
conv2d_67 (Conv2D) (None, 24, 24, 64) 131136 batch_normalization_49[0][0]
__________________________________________________________________________________________________
conv2d_68 (Conv2D) (None, 24, 24, 16) 1040 conv2d_67[0][0]
__________________________________________________________________________________________________
locally_connected2d_9 (LocallyC (None, 24, 24, 1) 9792 conv2d_68[0][0]
__________________________________________________________________________________________________
conv2d_69 (Conv2D) (None, 24, 24, 2048) 2048 locally_connected2d_9[0][0]
__________________________________________________________________________________________________
multiply_9 (Multiply) (None, 24, 24, 2048) 0 conv2d_69[0][0]
batch_normalization_49[0][0]
__________________________________________________________________________________________________
global_average_pooling2d_23 (Gl (None, 2048) 0 multiply_9[0][0]
__________________________________________________________________________________________________
global_average_pooling2d_24 (Gl (None, 2048) 0 conv2d_69[0][0]
__________________________________________________________________________________________________
RescaleGAP (Lambda) (None, 2048) 0 global_average_pooling2d_23[0][0]
global_average_pooling2d_24[0][0]
__________________________________________________________________________________________________
dropout_18 (Dropout) (None, 2048) 0 RescaleGAP[0][0]
__________________________________________________________________________________________________
dense_17 (Dense) (None, 1024) 2098176 dropout_18[0][0]
__________________________________________________________________________________________________
dropout_19 (Dropout) (None, 1024) 0 dense_17[0][0]
__________________________________________________________________________________________________
dense_18 (Dense) (None, 11) 11275 dropout_19[0][0]
==================================================================================================
Total params: 23,123,139
Trainable params: 23,062,467
Non-trainable params: 60,672
__________________________________________________________________________________________________
完整的堆栈跟踪 + 错误消息:
---------------------------------------------------------------------------
UnimplementedError Traceback (most recent call last)
<ipython-input-53-5130a0bcf331> in <module>
19 validation_data = valid_df,
20 validation_steps = VALIDATION_STEPS,
---> 21 callbacks = [model_save, early_stop, reduce_lr]
22 )
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
64 def _method_wrapper(self, *args, **kwargs):
65 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
---> 66 return method(self, *args, **kwargs)
67
68 # Running inside `run_distribute_coordinator` already.
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
853 context.async_wait()
854 logs = tmp_logs # No error, now safe to assign to logs.
--> 855 callbacks.on_train_batch_end(step, logs)
856 epoch_logs = copy.copy(logs)
857
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py in on_train_batch_end(self, batch, logs)
387 """
388 if self._should_call_train_batch_hooks:
--> 389 logs = self._process_logs(logs)
390 self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
391
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py in _process_logs(self, logs)
263 """Turns tensors into numpy arrays or Python scalars."""
264 if logs:
--> 265 return tf_utils.to_numpy_or_python_type(logs)
266 return {}
267
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/tf_utils.py in to_numpy_or_python_type(tensors)
521 return t # Don't turn ragged or sparse tensors to NumPy.
522
--> 523 return nest.map_structure(_to_single_numpy_or_python_type, tensors)
524
/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/nest.py in map_structure(func, *structure, **kwargs)
615
616 return pack_sequence_as(
--> 617 structure[0], [func(*x) for x in entries],
618 expand_composites=expand_composites)
619
/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/nest.py in <listcomp>(.0)
615
616 return pack_sequence_as(
--> 617 structure[0], [func(*x) for x in entries],
618 expand_composites=expand_composites)
619
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/tf_utils.py in _to_single_numpy_or_python_type(t)
517 def _to_single_numpy_or_python_type(t):
518 if isinstance(t, ops.Tensor):
--> 519 x = t.numpy()
520 return x.item() if np.ndim(x) == 0 else x
521 return t # Don't turn ragged or sparse tensors to NumPy.
/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in numpy(self)
959 """
960 # TODO(slebedev): Consider avoiding a copy for non-CPU or remote tensors.
--> 961 maybe_arr = self._numpy() # pylint: disable=protected-access
962 return maybe_arr.copy() if isinstance(maybe_arr, np.ndarray) else maybe_arr
963
/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in _numpy(self)
927 return self._numpy_internal()
928 except core._NotOkStatusException as e:
--> 929 six.raise_from(core._status_to_exception(e.code, e.message), None)
930
931 @property
/opt/conda/lib/python3.7/site-packages/six.py in raise_from(value, from_value)
UnimplementedError: {{function_node __inference_train_function_644557}} Compilation failure: Dynamic Spatial Convolution is not supported: %convolution.30660 = f32[<=8,24,24,2048]{3,2,1,0} convolution(f32[<=8,24,24,1]{3,2,1,0} %add.30633, f32[1,1,1,2048]{3,2,1,0} %get-tuple-element.354), window={size=1x1}, dim_labels=b01f_01io->b01f, metadata={op_type="Conv2D" op_name="model_8/conv2d_69/Conv2D"}
TPU compilation failed
[[{{node tpu_compile_succeeded_assert/_17367812259898276239/_5}}]]
解决方案
推荐阅读
- nlp - 如何在spacy中将后缀附加到令牌
- c# - ComboBox SelectedIndexChanged 总是被触发/触发,甚至被禁用
- python - 用于检索 100 行的 hdf5(h5py 或 pytables)与 numpy memmep(与其他)的速度分析/预测。总共 3000 万行,每行 512 个整数
- json - SignalR .Net Core 3.1 无法从服务类发送 SendAsync 方法中的对象
- javascript - 使用句柄提交和传递消息反应重定向到页面
- python - 在 Windows 上更改权限后 Python 找不到模块
- php - 使用 unlink() 删除后文件继续存在
- xml - 如何在 Flutter 中对 HttpClient 流中的 XML 元素进行分组
- javascript - 在 Angular 表中每行只维护 3 个项目
- angular - Angular - Azure Key Vault 管理 Vault 访问机密