python - PyTorch: How to multiply via broadcasting of two tensors with different shapes
问题描述
I have the following two PyTorch tensors A and B.
A = torch.tensor(np.array([40, 42, 38]), dtype = torch.float64)
tensor([40., 42., 38.], dtype=torch.float64)
B = torch.tensor(np.array([[[1,2,3,4,5],[1,2,3,4,5],[1,2,3,4,5],[1,2,3,4,5],[1,2,3,4,5]], [[4,5,6,7,8],[4,5,6,7,8],[4,5,6,7,8],[4,5,6,7,8],[4,5,6,7,8]], [[7,8,9,10,11],[7,8,9,10,11],[7,8,9,10,11],[7,8,9,10,11],[7,8,9,10,11]]]), dtype = torch.float64)
tensor([[[ 1., 2., 3., 4., 5.],
[ 1., 2., 3., 4., 5.],
[ 1., 2., 3., 4., 5.],
[ 1., 2., 3., 4., 5.],
[ 1., 2., 3., 4., 5.]],
[[ 4., 5., 6., 7., 8.],
[ 4., 5., 6., 7., 8.],
[ 4., 5., 6., 7., 8.],
[ 4., 5., 6., 7., 8.],
[ 4., 5., 6., 7., 8.]],
[[ 7., 8., 9., 10., 11.],
[ 7., 8., 9., 10., 11.],
[ 7., 8., 9., 10., 11.],
[ 7., 8., 9., 10., 11.],
[ 7., 8., 9., 10., 11.]]], dtype=torch.float64)
Tensor A is of shape:
torch.Size([3])
Tensor B is of shape:
torch.Size([3, 5, 5])
How do I multiply tensor A with tensor B (using broadcasting) in such a way for eg. the first value in tensor A (ie. 40.
) is multiplied with all the values in the first 'nested' tensor in tensor B, ie.
tensor([[[ 1., 2., 3., 4., 5.],
[ 1., 2., 3., 4., 5.],
[ 1., 2., 3., 4., 5.],
[ 1., 2., 3., 4., 5.],
[ 1., 2., 3., 4., 5.]],
and so on for the other 2 values in tensor A and the other two nested tensors in tensor B, respectively.
I could do this multiplication (via broadcasting) with numpy arrays if A and B are arrays of both shape (3,) - ie. A*B
- but I can't seem to figure out a counterpart of this with PyTorch tensors. Any help would really be appreciated.
解决方案
When applying broadcasting in pytorch (as well as in numpy) you need to start at the last dimension (check out https://pytorch.org/docs/stable/notes/broadcasting.html). If they do not match you need to reshape your tensor. In your case they can't directly be broadcasted:
[3] # the two values in the last dimensions are not one and do not match
[3, 5, 5]
Instead you can redefine A = A[:, None, None]
before muliplying such that you get shapes
[3, 1, 1]
[3, 5, 5]
which satisfies the conditions for broadcasting.
推荐阅读
- node.js - 反应不加载任何组件
- java - 更改 ListView 高度以匹配 EditText 内容
- php - 使用 file_get_content 更新 txt 文件
- string - 解析字符串,然后将其存储为数组并在 shell 脚本上再次解析
- php - 如果发生异常,Laravel 停止队列作业
- html - 基于数据库配置的动态模板Angular
- pandas - 熊猫数据框剥离内容
- python - Python 运行版本 3.5 而不是 3.6
- magento2 - 如何在同一行中制作名字和姓氏结帐页面送货地址
- html - 当您将鼠标悬停在 p 元素的不同部分时,如何使用 CSS 使 div 元素出现?