python - 无法列出 pytorch 网络的参数
问题描述
我正在尝试构建具有以下结构的嵌套单元:
class EFUnet(nn.Module):
backbone = efn.EfficientNetB3(
weights=None,
include_top=False,
input_shape=(256,1600,3)
)
backbone.load_weights(('../input/efficientnet-keras-weights-b0b5/'
'efficientnet-b3_imagenet_1000_notop.h5'))
def __init(self,in_channel = 3, out_channels = 1, features = 32):
super(EFUnet, self).__init__()
self.conv00 = backbone.input
self.conv10 = backbone.get_layer('stem_activation').output
self.conv20 = backbone.get_layer('block2c_add').output
self.conv30 = backbone.get_layer('block3c_add').output
self.conv40 = backbone.get_layer('block5e_add').output
self.conv50 = backbone.get_layer('block7b_add').output
self.conv01 = _H(features*(2^0+2^1), features*2^0)
self.conv11 = _H(feathres*(2^1+2^2),feature*2^1)
self.conv21 = _H(feathres*(2^2+2^3),feature*2^2)
self.conv31 = _H(feathres*(2^3+2^4),feature*2^3)
self.conv41 = _H(feathres*(2^4+2^5),feature*2^4)
self.conv02 = _H(features*(2^0*2+2^1), features*2^0)
self.conv12 = _H(features*(2^1*2+2^2), features*2^1)
self.conv22 = _H(features*(2^2*2+2^3), features*2^2)
self.conv32 = _H(features*(2^3*2+2^4), features*2^3)
self.conv03 = _H(features*(2^0*3+2^1), features*2^0)
self.conv13 = _H(features*(2^1*3+2^2), features*2^1)
self.conv23 = _H(features*(2^2*3+2^3), features*2^2)
self.conv04 = _H(features*(2^0*4+2^1), features*2^0)
self.conv14 = _H(features*(2^1*4+2^2), features*2^1)
self.conv05 = _H(features*(2^0*5+2^1), features*2^0)
self.final1 = nn.Conv2d(features, 1, kernel_size=1)
self.final2 = nn.Conv2d(features, 1, kernel_size=1)
self.final3 = nn.Conv2d(features, 1, kernel_size=1)
self.final4 = nn.Conv2d(features, 1, kernel_size=1)
self.final5 = nn.Conv2d(features, 1, kernel_size=1)
self.final = nn.Conv2d(feature*5, 4, (3,3), padding="same", activation="sigmoid")
def forward(self,input):
x00 = self.conv00(input)
x10 = self.conv10(x00)
x01 = self.conv01(torch.cat([x00,self._U(x10)],1))
x20 = self.conv20(x10)
x11 = self.conv11(torch.cat([x10,self._U(x20)],1))
x02 = self.conv02(torch.cat([x00,x01,self._U(x11)],1))
x30 = self.conv30(x20)
x21 = self.conv21(torch.cat([x20,self._U(x30)],1))
x12 = self.conv12(torch.cat([x10,x11,self._U(x21)],1))
x03 = self.conv03(torch.cat([x00,x01,x02,self._U(x12)],1))
x40 = self.conv40(x30)
x31 = self.conv31(torch.cat([x30,self._U(x40)],1))
x22 = self.conv22(torch.cat([x20,x21,self._U(x31)],1))
x13 = self.conv13(torch.cat([x10,x11,x12,self._U(x22)],1))
x04 = self.conv04(torch.cat([x00,x01,x02,x03,self._U(x13)],1))
x50 = self.conv50(x40)
x41 = self.conv41(torch.cat([x40,self._U(x50)],1))
x32 = self.conv32(torch.cat([x30,x31,self._U(x41)],1))
x23 = self.conv23(torch.cat([x20,x21,x22,self._U(x32)],1))
x14 = self.conv14(torch.cat([x10,x11,x12,x13,self._U(x23)],1))
x05 = self.conv05(torch.cat([x00,x01,x02,x03,x04,self._U(x14)],1))
output1 = self.final1(x01)
output2 = self.final2(x02)
output3 = self.final3(x03)
output4 = self.final4(x04)
output5 = self.final4(x05)
x_out = torch.cat([output1, output2, output3, output4, output5],1)
x_out = self.final(x_out)
return x_out
def _H(in_channels, features, use_gn=True):
if use_gn:
norm = torch.nn.GroupNorm(num_channels = 3, num_groups=1)
else:
norm = BatchNormalization(number_features = features)
return nn.Sequential(
OrderedDict(
[
(name + "conv", nn.Conv2D(in_channels, features, (2, 2), padding='same')),
(name + "norm", norm()),
(name + 'LReLU',LeakyReLU(alpha=0.1))
]
)
)
def _U(in_channels, features, use_gn=True):
if use_gn:
norm = torch.nn.GroupNorm(num_channels = 3, num_groups=1)
else:
norm = BatchNormalization(number_features = features)
return nn.Sequential(
OrderedDict(
[
(name + "upconv", nn.ConvTranspose2d(in_channels, features, (2, 2), padding='same')),
(name + "norm", norm()),
(name + 'LReLU',LeakyReLU(alpha=0.1))
]
)
)
当我把它放在亚当优化器中时。它抱怨说
ValueError: optimizer got an empty parameter list
所以我尝试做一些 QC 来检查 Unet 的参数。使用以下代码:
model = EFUnet()
model = model.cuda()
print(list(model.parameters))
但是,python 抱怨输出是一种不可迭代的方法。
TypeError: 'method' object is not iterable
任何人都可以帮助查看导致python无法获取参数的问题的原因吗?
谢谢!
解决方案
在 pytorch 中要获取参数,应该调用该方法,该方法model.parameters()
将返回一个生成器对象,您可以在该对象上进行迭代。
或者
更好的方法是使用model.named_parameters()
它将再次返回生成器对象,参数映射到相应的层名称。
推荐阅读
- json - Convert JSON dict in Bash to Powershell
- typescript - 对象可能是打字稿功能上的空错误
- c# - 在 UWP 应用中打开文件夹时访问被拒绝
- pandas - Pandas 和 SQLAlchemy:在连接期间重命名列
- reactjs - 如何立即更新 react useState?
- c# - c# FirestoreDB Winforms - 无法访问 firebase?
- mysql - 如何使用 DISTINCT 和 COUNT 创建 MySQL 嵌套查询
- python-3.x - 正则表达式捕获两个正则表达式模式之间的 n 行文本
- firebase - 尝试通过 Play 游戏登录时 Unity 应用程序崩溃
- xamarin - 在 Xamarin Forms Shell 中向 GoToAsync URI Navigation 添加查询会导致它停止工作