首页 > 解决方案 > 如何在 tf.while_loop() 中存储张量的值

问题描述

我有一个用tf.while_loop()手工设计的循环神经网络。我想在训练期间存储一些输出的所有值,看看它是如何工作的。但我想为每一种功能都这样做。

这意味着我必须在训练期间将相应的输入与其隐藏状态相匹配。为此,我声明了一个全局变量Dic = {}并在循环体中使用它:

def body(t1, t2, h, x_array, l_middle):
    global Dic
    x = x_array.read(t1)
    if x.eval() in Dic:
        Dic[str(x.eval())] += 1
    else:
        Dic[str(x.eval())] = 1
    h = tf.multiply(h, x)
    print(type(h))
    h.set_shape([1, 24])
    l_middle = tf.concat([l_middle, h], axis=0)
    t1 = tf.add(t1, 1)
    return [t1, t2, h, x_array, l_middle]

def iterr2(L, ll, N):
    x_array = TensorArr.unstack(x)
    L = tf.unstack(ll)
    h = tf.ones([1, n_hiddens]) * 0.01
    for i in range(11):
        l_middle = tf.zeros([1, 24])
        aux = L[i]
        right = L[i + 1]
        s = L[i].get_shape()
        T, _, g, _, l = tf.while_loop(
            cond,
            body,
            [aux, right, h, x_array, l_middle],
            shape_invariants=[
                s,
                s,
                h.get_shape(),
                tf.TensorShape([]),
                tf.TensorShape([None, 24])])
    return T, g, l

这不起作用,因为ValueError: Operation u'while/TensorArrayReadV3' has been marked as not fetchable.

如果我不 eval() 它,那么我的字典中只有 10 位先生:Tensor("while_2/TensorArrayReadV3:0", dtype=float32)

如果您需要完整的错误日志或一些代码以使其运行,我可以提供它,但我认为它会提出一个太长的问题。

标签: pythondictionarytensorflowdeep-learning

解决方案


推荐阅读