python - 如果使用 tf.while_loop 的输出张量,网络将不会训练
问题描述
我有一个自定义损失,它使用tf.while_loop()
. 代码如下。如果我使用 tf.while_loop() 的输出,我的网络将无法训练,似乎永远卡住了。但是,如果我使用原始变量名,它就可以正常工作。为什么会这样?我们不应该为 tf.while_loop() 的输出张量使用新的变量名吗?
k = tf.constant(0)
i = tf.constant(val0)
sum_t = tf.constant(0,dtype=tf.float32)
while_condition = lambda k,i,sum_t: tf.math.less(k,val1)
def body(k,i,sum_t):
tf.add(sum_t,val2)
def f1(k,i):
lambda:tf.add(k,1)
lambda:tf.math.subtract(i,val1)
return(k,i)
def f2(k,i):
lambda:(tf.add(i,1))
return(k,i)
tf.cond(tf.math.greater_equal(i,dim_t),lambda:f1(k,i),lambda:f2(k,i))
return(k,i,sum_t)
new_k,new_i,loss2 = tf.while_loop(while_condition, body, [k,i,sum_t])
loss = loss2 #network won't train with this
loss = sum_t #network trains fine with this
解决方案
你的循环没有做任何事情,它只是永远循环。您正在声明一些 lambda 函数,但您没有使用它们,并且您对输入张量进行了一些操作,但不使用这些操作的结果。body 函数只返回它接收到的相同的东西。我认为你想要的是这样的:
k = tf.constant(0)
i = tf.constant(val0)
sum_t = tf.constant(0, dtype=tf.float32)
while_condition = lambda k, i, sum_t: k < val1
def body(k, i, sum_t):
def f1(k, i):
return k + 1, i - val1
def f2(k, i):
return k, i + 1
k, i = tf.cond(i >= dim_z, lambda: f1(k, i), lambda: f2(k, i))
return k, i, sum_t + val2
new_k, new_i, loss2 = tf.while_loop(while_condition, body, [k, i, sum_t])
推荐阅读
- java - 如何正确设置 Eclipse(ActionListener 无法解析为类型)?
- r - 如何使用函数将变量归类
- c++ - Sancov象征大文件这么久?
- python - 我试图让我的放弃时间说 5 天提醒 432000 秒
- go - 如何在不编组的情况下将结构从中间件传递到 Gin 中的处理程序?
- google-apps-script - 使用谷歌应用脚本将项目从一个谷歌表单移动到一个新的谷歌表单
- javascript - 使用专门的阅读页面制作每个链接或页面
- node.js - RPi 上的 NodeJs Lan 应用程序到谷歌云平台到自定义域
- c# - 如何解决 ASP.NET.MVC 中的模型问题?
- .net - 没有注册类型为“TodoListGQL.Data.ApiDbContext”的服务 .Net 5、GraphQL 和 Hot Chocolate