python - 了解 torch.nn.Flatten
问题描述
我知道 Flatten 会删除除一个之外的所有维度。例如,我理解flatten():
> t = torch.ones(4, 3)
> t
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]])
> flatten(t)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
但是,我不明白,特别是我没有从文档Flatten
中得到这个片段的含义:
>>> input = torch.randn(32, 1, 5, 5)
>>> m = nn.Sequential(
>>> nn.Conv2d(1, 32, 5, 1, 1),
>>> nn.Flatten()
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([32, 288])
我觉得输出应该有大小[160]
,因为32*5=160
.
Q1。那么为什么它输出 size[32,288]
呢?
Q2。我也没有得到shape
文档中给出的信息的含义:
Q3。还有参数的含义:
解决方案
这是默认行为的差异。torch.flatten
默认展平所有尺寸,同时torch.nn.Flatten
展平从第二维(索引 1)开始的所有尺寸。
start_dim
您可以在和end_dim
参数的默认值中看到此行为。参数表示要展平的start_dim
第一个维度(零索引),end_dim
参数表示要展平的最后一个维度。因此,当start_dim=1
是 的默认值时torch.nn.Flatten
,第一个维度(索引 0)没有展平,但它包含在 时start_dim=0
,这是 的默认值torch.flatten
。
这种差异背后的原因可能是因为torch.nn.Flatten
旨在与 一起使用torch.nn.Sequential
,其中通常对一批输入执行一系列操作,其中每个输入都独立于其他输入进行处理。例如,如果您有一批图像并调用torch.nn.Flatten
,典型的用例是分别展平每个图像,而不是展平整批图像。
如果您确实想使用 展平所有尺寸torch.nn.Flatten
,您可以简单地将对象创建为torch.nn.Flatten(start_dim=0)
。
最后,文档中的形状信息仅涵盖了张量的形状将如何受到影响,说明第一个(索引 0)维度保持原样。因此,如果您有一个 shape 的输入张量(N, *dims)
,其中*dims
是任意维度序列,则输出张量将具有 shape (N, product of *dims)
,因为除了批量维度之外的所有维度都被展平。例如, shape 的输入(3,10,10)
将具有 shape 的输出(3, 10 x 10) = (3, 100)
。
推荐阅读
- html - 如何在两种颜色之间创建平滑的闪烁渐变?
- python - PyTorch 值阈值和归零所有其他值
- typescript - 为什么 Lerna 总是发布我所有的包?
- nlog - NLog - 使用过滤器忽略某些消息
- laravel - Laravel:将类名从控制器传递给作业
- r - 无法在 blogdown 中转义 LaTeX 美元符号`$`
- sybase - 无法启动指定数据库:自动启动数据库失败 -Sybase DB
- python - Python Pandas - 如果ID相同并且其中一行等于一个数字,则选择多行
- google-cloud-firestore - 向云函数中的文档添加数据:如何停止循环回调?
- dns - 如何在具有本地 DNS 服务器的本地网络上的 ansible awx 中生成动态主机清单