tensorflow - tf.cast() 导致我的程序反向传播失败,我该如何解决这个问题?
问题描述
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
A = tf.constant([[1,7,3]],dtype=tf.float32)
B = tf.zeros_like([[1,0,0],[0,1,0]])
C = tf.cast(A,dtype=tf.int32)+B
f = tf.gradients(C,A)
with tf.compat.v1.Session() as sess:
print(sess.run(f))
我正在使用 tensorflow 2.3.0 版本。运行此程序时会提示错误:
TypeError: Fetch argument None has invalid type <class 'NoneType'>
如何在不阻碍反向传播的情况下在tensorflow中完成数据转换?最初A是float类型,B是int类型,这个是不能改的。
解决方案
Tensorflow 不通过整数来区分。将您的 int 转换为浮动。
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
A = tf.constant([[1,7,3]],dtype=tf.float32)
B = tf.zeros_like([[1,0,0],[0,1,0]])
C = A+tf.cast(B,tf.float32)
f = tf.gradients(C,A)
with tf.compat.v1.Session() as sess:
print(sess.run(f))
结果是 :
[array([[2., 2., 2.]], dtype=float32)]
推荐阅读
- c++ - Winapi 创建选项卡菜单
- javafx - 组合框中的 JavaFX 组合框
- postgresql - 在某些情况下,UPDATE 事务是否会失败而不是阻塞等待“FOR UPDATE 锁定”?
- php - 如何在从我的数据库写入和读取时防止并发?
- google-chrome - 如何让 Chrome 网络扩展程序读取 Chrome 设置“下载前询问每个文件的保存位置”?
- spring - 在 JWT 和 Spring 中使用范围限制访问
- javascript - 找不到 React/Node 应用程序的 Node sdk。如何将亚马逊支付与 React 应用程序集成
- angular - 反应形式自定义验证器未在 angular6 中显示错误
- .net - FFMPEG 在线程错误:线程被中止
- android - 使用颤振和 vscode 构建 android 应用程序时出错