image - How to use another library in the tensorflow graph?
问题描述
I just read this article. The article says that the resize algorithm of tensorflow has some bugs. Now I want to use scipy.misc.imresize
instead of tf.image.resize_images
. And I wonder what is the best way to implement the scipy resize algorithm.
Let`s consider the following layer:
def up_sample(input_tensor, new_height, new_width):
_up_sampled = tf.image.resize_images(input_tensor, [new_height, new_width])
_conv = tf.layers.conv2d(_up_sampled, 32, [3,3], padding="SAME")
return _conv
How can I use the scipy algorithm in this layer?
Edit:
An example can be this function:
input_tensor = tf.placeholder("float32", [10, 200, 200, 8])
output_shape = [32, 210, 210, 8]
def up_sample(input_tensor, output_shape):
new_array = np.zeros(output_shape)
for batch in range(input_tensor.shape[0]):
for channel in range(input_tensor.shape[-1]):
new_array[batch, :, :, channel] = misc.imresize(input_tensor[batch, :, :, channel], output_shape[1:3])
But obviously scipy raises a ValueError that the the tf.Tensor object does not have the right shape. I read that during the a tf.Session the Tensors are accessible as numpy arrays. How can I use the scipy function only during a session and omit the execution in when creating the protocol buffer?
And is there a faster way than looping over all batches and channels?
解决方案
一般来说,您需要的工具是tf.map_fn
和的组合tf.py_func
。
tf.py_func
允许您将标准 python 函数包装到插入到图形中的 tensorflow 操作中。tf.map_fn
允许您在批处理样本上重复调用函数,当该函数无法对整个批处理进行操作时——图像函数通常就是这种情况。
在目前的情况下,我可能会建议scipy.ndimage.zoom
在它可以直接在 4D 张量上操作的基础上使用,这会使事情变得更简单。另一方面,它需要输入缩放因子,而不是大小,所以我们需要计算它们。
import tensorflow as tf
sess = tf.InteractiveSession()
# unimportant -- just a way to get an input tensor
batch_size = 13
im_size = 7
num_channel=5
x = tf.eye(im_size)[None,...,None] + tf.zeros((batch_size, 1, 1, num_channel))
new_size = 17
from scipy import ndimage
new_x = tf.py_func(
lambda a: ndimage.zoom(a, (1, new_size/im_size, new_size/im_size, 1)),
[x], [tf.float32], stateful=False)[0]
print(new_x.eval().shape)
# (13, 17, 17, 5)
您可以使用其他函数(例如 OpenCV's cv2.resize
、 Scikit-image's transform.image
、 Scipy's misc.imresize
),但没有一个可以直接对 4D 张量进行操作,因此使用起来更加冗长。zoom
如果您想要除基于样条的插值之外的插值,您可能仍想使用它们。
但是,请注意以下事项:
Python 函数在主机上执行。因此,如果您在图形卡等设备上执行图形,它需要停止,将张量复制到主机内存,调用您的函数,然后将结果复制回设备上。如果内存传输很重要,这可能会完全破坏您的计算时间。
渐变不通过 python 函数。例如,如果您的节点用于网络的升级部分,则上游层将不会收到任何梯度(或只有部分梯度,如果您有跳过连接),这会影响您的训练。
出于这两个原因,我建议仅在 CPU 上进行预处理且不使用梯度时将这种重采样应用于输入。
tf.image.resize_image
如果您确实想使用这个高档节点在设备上进行训练,那么我认为除了坚持使用 buggy或自己编写之外别无选择。
推荐阅读
- java - 春季启动 - restTemplate.postForObject - 参数为空
- ruby-on-rails - 升级到 ruby 3/rails 6.1 后应用程序无法访问(提供了演示应用程序)
- segmentation-fault - 相同的错误代码在 mpiexec 时重复多次
- vba - 尝试使用 VBA 在 Word 中自动进行文档拆分
- python-3.x - ValueError:应定义输入的通道维度。找到`无`
- sql - 用于计算总和的 SQL Server 查询
- zpl - 皇家邮政 CMDM 条码
- elasticsearch - elasticsearch 错误或缺失查询
- google-sheets - 返回所有 DEC YYYY 和结果的 Google 表格公式?
- javascript - 是否有可能将 async/await 与 map 一起使用,