首页 > 解决方案 > 如何在 tf.function 内部进行“元素明智”比较?

问题描述

我尝试在 TensorFlow 2 中创建自己的激活函数,函数如下所示:

@tf.function
def f(x):
  r = 2
  if x>=0:
    return (r**2 * x + 1)**(1/r) - 1/r
  else:
    return K.exp(r*x) - 1/r

问题是它不能作为论据tf.constant([2.0, 3.0]),因为条件存在问题。我也尝试过tf.math.qreater_equal(x, 0)导致相同的输出tf.cond()。我对文档示例也没有运气。它返回错误:

InvalidArgumentError:  The second input must be a scalar, but it has shape [2]
     [[{{node cond/switch_pred/_2}}]] [Op:__inference_f_7469065]

谢谢!

标签: pythontensorflow

解决方案


if语句被转换为cond,但它只接受谓词的标量参数(并且不广播)。请尝试where

return tf.where(x >= 0, (r**2 * x + 1)**(1/r) - 1/r, K.exp(r*x) - 1/r))

(目前无法使用 TensorFlow 进行测试,但这至少是 Numpy 的行为方式......)


推荐阅读