首页 > 解决方案 > 使用索引提取张量的行和列

问题描述

我有一个 2-D tensor x,它是一个 placeholder x = tf.placeholder(tf.int32, shape=[None, 100])。现在在图中,我想随机提取50%x 行。这是我尝试过的:

idxs = tf.range(tf.shape(x)[0])
ridxs = tf.random_shuffle(idxs)[: int(tf.shape(x)[0] * 0.5)]
x = tf.gather(x, ridxs, axis=0)

它抛出以下错误:

Traceback (most recent call last):
  File "main_gcn_random_sampling.py", line 76, in <module>
    model = GCN_RANDOM_SAMPLING(FLAGS.learning_rate, num_input, num_classes, hidden_dimensions=hidden_sizes, act=tf.nn.relu)
  File "/home/tiendh/Projects/EPooling/gcn_random_sampling.py", line 63, in __init__
    ridxs = tf.random_shuffle(idxs)[: int(tf.shape(x)[0] * 0.5)]
TypeError: unsupported operand type(s) for *: 'Tensor' and 'float'

我可以为预定义的索引声明另一个占位符,但它看起来很难看。还有什么建议吗?

标签: pythontensorflow

解决方案


这几乎是正确的,但是:

  1. tf.shape(x)[0]在乘以之前,您必须强制转换为浮动0.5
  2. 您不能使用int转换张量数据类型,您需要使用tf.cast.
idxs = tf.range(tf.shape(x)[0])
num_half_idx = tf.cast(tf.cast(tf.shape(x)[0], tf.float32) * 0.5, tf.int32)
ridxs = tf.random_shuffle(idxs)[: num_half_idx]
x = tf.gather(x, ridxs, axis=0)

请注意,如果您总是获得 50% 的行,您也可以简单地执行以下操作:

idxs = tf.range(tf.shape(x)[0])
ridxs = tf.random_shuffle(idxs)[: tf.shape(x)[0] // 2]
x = tf.gather(x, ridxs, axis=0)

推荐阅读