python - Numpy 3D 数组索引:适用于 2D,如何处理 3D?
问题描述
我有 3 个numpy
数组,如下所示。
import numpy as np
key_idx = np.array([1, 2, 1]) # both have same shape
out_idx = np.array([0, 3, 0])
max_out = out_idx.max()
output = np.zeros(shape=(len(key_idx), max_out + 1))
# output =
# array([[0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.]])
我想增加索引给出的值,如下所示:
key_idx = key_idx[np.newaxis, :] # convert to 2D
out_idx = out_idx[np.newaxis, :]
idx = (key_idx, out_idx)
np.add.at(output, idx, 1)
# output =
# array([[0., 0., 0., 0.],
# [2., 0., 0., 0.],
# [0., 0., 0., 1.]])
然后应用如下变换:
np.sum(np.amax(output, axis=1))
#3.0
但现在我想为 3D 输出数组执行此操作,其中key_idx2D
2D 数组的第一个维度表示table_id
. 请参考下图:
我试过的
key_idx2D = np.array([[1, 2, 1], [2, 2, 2]])
output3D = np.zeros(shape=(key_idx2D.shape[0], len(key_idx), max_out + 1))
key_idx2D = key_idx[np.newaxis, :] # convert to 3D
out_idx = out_idx[np.newaxis, :]
idx3D = (key_idx2D, out_idx)
np.add.at(output3D, idx3D, 1)
#IndexError: index 2 is out of bounds for axis 0 with size 2
我怎样才能为 3D 案例做到这一点?任何帮助表示赞赏。它应该为每个返回一个值数组,table_id
如图所示。
注意:我可以用循环来做,但它会很慢。我需要更快的东西。
编辑:
key_idx2D
有axis 0 = table_id
和axis 1 = key_id
。
out_idx
有axis 0 = out_id
。两者都key_idx2D
包含out_idx
ndarrayoutput
中唯一需要np.add.at()
应用于它们的索引。我已经更新了该图以澄清这一点。
解决方案
我发布答案以防有人发现它有用。
key_idx2D = np.array([[1, 2, 1], [2, 2, 2]])
output3D = np.zeros(shape=(key_idx2D.shape[0], key_idx2D.shape[1], max_out + 1))
output3D.shape
#(2, 3, 4)
我所需要的只是为第一个轴(即轴 0)创建一个索引数组。
table_idx = np.array([0, 1]).reshape(-1, 1)
out_idx = np.array([0, 3, 0])
table_idx.shape, key_idx2D.shape, out_idx.shape
#((2, 1), (2, 3), (3,))
然后将所有索引数组以np.add.at
元组的形式发送到。
np.add.at(output3D, (table_idx, key_idx2D, out_idx), 1)
output3D
# array([[[0., 0., 0., 0.],
# [2., 0., 0., 0.],
# [0., 0., 0., 1.]],
# [[0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [2., 0., 0., 1.]]])
np.sum(np.amax(output3D, axis=2), axis=1)
#array([3., 2.])
推荐阅读
- javascript - 将 FullCalendar 集成到 Apache Royale
- linux - 使用什么系统调用通过 Internet 发送数据?
- scala - 如何检查空 JSON
- keras - tf.keras HDF5 模型和 Keras HDF5 模型
- php - 在 xampp localhost 不起作用,本地网络 ip 适用于项目
- terraform - 使用 terraform 在 digitalOcean 中带液滴的体积附件
- javascript - 无法将 rtmp 流式传输到网页
- c - 类型检查ANSI C中的任意长度数组
- python - 在 HTML 字符串中选择和剥离 img src
- javascript - Javascript 碰撞检测生成的分数超过了应有的分数