首页 > 解决方案 > 将此代码从 tensorflow 1 移植到 tensorflow 2

问题描述

我正在尝试移植我在 stackoverflow.com 上的一个答案中找到的代码:

import tensorflow as tf

x = tf.placeholder(tf.float32,shape=[3,3])

cond1 = tf.where(x > 10, x - 10, tf.zeros_like(x))
cond2 = tf.where(x < 4, x + 60, tf.zeros_like(x))
cond3 = tf.where(tf.logical_and(x >= 4, x <= 10), x, tf.zeros_like(x))
y = cond1 + cond2 + cond3

sample = [[10, 15, 25], [1, 2, 3], [4, 4, 10]]

print(sess.run(y, feed_dict={x: sample}))

到目前为止,我已经完成了:

import tensorflow as tf

x = tf.keras.Input(shape=[3,3], dtype=tf.dtypes.float32)

cond1 = tf.where(x > 10, x - 10, tf.zeros_like(x))
cond2 = tf.where(x < 4, x + 60, tf.zeros_like(x))
cond3 = tf.where(tf.logical_and(x >= 4, x <= 10), x, tf.zeros_like(x))
y = cond1 + cond2 + cond3

sample = [[10, 15, 25], [1, 2, 3], [4, 4, 10]]

但是我找不到打印结果的方法,因为我不能按照移植指南的建议执行 print(f(sample)。

标签: tensorflowtensorflow2.0

解决方案


首先使用创建模型,

from tensorflow.keras.models import Model
model = Model(x, y)

然后做,

res = model.predict(sample)
print(res)

res将是一个 numpy 数组。


推荐阅读