首页 > 解决方案 > 使用“稀疏分类交叉熵”

问题描述

我无法理解为什么稀疏分类交叉熵不适用于 SVHN 数据集。

import tensorflow as tf
from scipy.io import loadmat
import numpy as np

train = loadmat('data/train_32x32.mat')
test = loadmat('data/test_32x32.mat')

x_train = train['X']
y_train = train['y']
x_train = x_train.astype('float64')
y_train = y_train.astype('int64')

x_test = test['X']
x_test = x_test.astype('float64')
y_test = test['y']
y_test = y_test.astype('int64')

# reorder data
x_train = np.moveaxis(x_train, -1, 0)
x_test = np.moveaxis(x_test, -1, 0)
def colored_to_gray(x):
    '''
    input shape: n_sample, n_x, x_y, n_channel
    output shape: n_sample, n_x, x_y, 1
    this is a rudementary way of converting a colored image into gray image
    '''
    x = np.mean(x, axis=-1, keepdims=True)
    return x

def normalize_data(x):
    '''
    normalize data so that values are between 0 to 1
    '''
    x = x / 255.0
    return x

x_train = colored_to_gray(x_train)
x_test = colored_to_gray(x_test)

x_train = normalize_data(x_train)
x_test = normalize_data(x_test)
print("Shape of Training Data: {}".format(x_train.shape))
print("Shape of Training Labels: {}".format(y_train.shape))
print("Shape of Testing Data: {}".format(x_test.shape))
print("Shape of Testing Labels: {}".format(y_test.shape))

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense

model = Sequential([
    Flatten(name='Flatten_Input', input_shape=x_train.shape[1:]),
    Dense(units=1024, activation='relu', name='Dense_1'),
    Dense(units=512, activation='relu', name='Dense_2'),
    Dense(units=256, activation='relu', name='Dense_3'),
    Dense(units=32, activation='relu', name='Dense_4'),
    Dense(units=10, activation='softmax', name='Output')
])

opt = tf.keras.optimizers.Adam(learning_rate=0.0001)

model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
model.fit(x_train, y_train, epochs=2, batch_size=256)

有了这个 model.fit 调用,我期待它可以在 10 个班级和火车上工作。相反,我得到 'nan' 作为损失输出,而 0 表示准确性。

谢谢,

标签: pythonkerasdeep-learning

解决方案


Sparse是当每个图像只属于一个类时。在 SVHN 数据集中它不是,例如图像3213, 21,它是多类的。将其更改为categorical_crossentropy,它应该可以工作。此外,您没有使用准确性作为指标。


推荐阅读