首页 > 解决方案 > torch.einsum 的内存使用情况

问题描述

我一直在尝试调试某个模型,该模型torch.einsum在重复几次的层中使用运算符。

在尝试分析模型在训练期间的 GPU 内存使用情况时,我注意到某个Einsum操作显着增加了内存使用量。我正在处理多维矩阵。操作是torch.einsum('b q f n, b f n d -> b q f d', A, B)

还值得一提的是:

我一直想知道为什么这个操作会使用这么多内存,以及为什么在每次迭代该层类型后内存都保持分配状态。

标签: pythonpytorchnumpy-einsum

解决方案


变量 " x" 确实被覆盖了,但是张量数据保存在内存中(也称为层的激活),以供以后在向后传递中使用。

因此,反过来,您可以有效地为 的结果分配新的内存数据,但即使看起来已被覆盖torch.einsum,您也不会替换的内存。x


要将其传递给测试,您可以在torch.no_grad()上下文管理器下计算前向传递(这些激活不会保存在内存中),并与标准推理相比,查看内存使用差异。


推荐阅读