首页 > 解决方案 > How to use another library in the tensorflow graph?

问题描述

I just read this article. The article says that the resize algorithm of tensorflow has some bugs. Now I want to use scipy.misc.imresize instead of tf.image.resize_images. And I wonder what is the best way to implement the scipy resize algorithm.

Let`s consider the following layer:

def up_sample(input_tensor, new_height, new_width):
    _up_sampled = tf.image.resize_images(input_tensor, [new_height, new_width])
    _conv = tf.layers.conv2d(_up_sampled, 32, [3,3], padding="SAME")
    return _conv

How can I use the scipy algorithm in this layer?

Edit:

An example can be this function:

input_tensor = tf.placeholder("float32", [10, 200, 200, 8])
output_shape = [32, 210, 210, 8]

def up_sample(input_tensor, output_shape):
    new_array = np.zeros(output_shape)
    for batch in range(input_tensor.shape[0]):
        for channel in range(input_tensor.shape[-1]):
            new_array[batch, :, :, channel] = misc.imresize(input_tensor[batch, :, :, channel], output_shape[1:3])

But obviously scipy raises a ValueError that the the tf.Tensor object does not have the right shape. I read that during the a tf.Session the Tensors are accessible as numpy arrays. How can I use the scipy function only during a session and omit the execution in when creating the protocol buffer?

And is there a faster way than looping over all batches and channels?

标签: imagetensorflowscipy

解决方案


一般来说,您需要的工具是tf.map_fn和的组合tf.py_func

  • tf.py_func允许您将标准 python 函数包装到插入到图形中的 tensorflow 操作中。
  • tf.map_fn允许您在批处理样本上重复调用函数,当该函数无法对整个批处理进行操作时——图像函数通常就是这种情况。

在目前的情况下,我可能会建议scipy.ndimage.zoom在它可以直接在 4D 张量上操作的基础上使用,这会使事情变得更简单。另一方面,它需要输入缩放因子,而不是大小,所以我们需要计算它们。

import tensorflow as tf

sess = tf.InteractiveSession()

# unimportant -- just a way to get an input tensor
batch_size = 13
im_size = 7
num_channel=5
x = tf.eye(im_size)[None,...,None] + tf.zeros((batch_size, 1, 1, num_channel))
new_size = 17

from scipy import ndimage
new_x = tf.py_func(
  lambda a: ndimage.zoom(a, (1, new_size/im_size, new_size/im_size, 1)),
  [x], [tf.float32], stateful=False)[0]
print(new_x.eval().shape)
# (13, 17, 17, 5)

您可以使用其他函数(例如 OpenCV's cv2.resize、 Scikit-image's transform.image、 Scipy's misc.imresize),但没有一个可以直接对 4D 张量进行操作,因此使用起来更加冗长。zoom如果您想要除基于样条的插值之外的插值,您可能仍想使用它们。

但是,请注意以下事项:

  1. Python 函数在主机上执行。因此,如果您在图形卡等设备上执行图形,它需要停止,将张量复制到主机内存,调用您的函数,然后将结果复制回设备上。如果内存传输很重要,这可能会完全破坏您的计算时间。

  2. 渐变不通过 python 函数。例如,如果您的节点用于网络的升级部分,则上游层将不会收到任何梯度(或只有部分梯度,如果您有跳过连接),这会影响您的训练。

出于这两个原因,我建议仅在 CPU 上进行预处理且不使用梯度时将这种重采样应用于输入。

tf.image.resize_image如果您确实想使用这个高档节点在设备上进行训练,那么我认为除了坚持使用 buggy或自己编写之外别无选择。


推荐阅读