首页 > 解决方案 > Python3函数前'*'的用途

问题描述

我在 Python3 和 PyTorch 中看到了 ResNet CNN 的代码,如下所示:

def resnet_block(input_channels, num_channels, num_residuals,
                 first_block=False):
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual(input_channels, num_channels,
                                use_1x1conv=True, strides=2))
        else:
            blk.append(Residual(num_channels, num_channels))
    return blk

要添加模块,请使用以下代码 -

b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))

“*resnet_block()”是什么意思/做什么?

标签: pythonpython-3.xpytorchiterableiterable-unpacking

解决方案


基本上*iterable用于将可迭代对象的项目解包为位置参数。在您的问题resnet_block中返回一个列表,并且该列表的项目被传递给nn.Sequential而不是列表本身。


推荐阅读