首页 > 解决方案 > 如果使用 tf.while_loop 的输出张量,网络将不会训练

问题描述

我有一个自定义损失,它使用tf.while_loop(). 代码如下。如果我使用 tf.while_loop() 的输出,我的网络将无法训练,似乎永远卡住了。但是,如果我使用原始变量名,它就可以正常工作。为什么会这样?我们不应该为 tf.while_loop() 的输出张量使用新的变量名吗?

k = tf.constant(0)
i = tf.constant(val0)
sum_t = tf.constant(0,dtype=tf.float32)


while_condition = lambda k,i,sum_t: tf.math.less(k,val1)

def body(k,i,sum_t):
    tf.add(sum_t,val2)
    def f1(k,i):
        lambda:tf.add(k,1)
        lambda:tf.math.subtract(i,val1)
        return(k,i)
    def f2(k,i):
        lambda:(tf.add(i,1))
        return(k,i)
    tf.cond(tf.math.greater_equal(i,dim_t),lambda:f1(k,i),lambda:f2(k,i))
    return(k,i,sum_t)


new_k,new_i,loss2 = tf.while_loop(while_condition, body, [k,i,sum_t])

loss = loss2 #network won't train with this
loss = sum_t #network trains fine with this

标签: pythontensorflow

解决方案


你的循环没有做任何事情,它只是永远循环。您正在声明一些 lambda 函数,但您没有使用它们,并且您对输入张量进行了一些操作,但不使用这些操作的结果。body 函数只返回它接收到的相同的东西。我认为你想要的是这样的:

k = tf.constant(0)
i = tf.constant(val0)
sum_t = tf.constant(0, dtype=tf.float32)

while_condition = lambda k, i, sum_t: k < val1

def body(k, i, sum_t):
    def f1(k, i):
        return k + 1, i - val1
    def f2(k, i):
        return k, i + 1
    k, i = tf.cond(i >= dim_z, lambda: f1(k, i), lambda: f2(k, i))
    return k, i, sum_t + val2

new_k, new_i, loss2 = tf.while_loop(while_condition, body, [k, i, sum_t])

推荐阅读