首页 > 解决方案 > tensorflow scatter_nd 与空张量?

问题描述

我正在尝试将两个张量混合在一起。 scatter_nd非常适合这个场合,我编写了以下函数来完成我的任务。它基本上只是将 2 个 scatter_nds 广告放在一起。

def tf_munge(t, i, r, j, axis=0):
    #insert tensor t at indices i and tensor r at indices j on axis `axis`.
    #requires: i.shape[0] == t.shape[axis] && j.shape[0] == r.shape[axis] && t.shape[k] == r.shape[k] ∀k != axis
    i = tf.expand_dims(i, -1)
    j = tf.expand_dims(j, -1)
    rank_indices = tf.range(tf.rank(t))
    roller = tf.roll(rank_indices, -axis, 0)
    rolled_t = tf.transpose(t, roller)
    rolled_r = tf.transpose(r, roller)
    scatter_shape = tf.concat((tf.shape(i)[0:1] + tf.shape(j)[0:1], tf.shape(rolled_t)[1:]), axis=0)
    scattered = tf.scatter_nd(i, rolled_t, scatter_shape) + tf.scatter_nd(j, rolled_r, scatter_shape)
    return tf.transpose(scattered, tf.roll(rank_indices, axis, 0))

它通常按预期工作。但是,只要两者都沿某个轴为空,它就会r失败t。我有两个代码“路径”,具体取决于布尔值,其中我拆分张量并根据该布尔值是真还是假执行不同的操作。有时,对于 0 行,该布尔值是错误的。在这种情况下,我最终会对一个空张量做一些事情。其中之一就是这种尝试的散射。该错误实际上引用了输出形状(scatter_shape在上面的代码中)声称:

ValueError:为“ScatterNd_4”(操作:“ScatterNd”)的空输出形状指定的索引和更新
,输入形状为:[3,1]、[3,0,2]、[3],输入张量计算为部分形状:输入[2] = [5,0,2]。

请注意,空的轴与我散布的轴不同。这是一个工作示例:

foo = tf.ones((3,1,2))
bar = tf.ones((2,1,2))*2
i = tf.constant([1,3,4])
j = tf.constant([0,2])
tf_munge(foo,i,bar,j,axis=0)
#Output:  <tf.Tensor 'transpose_13:0' shape=(5, 1, 2) dtype=float32>

这是一个失败的例子:

foo = tf.ones((3,0,2))
bar = tf.ones((2,0,2))*2
tf_munge(foo,i,bar,j,axis=0)
#Output: The error above

这里的预期输出显然是 shape 的空张量(5,0,2)

我考虑过对输入的形状使用条件,但同时tf.cond 执行两个路径。当我有一个空的张量时,我该如何处理这种情况scatter_nd

标签: pythontensorflow

解决方案


您可以通过tf.gather适用于所有情况的方式更简单地做到这一点:

import tensorflow as tf

def tf_munge(t, i, r, j, axis=0):
    tr = tf.concat([t, r], axis=axis)
    idx = tf.argsort(tf.concat([i, j], axis=0))
    return tf.gather(tr, idx, axis=axis)

with tf.Graph().as_default(), tf.Session() as sess:
    foo = tf.ones((3, 1, 2))
    bar = tf.ones((2, 1, 2)) * 2
    i = tf.constant([1, 3, 4])
    j = tf.constant([0, 2])
    out = tf_munge(foo, i, bar, j, axis=0)
    print(sess.run(out))
    # [[[2. 2.]]
    # 
    #  [[1. 1.]]
    # 
    #  [[2. 2.]]
    # 
    #  [[1. 1.]]
    # 
    #  [[1. 1.]]]
    foo2 = tf.ones((3, 0, 2))
    bar2 = tf.ones((2, 0, 2)) * 2
    out2 = tf_munge(foo2, i, bar2, j, axis=0)
    print(sess.run(out2))
    # []

推荐阅读