首页 > 解决方案 > the_model = TheModelClass(*args, **kwargs) 是什么意思?

问题描述

我正在使用 PyTorch 进行图像分类。训练最多后,我想保存训练好的模型。

我不明白是什么意思

the_model = TheModelClass(*args, **kwargs)

这行代码由 PyTorch 网站 ( https://pytorch.org/docs/master/notes/serialization.html ) 提供。

标签: pythonpytorch

解决方案


这个问题the_model = TheModelClass(*args, **kwargs) 意味着您必须首先定义一个 ModelClass 对象。然后您可以使用模型对象来加载磁盘顺序对象。例如:

in_feats = data.x.shape[1]
n_hidden = params["n_hidden"]
n_classes = 2
best_model = OwnGCN(in_c=in_feats, hid_c=n_hidden, out_c=n_classes)
best_model.load_state_dict(torch.load(PATH))

推荐阅读