首页 > 解决方案 > 我目前正在编写 Xpredict 函数,它是一般所有 keras 模型的 keras.predict function() 的包装器

问题描述

我想知道如何找到预测的相应类名?

Generator.class_indices 适用于数据来自生成器的少数模型。但是,对于数据不是来自生成器的少数模型,它会抛出一个错误说

AttributeError:“NumpyArrayIterator”对象没有属性“class_indices”

def xpredict(self, img_path, batch_size = None, verbose = 0, steps = None):
        x = image.load_img(img_path, target_size = Sequential.input_shape)
        x = image.img_to_array(x)
        x = np.expand_dims(x, axis = 0)
        result = self.predict(x, batch_size = batch_size, verbose = verbose, steps = None)
        for key, value in Sequential.generator.class_indices.items():
            if value == result:
                return key

预期的:

我想知道如何编写一个通用函数来从任何 keras 通用模型中预测 class_name。

实际的:

仅适用于使用 generator.class_indices 来自生成器的 training_data。

标签: pythonkerasdeep-learningconv-neural-networkrecurrent-neural-network

解决方案


不,一般情况下你不能这样做,类名是开发人员必须提供的,Keras 不知道这一点,它也不存储在任何地方。


推荐阅读