python - How does pytorch broadcasting work?
问题描述
torch.add(torch.ones(4,1), torch.randn(4))
produces a Tensor with size: torch.Size([4,4])
.
Can someone provide a logic behind this?
解决方案
PyTorchbroadcasting
基于numpy 广播语义,可以通过阅读numpy broadcasting rules
或PyTorch 广播指南来理解。用一个例子来解释这个概念会更直观地理解它。所以,请看下面的例子:
In [27]: t_rand
Out[27]: tensor([ 0.23451, 0.34562, 0.45673])
In [28]: t_ones
Out[28]:
tensor([[ 1.],
[ 1.],
[ 1.],
[ 1.]])
现在torch.add(t_rand, t_ones)
,将其可视化为:
# shape of (3,)
tensor([ 0.23451, 0.34562, 0.45673])
# (4, 1) | | | | | | | | | | | |
tensor([[ 1.],____+ | | | ____+ | | | ____+ | | |
[ 1.],______+ | | ______+ | | ______+ | |
[ 1.],________+ | ________+ | ________+ |
[ 1.]])_________+ __________+ __________+
它应该给出形状张量的输出(4,3)
:
# shape of (4,3)
In [33]: torch.add(t_rand, t_ones)
Out[33]:
tensor([[ 1.23451, 1.34562, 1.45673],
[ 1.23451, 1.34562, 1.45673],
[ 1.23451, 1.34562, 1.45673],
[ 1.23451, 1.34562, 1.45673]])
另外,请注意,即使我们以与前一个相反的顺序传递参数,我们也会得到完全相同的结果:
# shape of (4, 3)
In [34]: torch.add(t_ones, t_rand)
Out[34]:
tensor([[ 1.23451, 1.34562, 1.45673],
[ 1.23451, 1.34562, 1.45673],
[ 1.23451, 1.34562, 1.45673],
[ 1.23451, 1.34562, 1.45673]])
无论如何,我更喜欢前一种理解方式,以获得更直接的直观性。
为了图示理解,我挑选了更多示例,列举如下:
Example-1:
Example-2:
:
T
和分别F
代表True
和False
指示我们允许广播的维度(来源:Theano)。
Example-3:
下面是一些形状,其中适当地广播了数组以b
尝试 匹配数组的形状a
。
如上所示,广播b
的形状可能仍然与 的形状不匹配,因此只要最终广播的形状不匹配a
,操作就会失败。a + b
推荐阅读
- websocket - XMPP 服务器究竟做了什么?
- javascript - 如何使用 javascript 生成字母表
- javascript - 未找到模块:错误:无法解析 React-js?
- swiftui - SwiftUI 应用程序在两个选择器周围出现 VStack 崩溃
- node.js - Lightsail 上的 Create-React-App:DNS 和 SSL 问题
- python - 根据文本框输入在标签中打印一些文本 - PyQT5
- react-native - navigate to another screen using react native
- amazon-web-services - How can I expose the AWS_WEB_IDENTITY_TOKEN_FILE to docker container that runs on GitLab runner with Kubernetes executor
- excel - How Do I Control The Font Style After Moving Data From Excel To Word?
- php - Trying to get property 'name' of non-object (View: /home/laravel/web/laravel.swt101.eu/public_html/abonamenty/resources/views/products/show.blade.php)