首页 > 解决方案 > 从张量中提取所有 3x3 块

问题描述

如果我有一个 5x5 张量,那么,我如何从中获取所有 9 个 3x3 块,以便生成的张量具有 [9, 3, 3] 形状,或者如果这些 3x3 块被展平,则为 [9, 9] 形状。例如,

x = torch.randn(5, 5)

假设 x 是

tensor([[ 0.5756,  0.2463,  1.3940,  0.8473, -0.8371],
        [ 0.9690,  1.4913, -0.2129,  0.8331, -0.6322],
        [-0.0348, -1.6920, -0.0157,  0.6159,  0.1038],
        [-1.0790,  1.4303,  0.3861,  0.1293,  0.4582],
        [ 0.2815, -1.1944, -0.7612,  0.6595,  1.4611]])

那么产生的张量应该是这样的,

tensor([[0.5756,  0.2463,  1.3940, 0.9690,  1.4913, -0.2129, -0.0348, -1.6920, -0.0157],
 [0.2463, 1.3940,  0.8473, 1.4913, -0.2129,  0.8331, -1.6920, -0.0157,  0.6159],
...
[-0.0157,  0.6159,  0.1038, 0.3861,  0.1293,  0.4582, -0.7612,  0.6595,  1.4611]])

标签: pytorch

解决方案


一个非常天真的实现可以是

y = torch.randn(5, 5)
x = torch.zeros((9, 3, 3))
count = 0
for i in range(3) :
    for j in range(3) :
        x[count] = y[i : i + 3, j : j + 3]
        count += 1

样本输出:

y = tensor([[ 0.0361, -0.4931, -1.1977, -0.5224, -3.4067],
        [ 0.2380, -1.1042, -0.0696, -2.0487, -0.4123],
        [ 0.6567, -0.2485, -0.3954, -0.8197, -0.4903],
        [ 1.0073,  1.4759,  0.3532,  0.3565, -1.5257],
        [-0.8493, -0.0532,  1.0918,  1.2715, -0.1775]])

x = tensor([[[ 0.0361, -0.4931, -1.1977],
         [ 0.2380, -1.1042, -0.0696],
         [ 0.6567, -0.2485, -0.3954]],

        [[-0.4931, -1.1977, -0.5224],
         [-1.1042, -0.0696, -2.0487],
         [-0.2485, -0.3954, -0.8197]],

        [[-1.1977, -0.5224, -3.4067],
         [-0.0696, -2.0487, -0.4123],
         [-0.3954, -0.8197, -0.4903]],

        [[ 0.2380, -1.1042, -0.0696],
         [ 0.6567, -0.2485, -0.3954],
         [ 1.0073,  1.4759,  0.3532]],

        [[-1.1042, -0.0696, -2.0487],
         [-0.2485, -0.3954, -0.8197],
         [ 1.4759,  0.3532,  0.3565]],

        [[-0.0696, -2.0487, -0.4123],
         [-0.3954, -0.8197, -0.4903],
         [ 0.3532,  0.3565, -1.5257]],

        [[ 0.6567, -0.2485, -0.3954],
         [ 1.0073,  1.4759,  0.3532],
         [-0.8493, -0.0532,  1.0918]],

        [[-0.2485, -0.3954, -0.8197],
         [ 1.4759,  0.3532,  0.3565],
         [-0.0532,  1.0918,  1.2715]],

        [[-0.3954, -0.8197, -0.4903],
         [ 0.3532,  0.3565, -1.5257],
         [ 1.0918,  1.2715, -0.1775]]])

推荐阅读