首页 > 解决方案 > TensorFlow scatter_nd 函数不适用于占位符和复杂输入

问题描述

tf.scatter_nd用来更新某个索引处的复杂值。这个函数似乎将实部和虚部以某种方式加在一起。我的问题是如何使它与占位符一起使用。这是变量be应该具有相同值的最小工作示例。

import tensorflow as tf
import numpy as np
tf.reset_default_graph()

update=np.asarray([1.+2j])
idx=tf.constant( [[0]])
shp=tf.constant([1])

# works with constants

a=tf.constant(update)
b=tf.scatter_nd(idx,a,shp)
with tf.Session() as sess:
    print sess.run(b) # correct output: 1.+2j

#Does not work with placeholders

d=tf.placeholder(tf.complex128)
e=tf.scatter_nd(idx,d,shp)
with tf.Session() as sess:
    print sess.run(e,feed_dict={d:update}) # WRONG output: 3.+0j

我正在使用使用 conda 命令安装的 Anaconda python 2.7 + TensorFlow 1.7 GPU 版本。

编辑:

在 GPU 上运行代码时会出现此问题。CPU 版本正常工作。这是使用 Anaconda Python 2.7 安装的 TensorFlow-GPU 1.8 中重现该问题的更新代码。

import tensorflow as tf
import numpy as np
tf.reset_default_graph()

update=np.asarray([1.+2j])
idx=tf.constant( [[0]])
shp=tf.constant([1])

a=tf.placeholder(tf.complex128)

with tf.device("/cpu:0"):
    b=tf.scatter_nd(idx,a,shp)
with tf.device("/gpu:0"):
    c=tf.scatter_nd(idx,a,shp)

with tf.Session() as sess:
    print 'Correct output on CPU', sess.run(b,feed_dict={a:update})
    print 'Wrong output on GPU',sess.run(c,feed_dict={a:update})

我看到了这个线程和这个线程,但找不到如何解决它。是否有任何替代方案可以tf.scatter_nd在 GPU 上运行?

标签: python-2.7tensorflow

解决方案


推荐阅读