首页 > 解决方案 > 3,4轴pytorch中的矩阵乘法

问题描述

我有两个形状张量a(16,8,8,64)b(64,64). 假设,我将 的最后一个维度提取a到另一个列向量c中,我想计算matmul(matmul(c.T, b), c)。我希望在a. 也就是最终产品应该是成型的(16,8,8,1)。如何在 pytorch 中实现这一点?

标签: pythonpytorchmatrix-multiplication

解决方案


可以按如下方式进行:

row_vec = a[:, :, :, None, :].float()
col_vec = a[:, :, :, :, None].float()
b = (b[None, None, None, :, :]).float()
prod = torch.matmul(torch.matmul(row_vec, b), col_vec)

推荐阅读