首页 > 解决方案 > 如何在 PyTorch 中将坐标值的 3D 网格分割成块?

问题描述

如果我有一批坐标位置的统一 3D 网格,形状(例如)[1, 32, 32, 32, 3],那么我将其分成多个偶数块的最佳方法是什么,所以我最终可能会得到诸如 [1, 4096, 2, 2, 2, 3] 之类的东西吗?换句话说,我将一个 32 x 32 x 32 的大立方体(每个点都是 x、y、z 坐标位置)拆分为 4096 个较小的 2 x 2 x 2 立方体?一个简单的视图操作在这里有意义吗,还是会抛弃坐标值?我正在研究像 torch.chunk 这样的操作,但它们需要一个特定的维度来拆分,我不确定这是否适用于此。

我的用例是我有一个较小的 [1, 16, 16, 16, 3] 立方体,所以我试图将这个较小形状的点匹配到上采样 [1, 32, 32 , 32, 3] 形状(因为 16^3 形状中的单个坐标点对应 32^3 形状中的 8 个点)。

对于其他上下文,这就是我现在生成 3D 网格的方式:

pxs = torch.linspace(-1, 1, 32)
pys = torch.linspace(-1, 1, 32)
pzs = torch.linspace(-1, 1, 32)

pxs = pxs.view(-1, 1, 1).expand(*shape).contiguous().view(size)
pys = pys.view(1, -1, 1).expand(*shape).contiguous().view(size)
pzs = pzs.view(1, 1, -1).expand(*shape).contiguous().view(size)
points = torch.stack([pxs, pys, pzs], dim=1)

grid_3d = torch.reshape(points, (32, 32, 32, 3))

标签: pythonmatrixpytorch

解决方案


只需reshape使用您想要的尺寸即可:

In [29]: lin = np.linspace(-1, 1, 32)
    ...: cube = np.stack(np.meshgrid(lin, lin, lin), axis=-1)
    ...: cube.shape
Out[29]: (32, 32, 32, 3)

In [30]: new_cube = cube.reshape((1, 4096, 2, 2, 2, 3))
    ...: new_cube.shape
Out[30]: (1, 4096, 2, 2, 2, 3)

推荐阅读