首页 > 解决方案 > 交换批处理轴对pytorch中的性能有影响吗?

问题描述

我知道通常批次维度是零轴,我想这是有原因的:批次中每个项目的底层内存是连续的。

如果我在第一个轴上有另一个维度,我的模型调用的函数会变得更简单,这样我就可以使用x[k]而不是x[:, k].

算术运算的结果似乎保持相同的内存布局

x = torch.ones(2,3,4).transpose(0,1)
y = torch.ones_like(x)
u = (x + 1)
v = (x + y)
print(x.stride(), u.stride(), v.stride())

当我创建其他变量时,我正在创建它们torch.zeros然后转置,以便最大的步幅也到达轴 1。

例如

a,b,c = torch.zeros(
         (3, x.shape[1], ADDITIONAL_DIM, x.shape[0]) + x.shape[2:]
).transpose(1,2)

将创建三个具有相同批量大小的张量x.shape[1]。就内存局部性而言,拥有

a,b,c = torch.zeros(
  (x.shape[1], 3, ADDITIONAL_DIM, x.shape[0]) + x.shape[2:]
).permute(1,2,0, ...)

反而。

我应该关心这个吗?

标签: pytorchstride

解决方案


TLDR;切片看似包含较少的信息......但实际上与原始张量共享相同的存储缓冲区。由于 permute 不会影响底层内存布局,因此这两个操作本质上是等效的。


这两者本质上是相同的,底层数据存储缓冲区保持不变,只有元数据, 您与该缓冲区的交互方式(步幅和形状)发生变化。

让我们看一个简单的例子:

>>> x = torch.ones(2,3,4).transpose(0,1)
>>> x_ptr = x.data_ptr()

>>> x.shape, x.stride(), x_ptr
(3, 2, 4), (4, 12, 1), 94674451667072

我们将“基础”张量的数据指针保存在x_ptr

  1. 在第二个轴上切片:

    >>> y = x[:, 0]
    
    >>> y.shape, y.stride(), x_ptr == y.data_ptr()
    (3, 4), (4, 1), True
    

    如您所见,x并且x[:, k]共享相同的存储空间。

  2. 排列前两个轴,然后在第一个轴上切片:

    >>> z = x.permute(1, 0, 2)[0]
    
    >>> z.shape, z.stride(), x_ptr == z.data_ptr()
    (3, 4), (4, 1), True
    

    在这里,您再次注意到x.data_ptr与 相同z.data_ptr


事实上,您甚至可以使用以下方式从ytox表示torch.as_strided

>>> torch.as_strided(y, size=x.shape, stride=x.stride())
tensor([[[1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.]]])

z

>>> torch.as_strided(z, size=x.shape, stride=x.stride())

两者都将返回一个副本x因为torch.as_strided正在为新创建的张量分配内存。这两行只是为了说明我们如何仍然可以x从 的切片中“返回” x,我们可以通过更改张量的元数据来恢复明显的内容。


推荐阅读