首页 > 解决方案 > 如何在 PyTorch conv2d 函数中批量使用组参数?

问题描述

接下来是如何在 PyTorch conv2d 函数中使用组参数中的问题

我可以知道输入批次大小是否= 4,对于每个批次,它都有独立的过滤器来与之转换,我将代码修改如下,

import torch
import torch.nn.functional as F

filters = torch.autograd.Variable(torch.randn(3,4,3,3))
inputs = torch.autograd.Variable(torch.randn(4,3,10,10))
out = F.conv2d(inputs, filters, padding=1, groups=3)

我有另一个错误 RuntimeError: Given groups=3, weight of size [3, 4, 3, 3], expected input[4, 3, 10, 10] 有 12 个通道,但有 3 个通道 如何解决?

标签: pytorch

解决方案


当您有过滤器时,shape (3,4,3,3)预计通道数为 12

这应该工作

import torch
import torch.nn.functional as F
inputs = torch.autograd.Variable(torch.randn(3,12,10,10))
filters = torch.autograd.Variable(torch.randn(3,4,3,3))
out = F.conv2d(inputs, filters, padding=1, groups=3)

推荐阅读