python-3.x - Pytorch 试图让 NN 收到无效的参数组合
问题描述
我正在尝试使用 pytroch 构建我的第一个 NN,但遇到了问题。
TypeError: new() 收到了无效的参数组合 - 得到 (float, int, int, int),但预期是以下之一:* (torch.device device) * (torch.Storage storage) * (Tensor other) * (tuple整数大小,torch.device 设备)*(对象数据,torch.device 设备)
现在我知道这是在说什么,因为我没有将正确的类型传递给方法或 init。但我不知道我应该通过什么,因为它看起来对我来说是正确的。
def main():
#Get the time and data
now = datetime.datetime.now()
hourGlassToStack = 2 #Hourglasses to stack
numModules= 2 #Residual Modules for each hourglass
numFeats = 256 #Number of features in each hourglass
numRegModules = 2 #Depth regression modules
print("Creating Model")
model = HourglassNet3D(hourGlassToStack, numModules, numFeats,numRegModules).cuda()
print("Model Created")
这是创建模型的主要方法。然后它调用这个方法。
class HourglassNet3D(nn.Module):
def __init__(self, nStack, nModules, nFeats, nRegModules):
super(HourglassNet3D, self).__init__()
self.nStack = nStack
self.nModules = nModules
self.nFeats = nFeats
self.nRegModules = nRegModules
self.conv1_ = nn.Conv2d(3, 64, bias = True, kernel_size = 7, stride = 2, padding = 3)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace = True)
self.r1 = Residual(64, 128)
self.maxpool = nn.MaxPool2d(kernel_size = 2, stride = 2)
self.r4 = Residual(128, 128)
self.r5 = Residual(128, self.nFeats)
_hourglass, _Residual, _lin_, _tmpOut, _ll_, _tmpOut_, _reg_ = [], [], [], [], [], [], []
for i in range(self.nStack):
_hourglass.append(Hourglass(4, self.nModules, self.nFeats))
for j in range(self.nModules):
_Residual.append(Residual(self.nFeats, self.nFeats))
lin = nn.Sequential(nn.Conv2d(self.nFeats, self.nFeats, bias = True, kernel_size = 1, stride = 1),
nn.BatchNorm2d(self.nFeats), self.relu)
_lin_.append(lin)
_tmpOut.append(nn.Conv2d(self.nFeats, 16, bias = True, kernel_size = 1, stride = 1))
_ll_.append(nn.Conv2d(self.nFeats, self.nFeats, bias = True, kernel_size = 1, stride = 1))
_tmpOut_.append(nn.Conv2d(16, self.nFeats, bias = True, kernel_size = 1, stride = 1))
for i in range(4):
for j in range(self.nRegModules):
_reg_.append(Residual(self.nFeats, self.nFeats))
self.hourglass = nn.ModuleList(_hourglass)
self.Residual = nn.ModuleList(_Residual)
self.lin_ = nn.ModuleList(_lin_)
self.tmpOut = nn.ModuleList(_tmpOut)
self.ll_ = nn.ModuleList(_ll_)
self.tmpOut_ = nn.ModuleList(_tmpOut_)
self.reg_ = nn.ModuleList(_reg_)
self.reg = nn.Linear(4 * 4 * self.nFeats,16 )
然后这叫这个
class Residual(nn.Module):
#set the number ofinput and output for each layer
def __init__(self, numIn, numOut):
super(Residual, self).__init__()
self.numIn = numIn
self.numOut = numOut
self.bn = nn.BatchNorm2d(self.numIn)
self.relu = nn.ReLU(inplace = True)
self.conv1 = nn.Conv2d(self.numIn, self.numOut / 2, bias = True, kernel_size = 1)
self.bn1 = nn.BatchNorm2d(self.numOut / 2)
self.conv2 = nn.Conv2d(self.numOut / 2, self.numOut / 2, bias = True, kernel_size = 3, stride = 1, padding = 1)
self.bn2 = nn.BatchNorm2d(self.numOut / 2)
self.conv3 = nn.Conv2d(self.numOut / 2, self.numOut, bias = True, kernel_size = 1)
if self.numIn != self.numOut:
self.conv4 = nn.Conv2d(self.numIn, self.numOut, bias = True, kernel_size = 1)
所有这一切对我来说都很好,但我不知道如果我做错了我应该如何通过这个。感谢您的任何帮助
解决方案
您可能需要注意您在Residual
课堂上传递给卷积层的内容。默认情况下,Python 3 会将任何除法运算转换为浮点变量。
尝试将变量转换回整数,看看是否有帮助。固定代码Residual
:
class Residual(nn.Module):
#set the number ofinput and output for each layer
def __init__(self, numIn, numOut):
super(Residual, self).__init__()
self.numIn = numIn
self.numOut = numOut
self.bn = nn.BatchNorm2d(self.numIn)
self.relu = nn.ReLU(inplace = True)
self.conv1 = nn.Conv2d(self.numIn, int(self.numOut / 2), bias = True, kernel_size = 1)
self.bn1 = nn.BatchNorm2d(int(self.numOut / 2))
self.conv2 = nn.Conv2d(int(self.numOut / 2), int(self.numOut / 2), bias = True, kernel_size = 3, stride = 1, padding = 1)
self.bn2 = nn.BatchNorm2d(int(self.numOut / 2))
self.conv3 = nn.Conv2d(int(self.numOut / 2), self.numOut, bias = True, kernel_size = 1)
if self.numIn != self.numOut:
self.conv4 = nn.Conv2d(self.numIn, self.numOut, bias = True, kernel_size = 1)
推荐阅读
- ajax - HTTP 同步性
- jekyll - 在 Liquid (Jekyll) 中比较 forloop.index|modulo:4 和 0
- python - 如何在 python 中调用函数并且不再计算?
- android - Activity 和 Fragment 的不同 WindowSoftInputMode
- android - Android 构建错误:属性签名需要 InnerClasses 属性。检查 -keepattributes 指令
- reactjs - 使用 axios 直接将图片上传到 cloudinary
- react-native - 如何在 React Native 上使用 Pusher 创建用户?
- facebook - 如何升级 Facebook API?
- clickhouse - 新版本 18.10.3 中的 clickhouse 回合错误
- angular - 带有输入字段的 Ngb-dropdown 无法按预期工作