首页 > 解决方案 > 识别不属于 CNN 训练的类别的图像

问题描述

假设我有一个带有 CNN 训练的应用程序来预测猫和狗,并且用户输入了一张苹果的图像,我的 CNN 将苹果预测为猫或狗。如何让我的应用程序说用户输入了错误的图像而不是做出预测?

很想知道你们是如何处理类似情况的。

标签: tensorflowconv-neural-network

解决方案


好吧,可能有很多方法可以实现,但是,这里有一些简单的方法。

方法:1

如果您的数据集中有 2 个类,则只需添加第三个类,即“ unknown class”,代表一些随机数据,这意味着您可以在该类中添加您想要的所有垃圾(异常值),例如苹果、花、鸟的照片等(从哪里获得这些图像?只需下载并连接另一个数据集,如 cifar-10 等)。然后你可以在这个数据上训练你的网络unknown class,当它既不是狗也不是猫时,网络将预测第三类“ ”。

方法:2

但是,更好的方法(在我看来)是在输出层中使用 sigmoid 激活 2 个神经元,一个代表成为猫的概率,另一个代表成为狗的概率。所以在训练之后,当你喂一张狗的图像时,你可能会在你的输出层得到这样的值[0.9, 0.07](狗的概率为 90%,猫的概率为 7%)。所以直观地说,当你输入一个随机图像(比如苹果的图像)时,输出层可能会产生以下输出[0.3, 0.27],因此网络似乎无法确定它是猫还是狗。现在你可以很容易地设置一个阈值,比如 60%,所以只有当模型以超过 60% 的概率预测猫或狗时,你才会将输出/预测发送给用户,否则你会发送“未知类别”


推荐阅读