首页 > 解决方案 > 为什么 tf.cond 中的两个分支都被执行?为什么即使条件仍然成立,tf.while_loop 也会完成循环?

问题描述

我现在使用 keras 有一段时间了,但通常我不必使用自定义层或执行一些更复杂的流控制,所以我很难理解一些东西。

我正在建模一个顶部有一个自定义层的神经网络。这个自定义层调用另一个函数 ( search_sigma),在这个函数tf.while_loop中我执行并在其中tf.while_loop执行tf.cond

我不明白为什么这些条件不起作用。

有人可以帮助我了解我所缺少的吗?

我已经尝试为真张量更改 tf.cond 和 tf.while_loop 条件,只是想看看会发生什么。行为(完全相同的错误)保持不变。

我还尝试在不实现类的情况下编写此代码(仅使用函数)。没有改变。

我试图通过查看 tensorflow 文档、其他堆栈溢出问题以及谈论 tf.while_loop 和 tf.cond 的网站来寻找解决方案。

print()在代码主体中留下了一些 s 以尝试跟踪正在发生的事情。

class find_sigma:
    
    def __init__ (self, t_inputs,  inputs,  expected_perp=10. ):       
        self.sigma, self.cluster = t_inputs
        self.inputs = inputs
        self.expected_perp = expected_perp
        self.min_sigma=tf.constant([0.01],tf.float32)
        self.max_sigma=tf.constant([50.],tf.float32)
 

    def search_sigma(self):

        
        def cond(s,sigma_not_found): return sigma_not_found


        def body(s,sigma_not_found):   

            print('loop')
            pi = K.exp( - K.sum( (K.expand_dims(self.inputs, axis=1) - self.cluster)**2, axis=2  )/(2*s**2) )        
            pi = pi / K.sum(pi)
            MACHINE_EPSILON = np.finfo(np.double).eps
            pi = K.maximum(pi, MACHINE_EPSILON)
            H = - K.sum ( pi*(K.log(pi)/K.log(2.)) , axis=0 )
            perp = 2**H

            print('0')

            l1 = tf.logical_and (tf.less(perp , self.expected_perp), tf.less(0.01, self.max_sigma-s))
            l2 = tf.logical_and (tf.less(  self.expected_perp , perp) , tf.less(0.01, s-self.min_sigma) )
    
            def f1():
                print('f1')
                self.min_sigma = s 
                s2 = (s+self.max_sigma)/2 
                return  [s2, tf.constant([True])]
                

            def f2(l2): 
                tf.cond( l2, true_fn=f3 , false_fn = f4)

            def f3(): 
                print('f3')
                self.max_sigma = s 
                s2 = (s+self.min_sigma)/2
                return [s2, tf.constant([True])]

            def f4(): 
                print('f4')
                return [s, tf.constant([False])]
            
            output = tf.cond( l1, f1 ,  f4 ) #colocar f2 no lugar de f4

            s, sigma_not_found = output
            
            print('sigma_not_found = ',sigma_not_found)
            return [s,sigma_not_found]

        print('01')

        sigma_not_found = tf.constant([True])

        new_sigma,sigma_not_found=sigma_not_found = tf.while_loop(
            cond , body, loop_vars=[self.sigma,sigma_not_found]
        )

        print('saiu')
        
        print(new_sigma)

        return new_sigma

调用上述代码的一段代码是:

self.sigma = tf.map_fn(fn=lambda t: find_sigma(t,  inputs).search_sigma() , elems=(self.sigma,self.clusters), dtype=tf.float32)

'inputs' 是一个(None, 10)大小张量

'self.sigma' 是一个(10,)尺寸张量

'self.clusters' 是一个(N, 10)大小张量

标签: pythontensorflowkerasflow-control

解决方案


首先,您的第一个问题非常出色!很多信息!

tf.while_loop 非常令人困惑,这也是 tf 转向急切执行的原因之一。你不再需要这样做了。

无论如何,回到你的 2 个问题。两者的答案都是一样的,你永远不会执行你的图表,你只是在构建它。在构建执行图时,tensorflow 需要跟踪您的 python 代码,这就是您认为 tf.conf 正在运行 f1 和 f2 的原因。它是“某种运行”,因为它需要进入内部以确定哪些张量/操作将添加到图中。

这同样适用于您关于 tf.while_loop 的问题。它永远不会执行那个。

我建议进行一些小的更改,这可能会帮助您理解我在说什么并解决您的问题。从 body 方法中删除该 tf.while_loop 。创建另一个方法,比如说 run() 并将循环移到那里。有点像这样

def run(self):
   out = tf.while_loop(cond, body, loop_vars)

然后,调用 run()。它将强制执行图表。


推荐阅读