python - tensorflow scatter_nd 与空张量?
问题描述
我正在尝试将两个张量混合在一起。 scatter_nd
非常适合这个场合,我编写了以下函数来完成我的任务。它基本上只是将 2 个 scatter_nds 广告放在一起。
def tf_munge(t, i, r, j, axis=0):
#insert tensor t at indices i and tensor r at indices j on axis `axis`.
#requires: i.shape[0] == t.shape[axis] && j.shape[0] == r.shape[axis] && t.shape[k] == r.shape[k] ∀k != axis
i = tf.expand_dims(i, -1)
j = tf.expand_dims(j, -1)
rank_indices = tf.range(tf.rank(t))
roller = tf.roll(rank_indices, -axis, 0)
rolled_t = tf.transpose(t, roller)
rolled_r = tf.transpose(r, roller)
scatter_shape = tf.concat((tf.shape(i)[0:1] + tf.shape(j)[0:1], tf.shape(rolled_t)[1:]), axis=0)
scattered = tf.scatter_nd(i, rolled_t, scatter_shape) + tf.scatter_nd(j, rolled_r, scatter_shape)
return tf.transpose(scattered, tf.roll(rank_indices, axis, 0))
它通常按预期工作。但是,只要两者都沿某个轴为空,它就会r
失败t
。我有两个代码“路径”,具体取决于布尔值,其中我拆分张量并根据该布尔值是真还是假执行不同的操作。有时,对于 0 行,该布尔值是错误的。在这种情况下,我最终会对一个空张量做一些事情。其中之一就是这种尝试的散射。该错误实际上引用了输出形状(scatter_shape
在上面的代码中)声称:
ValueError:为“ScatterNd_4”(操作:“ScatterNd”)的空输出形状指定的索引和更新
,输入形状为:[3,1]、[3,0,2]、[3],输入张量计算为部分形状:输入[2] = [5,0,2]。
请注意,空的轴与我散布的轴不同。这是一个工作示例:
foo = tf.ones((3,1,2))
bar = tf.ones((2,1,2))*2
i = tf.constant([1,3,4])
j = tf.constant([0,2])
tf_munge(foo,i,bar,j,axis=0)
#Output: <tf.Tensor 'transpose_13:0' shape=(5, 1, 2) dtype=float32>
这是一个失败的例子:
foo = tf.ones((3,0,2))
bar = tf.ones((2,0,2))*2
tf_munge(foo,i,bar,j,axis=0)
#Output: The error above
这里的预期输出显然是 shape 的空张量(5,0,2)
。
我考虑过对输入的形状使用条件,但同时tf.cond
执行两个路径。当我有一个空的张量时,我该如何处理这种情况scatter_nd
?
解决方案
您可以通过tf.gather
适用于所有情况的方式更简单地做到这一点:
import tensorflow as tf
def tf_munge(t, i, r, j, axis=0):
tr = tf.concat([t, r], axis=axis)
idx = tf.argsort(tf.concat([i, j], axis=0))
return tf.gather(tr, idx, axis=axis)
with tf.Graph().as_default(), tf.Session() as sess:
foo = tf.ones((3, 1, 2))
bar = tf.ones((2, 1, 2)) * 2
i = tf.constant([1, 3, 4])
j = tf.constant([0, 2])
out = tf_munge(foo, i, bar, j, axis=0)
print(sess.run(out))
# [[[2. 2.]]
#
# [[1. 1.]]
#
# [[2. 2.]]
#
# [[1. 1.]]
#
# [[1. 1.]]]
foo2 = tf.ones((3, 0, 2))
bar2 = tf.ones((2, 0, 2)) * 2
out2 = tf_munge(foo2, i, bar2, j, axis=0)
print(sess.run(out2))
# []
推荐阅读
- csv - 具有大量列的数据框 - 导入时会在没有数据的列中创建 NA
- go - 在发出 HTTP 请求时,Go 是否有某种方式使用 DefaultNetworkCredentials 和代理?
- php - 如何将 alt 属性放在循环内的 img cmb2 中?
- delphi - 数值返回与实际不同
- python - 在实现 CountVectorizer() 时出现错误“空词汇;也许文档仅包含停用词”
- javascript - Matter.js — 如何防止 mouseConstraint 捕获滚动事件?
- java - For if loop wont loop:任何指向为什么会这样的指针?
- apache-spark - 在 Spark 表列名中保留特殊字符
- javascript - 防止点击元素
- python - 使用 Pandas 数据帧查找多个索引上所有列之间的差异(增量)