tensorflow - 更改为使用 tf.data api 后,Keras 模型无法学习任何东西
问题描述
我试图将一个简单的 Keras 模型转换为使用 tf.data api 进行数据加载,但不知何故,在整个 10 个时期内准确度保持在 10% 左右。
相比之下,不使用 tf.data api 的原始代码可以轻松达到 98% 左右的准确率。我做错什么了吗?
使用 tf.data api 的版本
import math
import tensorflow as tf
import numpy as np
batch_size = 32
def load_data():
mnist = tf.keras.datasets.mnist
(train_data, train_label), (validation_data, validation_label) = mnist.load_data()
train_data, validation_data = train_data / 255.0, validation_data / 255.0
train_label = train_label.astype(np.float32)
return train_data, train_label
def build_model():
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__(name='my_model')
self.flatten = tf.keras.layers.Flatten()
self.dense_1 = tf.keras.layers.Dense(512, activation=tf.nn.relu)
self.dropout = tf.keras.layers.Dropout(0.2)
self.dense_2 = tf.keras.layers.Dense(10, activation=tf.nn.softmax)
def call(self, inputs):
x = self.flatten(inputs)
x = self.dense_1(x)
x = self.dropout(x)
y = self.dense_2(x)
return y
model = MyModel()
model.compile(optimizer=tf.train.AdamOptimizer(),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
train_data, train_label = load_data()
train_sample_count = len(train_data)
train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_label))
train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.repeat()
model = build_model()
model.fit(
train_dataset,
epochs=10,
steps_per_epoch=math.ceil(train_sample_count/batch_size)
)
不使用 tf.data api 的版本
# load_data and build_model are exactly same as those in the tf.data api version
train_data, train_label = load_data()
model = build_model()
model.fit(
train_data,
train_label,
epochs=10
)
解决方案
推荐阅读
- python - 在多字串的空格之间添加一个字符
- c - C 一个无符号整数怎么可能只占用两个字节?
- javascript - regex for only lowercase letters and underscore
- javascript - 如何更改图例点的颜色?
- swift - recent messages with firestore - 1 second document write limit - swift
- gridview - Drupal gets HTTP Result Code: 403 when try to change title of view to
- json - TypeError: Cannot read property 'map' of undefined React Native
- sockets - Should I resend or reconnect if the acknowledgement for a sent message is missing?
- excel - Struggling to pull information from XML file into excel using VBA
- python - 调用 GetObjectTagging 操作时:拒绝访问