pytorch - 交换批处理轴对pytorch中的性能有影响吗?
问题描述
我知道通常批次维度是零轴,我想这是有原因的:批次中每个项目的底层内存是连续的。
如果我在第一个轴上有另一个维度,我的模型调用的函数会变得更简单,这样我就可以使用x[k]
而不是x[:, k]
.
算术运算的结果似乎保持相同的内存布局
x = torch.ones(2,3,4).transpose(0,1)
y = torch.ones_like(x)
u = (x + 1)
v = (x + y)
print(x.stride(), u.stride(), v.stride())
当我创建其他变量时,我正在创建它们torch.zeros
然后转置,以便最大的步幅也到达轴 1。
例如
a,b,c = torch.zeros(
(3, x.shape[1], ADDITIONAL_DIM, x.shape[0]) + x.shape[2:]
).transpose(1,2)
将创建三个具有相同批量大小的张量x.shape[1]
。就内存局部性而言,拥有
a,b,c = torch.zeros(
(x.shape[1], 3, ADDITIONAL_DIM, x.shape[0]) + x.shape[2:]
).permute(1,2,0, ...)
反而。
我应该关心这个吗?
解决方案
TLDR;切片看似包含较少的信息......但实际上与原始张量共享相同的存储缓冲区。由于 permute 不会影响底层内存布局,因此这两个操作本质上是等效的。
这两者本质上是相同的,底层数据存储缓冲区保持不变,只有元数据, 即您与该缓冲区的交互方式(步幅和形状)发生变化。
让我们看一个简单的例子:
>>> x = torch.ones(2,3,4).transpose(0,1)
>>> x_ptr = x.data_ptr()
>>> x.shape, x.stride(), x_ptr
(3, 2, 4), (4, 12, 1), 94674451667072
我们将“基础”张量的数据指针保存在x_ptr
:
在第二个轴上切片:
>>> y = x[:, 0] >>> y.shape, y.stride(), x_ptr == y.data_ptr() (3, 4), (4, 1), True
如您所见,
x
并且x[:, k]
共享相同的存储空间。排列前两个轴,然后在第一个轴上切片:
>>> z = x.permute(1, 0, 2)[0] >>> z.shape, z.stride(), x_ptr == z.data_ptr() (3, 4), (4, 1), True
在这里,您再次注意到
x.data_ptr
与 相同z.data_ptr
。
事实上,您甚至可以使用以下方式从y
tox
表示torch.as_strided
:
>>> torch.as_strided(y, size=x.shape, stride=x.stride())
tensor([[[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[1., 1., 1., 1.],
[1., 1., 1., 1.]]])
与z
:
>>> torch.as_strided(z, size=x.shape, stride=x.stride())
两者都将返回一个副本,x
因为torch.as_strided
正在为新创建的张量分配内存。这两行只是为了说明我们如何仍然可以x
从 的切片中“返回” x
,我们可以通过更改张量的元数据来恢复明显的内容。
推荐阅读
- alibaba-cloud - 阿里巴巴云上的公共 DNS (IPv4) 是否用于其他云?
- java - hibernate 5 org.hibernate.hql.internal.ast.QuerySyntaxException: Emplooye 未映射
- php - 带有 PhpMyAdmin 5.0.4 的 PHP 问题版本
- openssl - 如何检查/比较 openssl 速度
- sql - 替换序列化 PHP 对象 Bia 中的文本 PHPMYADMIN MariaDB SQL 查询
- javascript - 如何防止图像离开html中的区域
- cocoapods - 像在 SPM 中一样在 CocoaPods 中创建多个模块
- python - 如何在游戏中使用元组?
- r - 更新到 Mac OS Big Sur 并在 R 中出现“警告:预期的最小视图高度”错误
- c - 提高排序的左子右兄弟树的时间复杂度