python - 填充 = 1 的 PyTorch MaxPool2D 意外行为
问题描述
我在玩MaxPool2D
inPyTorch
并在设置时发现了奇怪的行为padding=1
。这是我得到的:
代码:
import torch
from torch.nn.functional import max_pool2d
TEST = 1
def test_maxpool(negative=False, tnsr_size=2, kernel_size=2, stride=2, padding=0):
"""Test MaxPool2D.
"""
global TEST
print(f'=== TEST {TEST} ===')
print(*[f'{i[0]}: {i[1]}' for i in locals().items()], sep=' | ')
inp = torch.arange(1., tnsr_size ** 2 + 1).reshape(1, tnsr_size, tnsr_size)
inp = -inp if negative else inp
print('In:')
print(inp)
out = max_pool2d(inp, kernel_size, stride, padding=padding)
print('Out:')
print(out)
print()
TEST += 1
test_maxpool()
test_maxpool(True)
test_maxpool(padding=1)
test_maxpool(True, padding=1)
出去:
=== TEST 1 ===
negative: False | tnsr_size: 2 | kernel_size: 2 | stride: 2 | padding: 0
In:
tensor([[[1., 2.],
[3., 4.]]])
Out:
tensor([[[4.]]])
=== TEST 2 ===
negative: True | tnsr_size: 2 | kernel_size: 2 | stride: 2 | padding: 0
In:
tensor([[[-1., -2.],
[-3., -4.]]])
Out:
tensor([[[-1.]]])
=== TEST 3 ===
negative: False | tnsr_size: 2 | kernel_size: 2 | stride: 2 | padding: 1
In:
tensor([[[1., 2.],
[3., 4.]]])
Out:
tensor([[[1., 2.],
[3., 4.]]])
=== TEST 4 ===
negative: True | tnsr_size: 2 | kernel_size: 2 | stride: 2 | padding: 1
In:
tensor([[[-1., -2.],
[-3., -4.]]])
Out:
tensor([[[-1., -2.],
[-3., -4.]]])
测试 1、2、3很好,但测试 4很奇怪,我希望得到[[0 0], [0 0]]
张量:
In:
[[-1 -2]
[-3 -4]]
+ padding ->
[[ 0 0 0 0]
[ 0 -1 -2 0]
[ 0 -3 -4 0]
[ 0 0 0 0]]
-> kernel_size=2, stride=2 ->
[[0 0]
[0 0]]
根据测试 3,使用了零填充,但测试 4产生了有争议的结果。
那是什么样的填充物(如果有的话)?为什么MaxPool2D
会有这样的行为?
pytorch 1.3.1
解决方案
这是预期的行为,因为默认情况下会完成负无穷填充。
MaxPool 的文档现已修复。请参阅此 PR:修复 MaxPool 默认垫文档 #59404。
推荐阅读
- flutter - 存根函数总是返回 null
- spring-boot - 如何在超过特定日期后使实体自动更新?
- reactjs - 如何删除 Form.input antd design react 中的冒号?
- firebase - 这是构建我的firebase数据库的好方法吗?
- string - 如何在 Flutter 的 Text Widget 中将 double 转换为字符串
- javascript - 为什么 `return()` 和 `onsubmit()` 在 html 中创建登录表单以欢迎页面时不起作用
- wordpress - 在 WP Bakery /Wordpress 上显示
- c# - 如何使用 C# 中的宏从 Revit 中的链接中获取元素?
- r - R Plot_ly:如何更改颜色条调色板和颜色阈值?
- django - 无法导入“django.db”pylint(导入错误)