python - 包含scatter_nd_update的tensorflow中while_loop的实现
问题描述
好的,这个问题与上一个问题有关(如果它有效,它也可以被认为是该问题的答案)。无论如何,我正在尝试根据tf.while_loop
我的需要实现一个。我需要在其中应用一个tf.scatter_nd_update
函数,但由于某种原因它会引发错误。
一个揭示问题的小脚本是这样的:
import tensorflow as tf
ref = tf.Variable([[0, 1, 0, 2],
[0, 1, 2, 2],
[1, 2, 1, 3]], dtype=tf.int32)
true_array = tf.Variable([[1, 1, 1, 1]])
false_array = tf.Variable([[1, 0, 1, 0]])
num_iters = tf.Variable(3, dtype=tf.int32) # 3
def body(ref, true_array, false_array, j, num_iters):
samples = tf.cond(tf.equal(tf.reduce_sum(ref[j, :], axis=0), 1), lambda: true_array, lambda: false_array)
ref = tf.scatter_nd_update(ref, [[j]], samples)
j = tf.add(j, 1)
return ref, true_array, false_array, j, num_iters
cond = lambda ref, true_array, false_array, j, num_iters: tf.less(j, num_iters)
j = tf.Variable(0, dtype=tf.int32) # tf.constant(0)
ref, true_array, false_array, j, num_iters = tf.while_loop(cond, body, [ref, true_array, false_array, j, num_iters])
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print('ref', sess.run(ref))
print('j', sess.run(j))
引发错误
AttributeError:“张量”对象没有属性“_lazy_read”
排队ref = tf.scatter_nd_update(ref, [[j]], samples)
。出于某种原因,我的一个对象是一个Tensor
不包含_lazy_read
属性的对象。这里有一个类似的问题建议将张量转换为变量,但在我的情况下它不起作用。我试图ref = lambda: tf.Variable(ref)
在函数内部使用,body
但这会引发另一个错误:
AttributeError:“函数”对象没有属性“dtype”
有谁知道如何实现while_loop
这个scatter_nd_update
?
值得注意的是,上面的代码在急切模式下工作得很好。
tf 版本 1.15
编辑:
删除了急切的执行模式(被错误地插入)。
解决方案
经过多次挫折和死胡同的搜索,我找到了适合我的组合的可行解决方案。
我应该只使用tf.tensor_scatter_nd_update
代替tf.scatter_nd_update
它似乎工作正常。不过,我无法具体说明这两者之间的区别。第一条指出:
根据索引将更新分散到现有张量中。
而第二个:
对变量中的单个值或切片应用稀疏更新。
但对我来说,他们似乎完成了同样的工作。无论如何,也许 tensorflow 团队有类似的想法,因为我认为这两个只有第一个仍然在 tf 2.0 版本中。
import tensorflow as tf
ref = tf.Variable([[0, 1, 0, 2],
[0, 1, 2, 2],
[1, 2, 1, 3]], dtype=tf.int32)
true_array = tf.Variable([[1, 1, 1, 1]])
false_array = tf.Variable([[1, 0, 1, 0]])
num_iters = tf.Variable(3, dtype=tf.int32) # 3
def body(ref, true_array, false_array, j, num_iters):
samples = tf.cond(tf.equal(tf.reduce_sum(ref[j, :], axis=0), 1), lambda: true_array, lambda: false_array)
ref = tf.tensor_scatter_nd_update(ref, [[j]], samples)
j = tf.add(j, 1)
return ref, true_array, false_array, j, num_iters
cond = lambda ref, true_array, false_array, j, num_iters: tf.less(j, num_iters)
j = tf.Variable(0, dtype=tf.int32) # tf.constant(0)
ref, true_array, false_array, j, num_iters = tf.while_loop(cond, body, [ref, true_array, false_array, j, num_iters])
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print('ref', sess.run(ref))
print('j', sess.run(j))
输出为:
update [[1 1 1 1]
[1 0 1 0]
[1 0 1 0]]
samples 3
推荐阅读
- python - 解释正则表达式
- sqlite - 升级 DBProvider Flutter
- html - 如何对齐文字与
代替 - java - Android:FusedLocationProvider,每隔几秒获取一次位置
- python - 如何从C访问python字典?
- javascript - 基于节点红色温度的灯或风扇
- javascript - 如何在 (window).resize 之后完全删除 jquery 函数
- vb.net - 循环中的sqlite更新需要很长时间
- elasticsearch - ElasticSearch - 列表元素中的模糊搜索
- dji-sdk - 如何使用 DJI mobile sdk 连接 DJI 产品