python - 在 Python 中具有相同精度的 TensorFlow
问题描述
我只是按照《 Hands-On Machine Learning with Scikit-Learn and TensorFlow 》一书中的一个TensorFlow示例进行操作,但得到了奇怪的结果。
这个例子:
import tensorflow as tf
from tensorflow import keras
tf.__version__
keras.__version__
fashion_mnist = keras.datasets.fashion_mnist
(X_train_full, y_train_full), (X_test, y_test) = fashion_mnist.load_data()
X_valid, X_train = X_train_full[:5000] / 255.0, X_train_full[5000:] / 255.0
y_valid, y_train = y_train_full[:5000] / 255.0, y_train_full[5000:] / 255.0
class_names = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
model = keras.models.Sequential([
keras.layers.Flatten(input_shape=[28, 28]),
keras.layers.Dense(300, activation="relu"),
keras.layers.Dense(100, activation="relu"),
keras.layers.Dense(10, activation="softmax")
])
model.compile(loss="sparse_categorical_crossentropy",
optimizer='sgd',
metrics=['accuracy'])
history = model.fit(X_train, y_train, epochs=50, validation_data=(X_valid, y_valid))
随着时代的发展,我们应该看到书中指出的准确性有所提高:
Train on 55000 samples, validate on 5000 samples
Epoch 1/30
55000/55000 [==========] - 3s 55us/sample - loss: 1.4948 - acc: 0.5757 - val_loss: 1.0042 - val_acc: 0.7166
Epoch 2/30
55000/55000 [==========] - 3s 55us/sample - loss: 0.8690 - acc: 0.7318 - val_loss: 0.7549 - val_acc: 0.7616
[...]
Epoch 50/50
55000/55000 [==========] - 4s 72us/sample - loss: 0.3607 - acc: 0.8752 - acc: 0.8752 -val_loss: 0.3706 - val_acc: 0.8728
但是当我跑步时,我得到了以下信息:
Epoch 1/30
1719/1719 [==============================] - 3s 2ms/step - loss: 0.0623 - accuracy: 0.1005 - val_loss: 0.0011 - val_accuracy: 0.0914
Epoch 2/30
1719/1719 [==============================] - 3s 2ms/step - loss: 8.7637e-04 - accuracy: 0.1011 - val_loss: 5.2079e-04 - val_accuracy: 0.0914
Epoch 3/30
1719/1719 [==============================] - 3s 2ms/step - loss: 4.9200e-04 - accuracy: 0.1019 - val_loss: 3.4211e-04 - val_accuracy: 0.0914
[...]
Epoch 49/50
1719/1719 [==============================] - 3s 2ms/step - loss: 3.1710e-05 - accuracy: 0.0992 - val_loss: 3.2966e-05 - val_accuracy: 0.0914
Epoch 50/50
1719/1719 [==============================] - 3s 2ms/step - loss: 2.7711e-05 - accuracy: 0.1022 - val_loss: 3.1833e-05 - val_accuracy: 0.0914
因此,正如您所看到的,再现的准确度大大降低,但并没有提高:它保持在0.0914而不是0.8728。
我的 TensorFlow 安装、设置甚至代码中是否有问题?
解决方案
你不能划分y
如y_valid, y_train = y_train_full[:5000] / 255.0, y_train_full[5000:] / 255.0
。完成的代码如下:
import tensorflow as tf
from tensorflow import keras
tf.__version__
keras.__version__
fashion_mnist = keras.datasets.fashion_mnist
(X_train_full, y_train_full), (X_test, y_test) = fashion_mnist.load_data()
X_train_full = X_train_full / 255.0
X_test = X_test / 255.0
class_names = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
model = keras.models.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation="relu"),
keras.layers.Dense(10, activation="softmax")
])
model.compile(loss="sparse_categorical_crossentropy",
optimizer='sgd',
metrics=['accuracy'])
history = model.fit(X_train_full, y_train_full, epochs=5, validation_data=(X_test, y_test))
它会给 acc 像:
Epoch 1/5
1875/1875 [==============================] - 2s 1ms/step - loss: 0.9880 - accuracy: 0.6923 - val_loss: 0.5710 - val_accuracy: 0.8054
Epoch 2/5
1875/1875 [==============================] - 2s 944us/step - loss: 0.5281 - accuracy: 0.8227 - val_loss: 0.5112 - val_accuracy: 0.8228
Epoch 3/5
1875/1875 [==============================] - 2s 913us/step - loss: 0.4720 - accuracy: 0.8391 - val_loss: 0.4782 - val_accuracy: 0.8345
Epoch 4/5
1875/1875 [==============================] - 2s 915us/step - loss: 0.4492 - accuracy: 0.8462 - val_loss: 0.4568 - val_accuracy: 0.8410
Epoch 5/5
1875/1875 [==============================] - 2s 935us/step - loss: 0.4212 - accuracy: 0.8550 - val_loss: 0.4469 - val_accuracy: 0.8444
此外,优化器adam
可能会给出比sgd
.
推荐阅读
- oracle - Oracle:从程序内部更新不起作用
- r - 组合长和短刻面标签时裁剪条背景
- r - 无法通过 Rstudio 中的控制台安装 ggplot2
- python - Python 函数返回“列表列表”而不是平面列表
- python - 使用python3从兄弟姐妹中找不到模块
- c - 将 libIEC61850 库与 GTK3+ for C 链接时出现问题
- php - 如何在 CakePHP 4 的翻译值中添加配置参数
- azure - 错误 - 访问被拒绝 - 部署到 Azure 应用服务
- javascript - React Native map() 和数组中的 JSX 元素
- python - MSC 构建的 Python 和 GCC 构建的 Python 有什么区别?