python - 切片类似于 numpy np.ix_ 的 2D 张量
问题描述
我在这里学习了如何在一维上切割张量。
我已经学习了如何对 2D 张量进行切片,并在此处给出特定值的 1D 张量。
两者都使用tf.gather()
,但我很确定我需要tf.gather_nd()
,尽管我显然用错了。
在 numpy 中,我有一个 5x5 2D 数组,我可以通过使用np.ix_()
行和列索引来切片一个 2x2 数组(我总是需要相同的行和列索引,从而产生一个方阵):
import numpy as np
a = np.array([[1,2,3,4,5],[2,1,6,7,8],[3,6,1,9,10],[4,7,9,1,11],[5,8,10,11,1]])
a
array([[ 1, 2, 3, 4, 5], [ 2, 1, 6, 7, 8], [ 3, 6, 1, 9, 10], [ 4, 7, 9, 1, 11], [ 5, 8, 10, 11, 1]])
a[np.ix_([1,3], [1,3])]
array([[1, 7], [7, 1]])
阅读tf.gather_nd()
文档我认为这是在 TF 中执行此操作的方法,但我使用错了:
import tensorflow as tf
a = tf.constant([[1,2,3,4,5],[2,1,6,7,8],[3,6,1,9,10],[4,7,9,1,11],[5,8,10,11,1]])
tf.gather_nd(a, [[1,3], [1,3]])
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([7, 7])>
我将不得不做类似的事情:
tf.gather_nd(a, [[[1,1], [1,3]],[[3,1],[3,3]]])
<tf.Tensor: shape=(2, 2), dtype=int32, numpy= array([[1, 7], [7, 1]])>
这使我陷入了另一个我不喜欢的兔子洞。当然,我的索引向量要长得多。
顺便说一句,我的索引本身就是一维整数张量。所以底线我想a
用与我相同的行和列索引来切片np._ix()
,我的索引类似于:
idx = tf.constant([1, 3])
# tf.gather_nd(a, indices = "something with idx")
解决方案
要使用长度为 d 的 1D 张量对 nxn 2D 数组进行切片,从而生成具有指定索引的 dxd 2D 数组,可以使用tf.repeat
,tf.tile
然后来完成tf.stack
:
n = 5
a = tf.constant(np.arange(n * n).reshape(n, n)) # 2D nxn array
idx = [1,2,4] # 1D tensor with length d
d = tf.shape(idx)[0]
ix_ = tf.reshape(tf.stack([tf.repeat(idx,d),tf.tile(idx,[d])],1),[d,d,2])
target = tf.gather_nd(a,ix_) # 2D dxd array
print(a)
print(target)
预期产出:
tf.Tensor(
[[ 0 1 2 3 4]
[ 5 6 7 8 9]
[10 11 12 13 14]
[15 16 17 18 19]
[20 21 22 23 24]], shape=(5, 5), dtype=int64)
tf.Tensor(
[[ 6 7 9]
[11 12 14]
[21 22 24]], shape=(3, 3), dtype=int64)
推荐阅读
- r - 了解我的 Ramsey RESET 测试的输出
- ios - 单选按钮在 ColectionView 单元格中不起作用,如果我选择一个按钮,另一个按钮不会取消选择
- vue.js - 捆绑服务器文件(Vue)并启动生产节点服务器的最佳实践
- julia - 为什么在 Julia 中尝试使用 ggplot 时出现错误?
- php - 在 stream_context_create 上更改 curl
- python - 使用“Pymongo”的“聚合”方法和“MongoDB”的“$project”命令从字典中提取键值并将它们转储到父字典中
- c# - 在二维数组中移动值
- javascript - 如何通知移动键盘有关 html 输入控件所需的输入?
- android - 如何在 listView 中获取单个文件大小
- ansible - 使用 ansible 按顺序克隆 git repo