首页 > 解决方案 > tf.cast() 导致我的程序反向传播失败,我该如何解决这个问题?

问题描述

import tensorflow as tf
tf.compat.v1.disable_eager_execution()

A  = tf.constant([[1,7,3]],dtype=tf.float32)
B = tf.zeros_like([[1,0,0],[0,1,0]])
C = tf.cast(A,dtype=tf.int32)+B

f = tf.gradients(C,A)
with tf.compat.v1.Session() as sess:
    print(sess.run(f))

我正在使用 tensorflow 2.3.0 版本。运行此程序时会提示错误:

TypeError: Fetch argument None has invalid type <class 'NoneType'>

如何在不阻碍反向传播的情况下在tensorflow中完成数据转换?最初A是float类型,B是int类型,这个是不能改的。

标签: tensorflowtensorflow2.0

解决方案


Tensorflow 不通过整数来区分。将您的 int 转换为浮动。

import tensorflow as tf
tf.compat.v1.disable_eager_execution()

A  = tf.constant([[1,7,3]],dtype=tf.float32)
B = tf.zeros_like([[1,0,0],[0,1,0]])
C = A+tf.cast(B,tf.float32)

f = tf.gradients(C,A)
with tf.compat.v1.Session() as sess:
    print(sess.run(f))

结果是 :

[array([[2., 2., 2.]], dtype=float32)]

推荐阅读