python - 是否有更简单的方法来获取张量的切片,如下例所示?
问题描述
我想对张量进行切片,如下面的 numpy 切片。我怎样才能做到这一点?
# numpy array
a = np.reshape(np.arange(60), (3,2,2,5))
idx = np.array([0, 1, 0])
N = np.shape(a)[0]
mask = a[np.arange(N),:,:,idx]
# I have tried several solutions, but only the following success.
# tensors
import tensorflow as tf
import numpy as np
a = tf.cast(tf.constant(np.reshape(np.arange(60), (3,2,2,5))), tf.int32)
idx2 = tf.constant([0, 1, 0])
fn = lambda i: a[i][:,:,idx2[i]]
idx = tf.range(tf.shape(a)[0])
masks = tf.map_fn(fn, idx)
with tf.Session() as sess:
print(sess.run(a))
print(sess.run(tf.shape(masks)))
print(sess.run(masks))
有没有更简单的方法来实现这一点?
我可以使用功能tf.gather
或tf.gather_nd
实现这一点吗?非常感谢!
解决方案
另一种方法使用tf.gather_nd
:
import tensorflow as tf
import numpy as np
a = tf.cast(tf.constant(np.reshape(np.arange(60), (3,2,2,5))), tf.int32)
idx = tf.range(tf.shape(a)[0])
idx2 = tf.constant([0,1,0])
indices = tf.stack([idx, idx2], axis=1)
a = tf.transpose(a, [0,3,1,2])
masks = tf.gather_nd(a, indices)
with tf.Session() as sess:
print(sess.run(a))
print(sess.run(tf.shape(masks)))
print(sess.run(masks))
推荐阅读
- java - 如何在java中的.txt文件中写入或存储字符串数组
- c++ - 两个类可以同时继承吗?
- heroku - 下一个 js 和 Heroku 应用程序用户在输入时被定向到 http:// 而不是安全的 https://
- c# - 如何验证 PayPal Webhook 签名?
- oracle - 无法扩展表“SYS.SYSAUTH$”
- python - python中协方差的无限值
- python-3.x - 如何在熊猫中对值组进行排序
- bash - 我们如何将变量作为参数传递(作为列表)
- .htaccess - Codeigniter 自动重定向到友好的 url 不起作用
- sql - 函数 from_tz 在每月的第一天不起作用