首页 > 解决方案 > 是否有更简单的方法来获取张量的切片,如下例所示?

问题描述

我想对张量进行切片,如下面的 numpy 切片。我怎样才能做到这一点?

# numpy array
a = np.reshape(np.arange(60), (3,2,2,5))
idx = np.array([0, 1, 0])
N = np.shape(a)[0]
mask = a[np.arange(N),:,:,idx]


# I have tried several solutions, but only the following success.
# tensors
import tensorflow as tf
import numpy as np


a = tf.cast(tf.constant(np.reshape(np.arange(60), (3,2,2,5))), tf.int32)
idx2 = tf.constant([0, 1, 0])

fn = lambda i: a[i][:,:,idx2[i]]
idx = tf.range(tf.shape(a)[0])
masks = tf.map_fn(fn, idx)
with tf.Session() as sess:
    print(sess.run(a))
    print(sess.run(tf.shape(masks)))
    print(sess.run(masks))

有没有更简单的方法来实现这一点?

我可以使用功能tf.gathertf.gather_nd实现这一点吗?非常感谢!

标签: pythontensorflowslice

解决方案


另一种方法使用tf.gather_nd

import tensorflow as tf
import numpy as np


a = tf.cast(tf.constant(np.reshape(np.arange(60), (3,2,2,5))), tf.int32)
idx = tf.range(tf.shape(a)[0])
idx2 = tf.constant([0,1,0])
indices = tf.stack([idx, idx2], axis=1)
a = tf.transpose(a, [0,3,1,2])
masks = tf.gather_nd(a, indices)

with tf.Session() as sess:
    print(sess.run(a))
    print(sess.run(tf.shape(masks)))
    print(sess.run(masks))

推荐阅读