python - 二维数组作为 Pytorch 中的索引
问题描述
我想使用一组规则“增长”一个矩阵。
规则示例:
0->[[1,1,1],[0,0,0],[2,2,2]],
1->[[2,2,2],[2,2,2],[2,2,2]],
2->[[0,0,0],[0,0,0],[0,0,0]]
增长矩阵的示例:
[[0]]->[[1,1,1],[0,0,0],[2,2,2]]->
[[2,2,2,2,2,2,2,2,2],[2,2,2,2,2,2,2,2,2],[2,2,2,2,2,2,2,2,2],
[1,1,1,1,1,1,1,1,1],[0,0,0,0,0,0,0,0,0],[2,2,2,2,2,2,2,2,2],
[0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0]]
这是我一直试图在 Pytorch 中工作的代码
rules = np.random.randint(256,size=(10,256,3,3,3))
rules_tensor = torch.randint(256,size=(10,
256, 3, 3, 3),
dtype=torch.uint8, device = torch.device('cuda'))
rules = rules[0]
rules_tensor = rules_tensor[0]
seed = np.array([[128]])
seed_tensor = seed_tensor = torch.cuda.ByteTensor([[128]])
decode = np.empty((3**3, 3**3, 3))
decode_tensor = torch.empty((3**3,
3**3, 3), dtype=torch.uint8,
device = torch.device('cuda'))
for i in range(3):
grow = seed
grow_tensor = seed_tensor
for j in range(1,4):
grow = rules[grow,:,:,i].reshape(3**j,-1)
grow_tensor = rules_tensor[grow_tensor,:,:,i].reshape(3**j,-1)
decode[..., i] = grow
decode_tensor[..., i] = grow_tensor
在这一行中,我似乎无法像在 Numpy 中那样选择索引:
grow = rules[grow,:,:,i].reshape(3**j,-1)
有没有办法在 Pytorch 中执行以下操作?
解决方案
您可以考虑使用torch.index_select()
,在重塑结果之前展平您的索引张量:
代码:
import torch
import numpy as np
rules_np = np.array([
[[1,1,1],[0,0,0],[2,2,2]], # for value 0
[[2,2,2],[2,2,2],[2,2,2]], # for value 1
[[0,0,0],[0,0,0],[0,0,0]]]) # for value 2, etc.
rules = torch.from_numpy(rules_np).long()
rule_shape = rules[0].shape
seed = torch.zeros(1).long()
num_growth = 2
print("Seed:")
print(seed)
grow = seed
for i in range(num_growth):
grow = (torch.index_select(rules, 0, grow.view(-1))
.view(grow.shape + rule_shape)
.squeeze())
print("Growth #{}:".format(i))
print(grow)
日志:
Seed:
tensor([ 0])
Growth #0:
tensor([[ 1, 1, 1], [ 0, 0, 0], [ 2, 2, 2]])
Growth #1:
tensor([[[[ 2, 2, 2], [ 2, 2, 2], [ 2, 2, 2]],
[[ 2, 2, 2], [ 2, 2, 2], [ 2, 2, 2]],
[[ 2, 2, 2], [ 2, 2, 2], [ 2, 2, 2]]],
[[[ 1, 1, 1], [ 0, 0, 0], [ 2, 2, 2]],
[[ 1, 1, 1], [ 0, 0, 0], [ 2, 2, 2]],
[[ 1, 1, 1], [ 0, 0, 0], [ 2, 2, 2]]],
[[[ 0, 0, 0], [ 0, 0, 0], [ 0, 0, 0]],
[[ 0, 0, 0], [ 0, 0, 0], [ 0, 0, 0]],
[[ 0, 0, 0], [ 0, 0, 0], [ 0, 0, 0]]]])
推荐阅读
- spring - 在不同控制器的视图之间共享modelAttribute
- angular - 无法通过代理 Angular 传递 http 标头
- javascript - Firebase 错误:photoURL 字段必须是有效的 URL
- javascript - 比较数组元素时出现问题(Javascript)
- apache-nifi - 按指定顺序合并流文件
- python - 如何处理 NoneType 错误
- python - 在 python 3 中编码字符串“tamilnadu”
- scala - 如何检查对象的参数是否在 ArrayBuffer 中填充了 Scala 中的对象?
- jekyll - 带有 site.data 的 Jekyll “where”过滤器似乎没有按预期工作
- c# - 用 C# 开发 Windows 位置传感器