首页 > 解决方案 > BatchToSpaceND 实际是如何工作的?

问题描述

我试图弄清楚BatchToSpaceND如何排列输入矩阵。示例之一如下:

(3)对于以下形状为[4,2,2,1]和block_size为2的输入:

x = [[[[1], [3]], [[9], [11]]],
     [[[2], [4]], [[10], [12]]],
     [[[5], [7]], [[13], [15]]],
     [[[6], [8]], [[14], [16]]]]

输出张量的形状为 [1, 4, 4, 1] 和值:

x = [[[1],   [2],  [3],  [4]],
     [[5],   [6],  [7],  [8]],
     [[9],  [10], [11],  [12]],
     [[13], [14], [15],  [16]]]

有人知道输出张量是如何得出的吗?为什么第一行是[[1], [2], [3], [4]]而不是[[1], [3], [9], [11]]相反?我也尝试了一些代码:

import tensorflow as tf
sess = tf.InteractiveSession()

a = [[[[1], [3]], [[9], [11]]],
     [[[2], [4]], [[10], [12]]],
     [[[5], [7]], [[13], [15]]],
     [[[6], [8]], [[14], [16]]]]
b = [2, 2, 1, 2, 2, 1]
a = tf.reshape(a, b)

b = [1, 2, 2, 2, 2, 1]
a = tf.reshape(a, b)

b = [1, 4, 4, 1]
a = tf.reshape(a, b)

print(a.eval())

[[[[ 1]
   [ 3]
   [ 9]
   [11]]

  [[ 2]
   [ 4]
   [10]
   [12]]

  [[ 5]
   [ 7]
   [13]
   [15]]

  [[ 6]
   [ 8]
   [14]
   [16]]]]

这不是文档中的结果。

标签: pythonpython-3.xtensorflow

解决方案


让我们考虑文档的参数部分:

input:一个Tensor。ND 与 shape input_shape = [batch] + spatial_shape + remaining_shape,其中spatial_shapeM尺寸。

所以对于具体的例子,这意味着我们有批量维度4、空间形状(2, 2)和剩余形状(1,)。在这里考虑一个现实世界的例子是有启发性的。让我们把这个输入张量看作是一组 4 个 2x2 图像的批次,具有 1 个通道(例如灰度)。由于操作没有修改,remaining_shape我们可以忽略它以进行进一步的探索。也就是说,输入有效地包含以下 2x2“图像”:

 1   3
 9  11
--------
 2   4
10  12
--------
 5   7
13  15
--------
 6   8
14  16

现在该操作要求将批量维度重塑为空间维度,类似于将a大小为 的一维数组重塑batcha.reshape(-1, *block_shape). 如果我们考虑批量索引[0, 1, 2, 3],它们将被重新整形[[0, 1], [2, 3]](省略新的大小为 1 的批量维度)。实际上,这意味着我们应该将四个 2x2 图像并排放置,block_shape指示布局,以便创建一个 4x4 图像。然而,此时我们还没有完成,因为还有一个额外的步骤,即空间维度是交错的,如文档所示:

此操作将“批量”维度 0 重新M + 1整形为 shape 维度block_shape + [batch]将这些块交错回由空间维度定义的网格中[1, ..., M],以获得与输入具有相同等级的结果。

那就是在我们拥有的网格中布置图像:

 1   3     2   4
 9  11    10  12

 5   7     6   8
13  15    14  16

现在我们只剩下交错各个图像的行和列维度,以得到最终结果:

        -------⅂
       |       |
    -------⅂   |
   |   |   |   |
   v   v   |   |

 1   3     2   4
                  <---⅂
 9  11    10  12      |
                  <---|---⅂
                      |   |
                      |   |
 5   7     6   8   ---⅃   |
                          |
13  15    14  16   -------⅃

这使:

 1   2     3   4
 5   6     7   8
 9  10    11  12
13  14    15  16

示例的实际输出具有形状(1, 4, 4, 1),因为它包含附加值remaining_shape(为了示例,我们已将其省略)并且它保留了批处理维度(在本例中为 1)。

等效代码示例

import numpy as np
import tensorflow as tf

sess = tf.InteractiveSession()

a = np.array([[[[1], [3]], [[ 9], [11]]],
              [[[2], [4]], [[10], [12]]],
              [[[5], [7]], [[13], [15]]],
              [[[6], [8]], [[14], [16]]]])

block_shape = (2, 2)

new_batch_size = a.shape[0] // np.prod(block_shape)
b = tf.reshape(a, 
    block_shape
    + (new_batch_size,)
    + a.shape[1:]
)
# Hard-coded version:
# b = tf.transpose(b, [2, 3, 0, 4, 1, 5])
# Generic version:
b = tf.transpose(b,
    (len(block_shape),)
    + tuple(j for i in range(len(block_shape)) for j in (i + len(block_shape) + 1, i))
    + tuple(i + 2*len(block_shape) + 1 for i in range(len(a.shape) - len(block_shape) - 1))
)
b = tf.reshape(b,
    (new_batch_size,)
    + tuple(i*j for i, j in zip(block_shape, a.shape[1:]))
    + a.shape[1+len(block_shape):]
)
print(b.eval())
print(b.shape)

推荐阅读