首页 > 解决方案 > 如何将 tensorflow 占位符变量转换为 numpy 数组?

问题描述

我想在 tensorflow 代码中使用 scipy 插值函数。

这是与我的情况类似的示例片段。

import tensorflow as tf
from scipy import interpolate

def interpolate1D(Xval,Fval,inp): 
    Xval = np.array(Xval)
    Fval = np.array(Fval)
    f = interpolate.interp1d(Xval, Fval, fill_value="extrapolate")
    z = f(inp)
    return z

properties = {
    'xval': [200,400,600,800,1100],
    'fval': [100.0,121.6,136.2,155.3,171.0]
}

tensor = tf.placeholder("float")

interpolate = interpolate1D(properties['xval'],properties['fval'], tensor)


一旦我得到了,interpolate我会使用它把它转换成张量tf.convert_to_tensor(interpolate)

这里interpolate.interp1d只是一个例子。我将使用其他插值方法,这些方法的输出将被输入另一个神经元。

我知道placeholder是空变量,所以从技术上讲它不可能转换成 numpy 数组。另外,我不能在张量流图之外使用这个插值函数,因为在某些情况下,我需要使用神经网络的输出作为插值函数的输入。

总的来说,我想在张量图中使用 scipy 插值函数。

标签: pythonnumpytensorflow

解决方案


您可以tf.py_func在图表中使用 SciPy 函数,但更好的选择是在 TensorFlow 中实现插值。库中没有开箱即用的功能,但实现起来并不难。

import tensorflow as tf

# Assumes Xval is sorted
def interpolate1D(Xval, Fval, inp):
    # Make sure input values are tensors
    Xval = tf.convert_to_tensor(Xval)
    Fval = tf.convert_to_tensor(Fval)
    inp = tf.convert_to_tensor(inp)
    # Find the interpolation indices
    c = tf.count_nonzero(tf.expand_dims(inp, axis=-1) >= Xval, axis=-1)
    idx0 = tf.maximum(c - 1, 0)
    idx1 = tf.minimum(c, tf.size(Xval, out_type=c.dtype) - 1)
    # Get interpolation X and Y values
    x0 = tf.gather(Xval, idx0)
    x1 = tf.gather(Xval, idx1)
    f0 = tf.gather(Fval, idx0)
    f1 = tf.gather(Fval, idx1)
    # Compute interpolation coefficient
    x_diff = x1 - x0
    alpha = (inp - x0) / tf.where(x_diff > 0, x_diff, tf.ones_like(x_diff))
    alpha = tf.clip_by_value(alpha, 0, 1)
    # Compute interpolation
    return f0 * (1 - alpha) + f1 * alpha

properties = {
    'xval': [200.0, 400.0, 600.0, 800.0, 1100.0],
    'fval': [100.0, 121.6, 136.2, 155.3, 171.0]
}

with tf.Graph().as_default(), tf.Session() as sess:
    tensor = tf.placeholder("float")
    interpolate = interpolate1D(properties['xval'], properties['fval'], tensor)
    print(sess.run(interpolate, feed_dict={tensor: [40.0, 530.0, 800.0, 1200.0]}))
    # [100.   131.09 155.3  171.  ]

推荐阅读