首页 > 解决方案 > 每次迭代在所有类上使用 CNN / ResNet 进行训练/预测 - 输入数据的串联 + 匈牙利算法

问题描述

因此,我有一个简单的 pytorch 示例,说明如何通过此链接训练 ResNet CNN 以学习 MNIST 标记:

https://zablo.net/blog/post/using-resnet-for-mnist-in-pytorch-tutorial/index.html

它工作得很好,但我想稍微修改一下,让它做两件事。首先,它不是预测数字,而是预测我正在从事的项目的动物形状/颜色。这已经很好地工作了,并且对此感到满意。

其次,我想破解训练(可能还有层),以便一次在多个图像上并行完成预测。在 MNIST 示例中,基本上预测(或输出)将对我连接的一次具有 10 位数字的图像进行。为清楚起见,每个 10 图像输入将具有数字 0-9,每个仅出现一次。这里的关键是 10 位数字中的每一个都从 CNN/ResNet 获得一个唯一的类/标签,并且每个类都被分配一次。并且具有高置信度的数字将阻止其他置信度较低的数字使用该标签(匈牙利算法类型的方法)。

因此,在我的用例中,我想训练连接图像(而不是单个图像),如下图 A 所示,并强制分类器学习预测每个连接图像的最佳唯一标签,并一次性完成所有操作。这种方法应该优于单一图像分类 - 它对我的动物分类特别有用,因为否则 CNN 有时会为多个动物返回相同的 ID,这在我的应用程序中是不可能的。

我已经可以按下面的图 B 进行系列预测。实际上,查看每个预测的置信度,我能够实施一种类似于匈牙利算法的方法后预测,以在每批 4 只动物中分配最佳(最有信心)的唯一 ID。但这并不总是有效,我想知道 ResNet 是否也可以尝试学习贪婪的匈牙利作业。

在此处输入图像描述

特别是,尚不清楚实现 A 只需要增加数据输入,训练集中的标签会自动完成 - 因为我不知道如何惩罚或禁止为每组图像两次返回相同的标签。所以现在我可以像这样生成这些训练数据集:

print (train_loader.dataset.data.shape)
print (train_loader.dataset.targets.shape)
torch.Size([60000, 28, 28])
torch.Size([60000])

我想我希望目标是 [60000, 10]。每个输入图像将是 [1, 28, 28, 10]?但我不确定正确的方法是什么。

任何建议或可用链接?

我认为这是一种特定类型的培训,但我忘记了名称。

标签: pytorchconv-neural-networkmnistresnet

解决方案


推荐阅读