首页 > 解决方案 > 更改为使用 tf.data api 后,Keras 模型无法学习任何东西

问题描述

我试图将一个简单的 Keras 模型转换为使用 tf.data api 进行数据加载,但不知何故,在整个 10 个时期内准确度保持在 10% 左右。

相比之下,不使用 tf.data api 的原始代码可以轻松达到 98% 左右的准确率。我做错什么了吗?

使用 tf.data api 的版本

import math
import tensorflow as tf
import numpy as np

batch_size = 32


def load_data():
    mnist = tf.keras.datasets.mnist
    (train_data, train_label), (validation_data, validation_label) = mnist.load_data()
    train_data, validation_data = train_data / 255.0, validation_data / 255.0
    train_label = train_label.astype(np.float32)
    return train_data, train_label


def build_model():
    class MyModel(tf.keras.Model):

        def __init__(self):
            super(MyModel, self).__init__(name='my_model')
            self.flatten = tf.keras.layers.Flatten()
            self.dense_1 = tf.keras.layers.Dense(512, activation=tf.nn.relu)
            self.dropout = tf.keras.layers.Dropout(0.2)
            self.dense_2 = tf.keras.layers.Dense(10, activation=tf.nn.softmax)

        def call(self, inputs):
            x = self.flatten(inputs)
            x = self.dense_1(x)
            x = self.dropout(x)
            y = self.dense_2(x)
            return y

    model = MyModel()

    model.compile(optimizer=tf.train.AdamOptimizer(),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    return model


train_data, train_label = load_data()
train_sample_count = len(train_data)

train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_label))

train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.repeat()

model = build_model()
model.fit(
    train_dataset,
    epochs=10,
  steps_per_epoch=math.ceil(train_sample_count/batch_size)
)

不使用 tf.data api 的版本

# load_data and build_model are exactly same as those in the tf.data api version

train_data, train_label = load_data()
model = build_model()
model.fit(
    train_data,
    train_label,
    epochs=10
)

标签: tensorflowkeras

解决方案


推荐阅读