首页 > 解决方案 > 如何在 tf.layers.max_pooling2d Tensorflow 中处理 0 批量大小?

问题描述

我正在使用 Tensorflow fold 来编写模型,并且 Tensorflow fold 通常有 0 个批量大小(github 上的错误)。这会导致某些 Tensorflow 操作出现问题,并显示如下错误:F tensorflow/stream_executor/cuda/cuda_dnn.cc:466] could not set cudnn tensor descriptor: CUDNN_STATUS_BAD_PARAM

但是,可以通过编写自定义梯度操作来解决此问题,如下所述。

# TensorFlow Fold can generate zero-size batch for conv layer
# which will crash cuDNN on backward pass. So use this
# for arbitrary convolution in modules to avoid the crash.
def _conv_safe(inputs, filters, kernel_size, strides, activation):
    g = tf.get_default_graph()
    with g.gradient_override_map({'Conv2D': 'Conv2D_handle_empty_batch'}):
        return tf.layers.conv2d(inputs=inputs, filters=filters, kernel_size=kernel_size,strides=strides, activation=activation)

@tf.RegisterGradient('Conv2D_handle_empty_batch')
def _Conv2DGrad(op, grad):
    with tf.device('/cpu:0'):
        return [tf.nn.conv2d_backprop_input(
                tf.shape(op.inputs[0]), op.inputs[1], grad, op.get_attr('strides'),
                op.get_attr('padding'), op.get_attr('use_cudnn_on_gpu'),
                op.get_attr('data_format')),
                tf.nn.conv2d_backprop_filter(op.inputs[0],
                                             tf.shape(op.inputs[1]), grad,
                                             op.get_attr('strides'),
                                             op.get_attr('padding'),
                                             op.get_attr('use_cudnn_on_gpu'),
                                             op.get_attr('data_format'))]

我现在想知道如何在使用tf.layers.max_pooling2d操作或任何其他形式的最大池时做类似的事情来避免这种崩溃。您可以在示例中看到tf.layers.conv2d,我们可以通过自定义实现渐变来处理 0 批量大小来绕过它。我该怎么做tf.layers.max_pooling2d

注意:我使用的是 Tensorflow 1.0,因为这是 Tensorflow Fold 所支持的。

谢谢

标签: pythontensorflow

解决方案


我认为我们可以这样做:

from tensorflow.python.ops import gen_nn_ops

def max_pooling_zero_batch(inputs, pool_size, strides, name):

    g = tf.get_default_graph()
    with g.gradient_override_map({'MaxPool': 'MaxPool_handle_empty_batch'}):  
        return tf.layers.max_pooling2d(inputs=inputs, pool_size=pool_size, strides=strides, name=name)

@tf.RegisterGradient("MaxPool_handle_empty_batch")
def _MaxPoolGrad(op, grad):
    with tf.device('/cpu:0'):
        return gen_nn_ops._max_pool_grad(op.inputs[0], op.outputs[0], grad, op.get_attr("ksize"), op.get_attr("strides"), padding=op.get_attr("padding"), data_format=op.get_attr("data_format"))

它似乎适用于 0 批量大小。


推荐阅读