python - “DataParallel”对象没有属性“conv1”
问题描述
我正在尝试conv1
根据下面的代码和架构可视化图层的 cnn 网络特征图。它在没有 DataParallel 的情况下正常工作,但是当我激活model = nn.DataParallel(model)
它时出现错误:“DataParallel”对象没有属性“conv1”。任何建议表示赞赏。
class Model(nn.Module):
def __init__(self, kernel, num_filters, res = ResidualBlock):
super(Model, self).__init__()
self.conv0 = nn.Sequential(
nn.Conv2d(4, num_filters, kernel_size = kernel*3,
padding = 4),
nn.BatchNorm2d(num_filters),
nn.ReLU(inplace=True))
self.conv1 = nn.Sequential(
nn.Conv2d(num_filters, num_filters*2, kernel_size = kernel,
stride=2, padding = 1),
nn.BatchNorm2d(num_filters*2),
nn.ReLU(inplace=True))
self.conv2 = nn.Sequential(
nn.Conv2d(num_filters*2, num_filters*4, kernel_size = kernel, stride=2, padding = 1),
nn.BatchNorm2d(num_filters*4),
nn.ReLU(inplace=True))
self.tsconv0 = nn.Sequential(
nn.ConvTranspose2d(num_filters*4, num_filters*2, kernel_size = kernel, padding = 1),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_filters*2))
self.tsconv1 = nn.Sequential(
nn.ConvTranspose2d(num_filters*2, num_filters, kernel_size = kernel, padding = 1),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.ReLU(inplace=True),
nn.BatchNorm2d(num_filters))
self.tsconv2 = nn.Sequential(
nn.Conv2d(num_filters, 1, kernel_size = kernel*3, padding = 4, bias=False),
nn.ReLU(inplace=True))
model = Model(kernel, num_filters)
model = nn.DataParallel(model)
特征图可视化的代码:
def get_activation(name):
def hook(model, x_train_batch, y_train_pred):
activation[name] = y_train_pred.detach()
return hook
model.conv3.register_forward_hook(get_activation('conv3'))
x_train_batch[0,0,:,:]
y_train_pred = model(x_train_batch)
act = activation['conv3'].squeeze()
act1 = act.cpu().detach().numpy()
act=act[0,:,:,:]
fig, axarr = plt.subplots(6,16)
k = 0
for idx in range(act.size(0)//16):
for idy in range(act.size(0)//6):
axarr[idx, idy].imshow(act[k])
k += 1
解决方案
使用时DataParallel
,在此处添加一个额外的module
。而不是model.conv3.
简单地写model.module.conv3.
推荐阅读
- c++ - 如何让“敌人”围绕玩家运行?
- android - 错误:没有名称参数名称:颤振中的“trackCameraPosition:true”
- amibroker - For 循环不接受数组。如何解决此 Amibroker 代码?
- c# - 什么是 Cocoa/CSS image.png/image@2x.png 概念的 WPF 等价物?
- html - 如何突出显示任何图像中给出的文本并使用 html、css、js 保存相同的图像?
- android - 如何在改造中发布带有数组数据的数组
- c# - 有没有办法循环这段代码,但是每次给变量一个不同的字符串?
- bash - 使用 tput 时如何在 bash 中使用 printf 格式化列
- c# - LINQ检查列表是否包含另一个列表中的任何项目mysql语法错误
- python - TensorFlow:帮助创建服务输入函数