python - 如何将 tensorflow 占位符变量转换为 numpy 数组?
问题描述
我想在 tensorflow 代码中使用 scipy 插值函数。
这是与我的情况类似的示例片段。
import tensorflow as tf
from scipy import interpolate
def interpolate1D(Xval,Fval,inp):
Xval = np.array(Xval)
Fval = np.array(Fval)
f = interpolate.interp1d(Xval, Fval, fill_value="extrapolate")
z = f(inp)
return z
properties = {
'xval': [200,400,600,800,1100],
'fval': [100.0,121.6,136.2,155.3,171.0]
}
tensor = tf.placeholder("float")
interpolate = interpolate1D(properties['xval'],properties['fval'], tensor)
一旦我得到了,interpolate
我会使用它把它转换成张量tf.convert_to_tensor(interpolate)
这里interpolate.interp1d
只是一个例子。我将使用其他插值方法,这些方法的输出将被输入另一个神经元。
我知道placeholder
是空变量,所以从技术上讲它不可能转换成 numpy 数组。另外,我不能在张量流图之外使用这个插值函数,因为在某些情况下,我需要使用神经网络的输出作为插值函数的输入。
总的来说,我想在张量图中使用 scipy 插值函数。
解决方案
您可以tf.py_func
在图表中使用 SciPy 函数,但更好的选择是在 TensorFlow 中实现插值。库中没有开箱即用的功能,但实现起来并不难。
import tensorflow as tf
# Assumes Xval is sorted
def interpolate1D(Xval, Fval, inp):
# Make sure input values are tensors
Xval = tf.convert_to_tensor(Xval)
Fval = tf.convert_to_tensor(Fval)
inp = tf.convert_to_tensor(inp)
# Find the interpolation indices
c = tf.count_nonzero(tf.expand_dims(inp, axis=-1) >= Xval, axis=-1)
idx0 = tf.maximum(c - 1, 0)
idx1 = tf.minimum(c, tf.size(Xval, out_type=c.dtype) - 1)
# Get interpolation X and Y values
x0 = tf.gather(Xval, idx0)
x1 = tf.gather(Xval, idx1)
f0 = tf.gather(Fval, idx0)
f1 = tf.gather(Fval, idx1)
# Compute interpolation coefficient
x_diff = x1 - x0
alpha = (inp - x0) / tf.where(x_diff > 0, x_diff, tf.ones_like(x_diff))
alpha = tf.clip_by_value(alpha, 0, 1)
# Compute interpolation
return f0 * (1 - alpha) + f1 * alpha
properties = {
'xval': [200.0, 400.0, 600.0, 800.0, 1100.0],
'fval': [100.0, 121.6, 136.2, 155.3, 171.0]
}
with tf.Graph().as_default(), tf.Session() as sess:
tensor = tf.placeholder("float")
interpolate = interpolate1D(properties['xval'], properties['fval'], tensor)
print(sess.run(interpolate, feed_dict={tensor: [40.0, 530.0, 800.0, 1200.0]}))
# [100. 131.09 155.3 171. ]
推荐阅读
- elixir - Elixir 遍历一个列表并将其中的值附加到一个新列表中
- javascript - 在 NestJS 嵌套模式中忽略 @Prop 和 mongoose 选项
- ruby-on-rails - Carrierwave 上传文件将文件保存到 AWS 但重新加载后数据不保留
- oracle - 使用 PL/SQL SDK Dbms_cloud 从 MinIO 获取存储桶列表
- python - 如何运行 CatBoostClassifier?
- php - 使用 SMTP PHPMailer 将不同的消息发送到不同的电子邮件帐户
- vue.js - 部署asp.net core + vue.js 环境变量设置
- azure - Azure CI/CD 不承认 Wix 工具集安装项目的项目设置
- coq - Coq:消除`forall`?
- python - 试图从 praw 中的 2 个 subreddits 创建交替的帖子列表