pytorch - RuntimeError: shape '[10, 3, 150, 150]' 对于大小为 472500 的输入无效
问题描述
我正在尝试对 covid CT 数据集执行卷积操作并不断收到此错误。我在火车装载机中的图像大小为 (10, 150, 150, 3),我使用 torch.reshape() 将其重新整形为 [10, 3, 150, 150]。谁能帮我解决问题
我的 CNN 代码
class BConv(nn.Module):
def __init__(self, out=3):
super(BConv, self).__init__()
#(10, 150, 150, 3)
self.conv1=nn.Conv2d(in_channels=3,out_channels=12,kernel_size=3,stride=1,padding=1)
self.bn1=nn.BatchNorm2d(num_features=12)
self.relu1=nn.ReLU()
self.pool=nn.MaxPool2d(kernel_size=2)
self.conv2=nn.Conv2d(in_channels=12,out_channels=20,kernel_size=3,stride=1,padding=1)
self.relu2=nn.ReLU()
# self.conv3=nn.Conv2d(in_channels=20,out_channels=32,kernel_size=3,stride=1,padding=1)
# self.bn3=nn.BatchNorm2d(num_features=32)
# self.relu3=nn.ReLU()
self.fc=nn.Linear(in_features= 20*75*75, out_features=3)
def forward(self,input):
output=self.conv1(input)
#print("output 1", output.shape)
output=self.bn1(output)
#print("output 1", output.shape)
output=self.relu1(output)
#print("output 1", output.shape)
output=self.pool(output)
#print("output 1", output.shape)
output=self.conv2(output)
#print("output 1", output.shape)
output=self.relu2(output)
#print("output 1", output.shape)
# output=self.conv3(output)
# output=self.bn3(output)
# output=self.relu3(output)
print(output.shape)
#Above output will be in matrix form, with shape (256,32,75,75)
output=output.view(output.size(0), -1)
output=self.fc(output)
return output
数据预处理
class Ctdataset(Dataset):
def __init__(self, path):
self.data= pd.read_csv(path, delimiter=" ")
data= self.data.values.tolist()
self.image= []
self.labels=[]
for i in data:
self.image.append(i[0])
self.labels.append(i[1])
#print(len(self.image), len(self.labels))
#self.class_map = {"0": 0, "1":1 , "2": 2}
def __len__(self):
return len(self.image)
def __getitem__(self, idx):
img_path = os.path.join("2A_images", self.image[idx])
img= Image.open(img_path).convert("RGB")
img= img.resize((150, 150))
img= np.array(img)
img= img.astype(float)
return img, label
解决方案
在这里,我正在考虑您的整个模型,包括由 、 和 组成的第三conv3
个bn3
块relu3
。有几点需要注意:
重塑与排列轴有很大不同。当你说你有一个输入形状时
(batch_size, 150, 150, 3)
,这意味着通道轴是最后一个。由于 PyTorch 2D 内置层以NHW
您需要置换轴的格式工作:您可以这样做torch.Tensor.permute
:>>> x = torch.rand(10, 150, 150, 3) >>> x.permute(0, 3, 1, 2).shape (10, 3, 150, 150)
假设您的输入是 shape
(batch_size, 3, 150, 150)
,那么 的输出 shaperelu3
将是(32, 75, 75)
。因此,以下全连接层必须具有精确的32*75*75
输入特征。但是,您需要像在代码中那样使用
view
:来展平此张量output = output.view(output.size(0), -1)
。另一种方法是定义一个self.flatten = nn.Flatten()
层并用output = self.flatten(output)
.从 PyTorch v 1.8.0
in_features
开始,在全连接层中设置的替代方法是使用nn.LazyLinear
它将根据第一个推断为您初始化它:>>> self.fc = nn.LazyLinear(out_features=3)
旁注:您不需要使用 、 和 定义单独的 ReLU 层,
relu1
因为它们是非参数函数:relu2
relu3
>>> self.relu = nn.ReLU()
以下是完整代码供参考:
class BConv(nn.Module):
def __init__(self, out=3):
super().__init__()
# input shape (10, 150, 150, 3)
self.conv1 = nn.Conv2d(3, 12,kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(num_features=12)
self.pool = nn.MaxPool2d(kernel_size=2)
self.conv2 = nn.Conv2d(12, 20,kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(20, 32, kernel_size=3, stride=1, padding=1)
self.bn3 = nn.BatchNorm2d(num_features=32)
self.relu = nn.ReLU()
self.flatten = nn.Flatten()
self.fc = nn.Linear(in_features=32*75*75, out_features=out)
def forward(self,input):
output = input.permute(0, 3, 1, 2)
output = self.conv1(output)
output = self.bn1(output)
output = self.relu(output)
output = self.pool(output)
output = self.conv2(output)
output = self.relu(output)
output = self.conv3(output)
output = self.bn3(output)
output = self.relu(output)
output = self.flatten(output)
output = self.fc(output)
return output
推荐阅读
- prolog - 生成总和为给定数字的所有对
- java - Overloading in Java for user input?
- regex - Regex on a substring
- django - Accessing django website hosted on vm with mobile device
- python - 在 Python 3 IDLE 中播放 Raspberry Pi 3B+ 上的音频文件
- c# - 是否可以从与 bin/roslyn 不同的路径使用 roslyn/csc.exe?
- python - ValueError with attention Dimension1 in both shape must be equal
- linux - Rancher - standard_init_linux.go:190: exec user process caused "permission denied"
- python - Python程序循环问题
- python - Is it possible to check if a function is decorated inside another function?