python - 为什么 tf.cond 中的两个分支都被执行?为什么即使条件仍然成立,tf.while_loop 也会完成循环?
问题描述
我现在使用 keras 有一段时间了,但通常我不必使用自定义层或执行一些更复杂的流控制,所以我很难理解一些东西。
我正在建模一个顶部有一个自定义层的神经网络。这个自定义层调用另一个函数 ( search_sigma
),在这个函数tf.while_loop
中我执行并在其中tf.while_loop
执行tf.cond
。
我不明白为什么这些条件不起作用。
tf.while_loop
即使条件 (l1
) 仍然为真,也会停止tf.cond executes
和(可调用对象f1
和)f2
true_fn
false_fn
有人可以帮助我了解我所缺少的吗?
我已经尝试为真张量更改 tf.cond 和 tf.while_loop 条件,只是想看看会发生什么。行为(完全相同的错误)保持不变。
我还尝试在不实现类的情况下编写此代码(仅使用函数)。没有改变。
我试图通过查看 tensorflow 文档、其他堆栈溢出问题以及谈论 tf.while_loop 和 tf.cond 的网站来寻找解决方案。
我print()
在代码主体中留下了一些 s 以尝试跟踪正在发生的事情。
class find_sigma:
def __init__ (self, t_inputs, inputs, expected_perp=10. ):
self.sigma, self.cluster = t_inputs
self.inputs = inputs
self.expected_perp = expected_perp
self.min_sigma=tf.constant([0.01],tf.float32)
self.max_sigma=tf.constant([50.],tf.float32)
def search_sigma(self):
def cond(s,sigma_not_found): return sigma_not_found
def body(s,sigma_not_found):
print('loop')
pi = K.exp( - K.sum( (K.expand_dims(self.inputs, axis=1) - self.cluster)**2, axis=2 )/(2*s**2) )
pi = pi / K.sum(pi)
MACHINE_EPSILON = np.finfo(np.double).eps
pi = K.maximum(pi, MACHINE_EPSILON)
H = - K.sum ( pi*(K.log(pi)/K.log(2.)) , axis=0 )
perp = 2**H
print('0')
l1 = tf.logical_and (tf.less(perp , self.expected_perp), tf.less(0.01, self.max_sigma-s))
l2 = tf.logical_and (tf.less( self.expected_perp , perp) , tf.less(0.01, s-self.min_sigma) )
def f1():
print('f1')
self.min_sigma = s
s2 = (s+self.max_sigma)/2
return [s2, tf.constant([True])]
def f2(l2):
tf.cond( l2, true_fn=f3 , false_fn = f4)
def f3():
print('f3')
self.max_sigma = s
s2 = (s+self.min_sigma)/2
return [s2, tf.constant([True])]
def f4():
print('f4')
return [s, tf.constant([False])]
output = tf.cond( l1, f1 , f4 ) #colocar f2 no lugar de f4
s, sigma_not_found = output
print('sigma_not_found = ',sigma_not_found)
return [s,sigma_not_found]
print('01')
sigma_not_found = tf.constant([True])
new_sigma,sigma_not_found=sigma_not_found = tf.while_loop(
cond , body, loop_vars=[self.sigma,sigma_not_found]
)
print('saiu')
print(new_sigma)
return new_sigma
调用上述代码的一段代码是:
self.sigma = tf.map_fn(fn=lambda t: find_sigma(t, inputs).search_sigma() , elems=(self.sigma,self.clusters), dtype=tf.float32)
'inputs' 是一个(None, 10)
大小张量
'self.sigma' 是一个(10,)
尺寸张量
'self.clusters' 是一个(N, 10)
大小张量
解决方案
首先,您的第一个问题非常出色!很多信息!
tf.while_loop 非常令人困惑,这也是 tf 转向急切执行的原因之一。你不再需要这样做了。
无论如何,回到你的 2 个问题。两者的答案都是一样的,你永远不会执行你的图表,你只是在构建它。在构建执行图时,tensorflow 需要跟踪您的 python 代码,这就是您认为 tf.conf 正在运行 f1 和 f2 的原因。它是“某种运行”,因为它需要进入内部以确定哪些张量/操作将添加到图中。
这同样适用于您关于 tf.while_loop 的问题。它永远不会执行那个。
我建议进行一些小的更改,这可能会帮助您理解我在说什么并解决您的问题。从 body 方法中删除该 tf.while_loop 。创建另一个方法,比如说 run() 并将循环移到那里。有点像这样
def run(self):
out = tf.while_loop(cond, body, loop_vars)
然后,调用 run()。它将强制执行图表。
推荐阅读
- javascript - 如何将 MySQL 服务器与 windows 的电子应用程序包集成?
- mysql - 在案例表达式 Redshift 中摆脱 group by 变量
- path - ENOENT:没有这样的文件或目录 readFileSync
- flutter - 在 Flutter Web 中上传大文件
- google-apps-script - 我可以在 GAS 中有多个 toast 吗?
- azure - Azure Log Analytics:如何同时显示 AppServiceConsoleLogs 和 AppServiceHTTPLogs?
- nestjs - 如何在 Nest JS 中实现 Retrace 分析工具
- python - 使用 np.arange 定义时缺少最终刻度线
- javascript - $refs resetFields 不是函数 AntDesign
- python - 如何使用具有自定义频率和数据折叠的熊猫滚动窗口