python - 如何在 pytorch 中使用 checkpoint 模型文件来测试 CIFAR-10 数据集?
问题描述
model = SqueezeNext()
model = model.to(device)
def load_checkpoint(model, optimizer, losslogger, filename='SqNxt_23_1x_Cifar.ckpt'):
# Note: Input model & optimizer should be pre-defined. This routine only updates their states.
start_epoch = 0
if os.path.isfile(filename):
print("=> loading checkpoint '{}'".format(filename))
checkpoint = torch.load(filename)
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
losslogger = checkpoint['losslogger']
print("=> loaded checkpoint '{}' (epoch {})"
.format(filename, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(filename))
return model, optimizer, start_epoch, losslogger
model, optimizer, start_epoch, losslogger = load_checkpoint(model, optimizer, losslogger)
TypeError: Traceback (last last call last) in () 41 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=80, num_workers=8, shuffle=False) 42 ---> 43 model = SqueezeNext() 44 model = model.to(device) 45 def load_checkpoint(model, optimizer, losslogger, filename='SqNxt_23_1x_Cifar.ckpt'): TypeError: init () missing 3 required positional arguments: 'width_x', 'blocks', 和 'num_classes'
我认为我没有以正确的方式实现这一点!!
解决方案
您的错误不是来自您的检查点功能。如果我们查看回溯:
> TypeError: Traceback (most recent call last)
> <ipython-input-51-94c8be648862> in <module>()
> 41 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=80, num_workers=8, shuffle=False)
> 42
> ---> 43 model = SqueezeNext()
> 44 model = model.to(device)
> 45 def load_checkpoint(model, optimizer, losslogger, filename='SqNxt_23_1x_Cifar.ckpt'): TypeError: __init__() missing 3
> required positional arguments: 'width_x', 'blocks', and 'num_classes'
我们被告知的那一行打破了第 43 行:
> ---> 43 model = SqueezeNext()
错误是:
> required positional arguments: 'width_x', 'blocks', and 'num_classes'
我假设您正在使用SqueezeNext 的此实现,但无论您使用哪种实现,您都没有传递初始化模型所需的所有参数。您需要将代码更改为:
model = SqueezeNext(width_x=1.0, blocks=[6, 6, 8, 1], num_classes=10)
如果您不使用该实现,则需要找到SqueezeNext
模型的源代码,并查看__init__
函数需要哪些参数。你可以试试这个:
import inspect
inspect.signature(SqueezeNext.__init__)
哪个应该给你签名。
推荐阅读
- flutter - SharedPreference 不在我的颤振应用程序上运行
- laravel-6 - count():参数必须是数组或对象,在 Laravel 6 和 php 7.2+ 中实现了 Countable
- html - 我很难在我的网站上的图像下将标题居中
- angular - 如何以编程方式隐藏 ngx-toaster?
- pytorch - RuntimeError:只有浮点 dtype 的张量才能需要渐变
- java - 如何使用 toString 方法在第三个对象中打印两个对象?
- date - 谷歌电子表格 - 谷歌脚本中两个日期之间的工作日数,节假日除外
- git - mac OS catalina 上的 git clone 命令面临问题
- sql - 使用 SQL Server 对学生成绩进行排序
- java - 从 Redis 获取密钥