python - 如何在 tf.layers.max_pooling2d Tensorflow 中处理 0 批量大小?
问题描述
我正在使用 Tensorflow fold 来编写模型,并且 Tensorflow fold 通常有 0 个批量大小(github 上的错误)。这会导致某些 Tensorflow 操作出现问题,并显示如下错误:F tensorflow/stream_executor/cuda/cuda_dnn.cc:466] could not set cudnn tensor descriptor: CUDNN_STATUS_BAD_PARAM
但是,可以通过编写自定义梯度操作来解决此问题,如下所述。
# TensorFlow Fold can generate zero-size batch for conv layer
# which will crash cuDNN on backward pass. So use this
# for arbitrary convolution in modules to avoid the crash.
def _conv_safe(inputs, filters, kernel_size, strides, activation):
g = tf.get_default_graph()
with g.gradient_override_map({'Conv2D': 'Conv2D_handle_empty_batch'}):
return tf.layers.conv2d(inputs=inputs, filters=filters, kernel_size=kernel_size,strides=strides, activation=activation)
@tf.RegisterGradient('Conv2D_handle_empty_batch')
def _Conv2DGrad(op, grad):
with tf.device('/cpu:0'):
return [tf.nn.conv2d_backprop_input(
tf.shape(op.inputs[0]), op.inputs[1], grad, op.get_attr('strides'),
op.get_attr('padding'), op.get_attr('use_cudnn_on_gpu'),
op.get_attr('data_format')),
tf.nn.conv2d_backprop_filter(op.inputs[0],
tf.shape(op.inputs[1]), grad,
op.get_attr('strides'),
op.get_attr('padding'),
op.get_attr('use_cudnn_on_gpu'),
op.get_attr('data_format'))]
我现在想知道如何在使用tf.layers.max_pooling2d
操作或任何其他形式的最大池时做类似的事情来避免这种崩溃。您可以在示例中看到tf.layers.conv2d
,我们可以通过自定义实现渐变来处理 0 批量大小来绕过它。我该怎么做tf.layers.max_pooling2d
?
注意:我使用的是 Tensorflow 1.0,因为这是 Tensorflow Fold 所支持的。
谢谢
解决方案
我认为我们可以这样做:
from tensorflow.python.ops import gen_nn_ops
def max_pooling_zero_batch(inputs, pool_size, strides, name):
g = tf.get_default_graph()
with g.gradient_override_map({'MaxPool': 'MaxPool_handle_empty_batch'}):
return tf.layers.max_pooling2d(inputs=inputs, pool_size=pool_size, strides=strides, name=name)
@tf.RegisterGradient("MaxPool_handle_empty_batch")
def _MaxPoolGrad(op, grad):
with tf.device('/cpu:0'):
return gen_nn_ops._max_pool_grad(op.inputs[0], op.outputs[0], grad, op.get_attr("ksize"), op.get_attr("strides"), padding=op.get_attr("padding"), data_format=op.get_attr("data_format"))
它似乎适用于 0 批量大小。
推荐阅读
- react-native - 单击react-native中的菜单项时如何打开页脚选项卡屏幕之一?
- angular - 如何在 Angular 中创建动态菜单
- powershell - New-QADUser:找不到与参数名称“proxyAddresses”匹配的参数
- android - Android RecyclerView 自定义图片显示
- javascript - 如何在nodeJS中将地图对象添加到mysql?
- angular - 将 *ngFor 与异步管道一起使用,并在 angular8 中的 html 中进行条件显示
- algorithm - 对 removeMin() 的调用将在 7 叉树的最大堆中进行多少次比较?
- c - 我收到此警告“常数 '10' 与布尔表达式的比较总是正确的”
- javascript - Angular 5 应用程序在控制台 SCRIPT1010 中显示空白屏幕 IE 11 并出现此错误:预期标识符(Windows 10)
- amazon-web-services - 如何在 AWS EC2 实例上设置 WSS Websockets?