首页 > 解决方案 > 包含scatter_nd_update的tensorflow中while_loop的实现

问题描述

好的,这个问题与上一个问题有关(如果它有效,它也可以被认为是该问题的答案)。无论如何,我正在尝试根据tf.while_loop我的需要实现一个。我需要在其中应用一个tf.scatter_nd_update函数,但由于某种原因它会引发错误。

一个揭示问题的小脚本是这样的:

import tensorflow as tf

ref = tf.Variable([[0, 1, 0, 2],
                   [0, 1, 2, 2],
                   [1, 2, 1, 3]], dtype=tf.int32)
true_array = tf.Variable([[1, 1, 1, 1]])
false_array = tf.Variable([[1, 0, 1, 0]])
num_iters = tf.Variable(3, dtype=tf.int32)  # 3


def body(ref, true_array, false_array, j, num_iters):
    samples = tf.cond(tf.equal(tf.reduce_sum(ref[j, :], axis=0), 1), lambda: true_array, lambda: false_array)
    ref = tf.scatter_nd_update(ref, [[j]], samples)
     j = tf.add(j, 1)
    return ref, true_array, false_array, j, num_iters


cond = lambda ref, true_array, false_array, j, num_iters: tf.less(j, num_iters)
j = tf.Variable(0, dtype=tf.int32)  # tf.constant(0)
ref, true_array, false_array, j, num_iters = tf.while_loop(cond, body, [ref, true_array, false_array, j, num_iters])
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print('ref', sess.run(ref))
    print('j', sess.run(j))

引发错误

AttributeError:“张量”对象没有属性“_lazy_read”

排队ref = tf.scatter_nd_update(ref, [[j]], samples)。出于某种原因,我的一个对象是一个Tensor不包含_lazy_read属性的对象。这里有一个类似的问题建议将张量转换为变量,但在我的情况下它不起作用。我试图ref = lambda: tf.Variable(ref)在函数内部使用,body但这会引发另一个错误:

AttributeError:“函数”对象没有属性“dtype”

有谁知道如何实现while_loop这个scatter_nd_update

值得注意的是,上面的代码在急切模式下工作得很好。

tf 版本 1.15

编辑:
删除了急切的执行模式(被错误地插入)。

标签: pythontensorflow

解决方案


经过多次挫折和死胡同的搜索,我找到了适合我的组合的可行解决方案。

我应该只使用tf.tensor_scatter_nd_update代替tf.scatter_nd_update它似乎工作正常。不过,我无法具体说明这两者之间的区别。第一条指出:

根据索引将更新分散到现有张量中。

而第二个:

对变量中的单个值或切片应用稀疏更新。

但对我来说,他们似乎完成了同样的工作。无论如何,也许 tensorflow 团队有类似的想法,因为我认为这两个只有第一个仍然在 tf 2.0 版本中。

import tensorflow as tf

ref = tf.Variable([[0, 1, 0, 2],
                   [0, 1, 2, 2],
                   [1, 2, 1, 3]], dtype=tf.int32)
true_array = tf.Variable([[1, 1, 1, 1]])
false_array = tf.Variable([[1, 0, 1, 0]])
num_iters = tf.Variable(3, dtype=tf.int32)  # 3


def body(ref, true_array, false_array, j, num_iters):
    samples = tf.cond(tf.equal(tf.reduce_sum(ref[j, :], axis=0), 1), lambda: true_array, lambda: false_array)
    ref = tf.tensor_scatter_nd_update(ref, [[j]], samples)
     j = tf.add(j, 1)
    return ref, true_array, false_array, j, num_iters


cond = lambda ref, true_array, false_array, j, num_iters: tf.less(j, num_iters)
j = tf.Variable(0, dtype=tf.int32)  # tf.constant(0)
ref, true_array, false_array, j, num_iters = tf.while_loop(cond, body, [ref, true_array, false_array, j, num_iters])
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print('ref', sess.run(ref))
    print('j', sess.run(j))

输出为:

update [[1 1 1 1]
 [1 0 1 0]
 [1 0 1 0]]
samples 3

推荐阅读