python - 无法在 Keras 中将 ModelCheckpoint 与 MobileNet 一起使用
问题描述
我正在尝试在多 GPU 机器上的 Docker 容器中使用 Keras 中的虚拟数据训练 MobileNet。最初我试图训练 Xception,但我决定改用更小的模型,这样即使机器功能不太强大的人也可以复制我的代码。我遇到了一些ModelCheckpoint
我无法理解的冲突。
import tensorflow as tf
import keras.utils
from keras.applications import MobileNet
from keras.callbacks import ModelCheckpoint
from keras.optimizers import Adam
import numpy as np
import os
height = 224
width = 224
channels = 3
epochs = 10
num_classes = 10
# Generate dummy data
batch_size = 32
n_train = 256
n_test = 64
x_train = np.random.random((n_train, height, width, channels))
y_train = keras.utils.to_categorical(np.random.randint(num_classes, size=(n_train, 1)), num_classes=num_classes)
x_test = np.random.random((n_train, height, width, channels))
y_test = keras.utils.to_categorical(np.random.randint(num_classes, size=(n_test, 1)), num_classes=num_classes)
# Get input shape
input_shape = x_train.shape[1:]
# Instantiate model
model = MobileNet(weights=None,
input_shape=input_shape,
classes=num_classes)
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
# Viewing Model Configuration
model.summary()
# Model file name
filepath = 'model_epoch_{epoch:02d}_loss_{loss:0.2f}_val_{val_loss:.2f}.hdf5'
# Define save_best_only checkpointer
checkpointer = ModelCheckpoint(filepath=filepath,
monitor='val_acc',
verbose=1,
save_best_only=True)
# Let's fit!
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
validation_data=(x_test, y_test),
callbacks=[checkpointer])
我得到的错误是
Traceback (most recent call last):
File "very_basic_test.py", line 52, in <module>
callbacks=[checkpointer])
File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 1650, in fit
batch_size=batch_size)
File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 1490, in _standardize_user_data
_check_array_lengths(x, y, sample_weights)
File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 220, in _check_array_lengths
'and ' + str(list(set_y)[0]) + ' target samples.')
ValueError: Input arrays should have the same number of samples as target arrays. Found 256 input samples and 64 target samples.
Python、Keras 和 TensorFlow 版本:
python -c 'import keras; import tensorflow; import sys; print(sys.version, 'keras.__version__', 'tensorflow.__version__')'
Using TensorFlow backend.
('2.7.12 (default, Dec 4 2017, 14:50:18) \n[GCC 5.4.0 20160609]', '2.1.6', '1.7.0')
解决方案
问题与检查点回调无关,而与您提供的数据有关。查看x_train.shape
并y_train.shape
检查样本数量、第一维大小是否不匹配。该错误似乎发生在该行,因为那是.fit
函数的调用。
推荐阅读
- api - Documentum Rest api 问题
- javascript - 找不到模块:无法在反应应用程序中解析 node_modules\minio\dist\main 和 node_modules\mkdirp 中的“fs”
- python - 用户警告:IPython 历史需要 SQLite
- python-3.x - 如何在 xlsxwriter 中格式化空单元格的范围
- python - 如何创建一个包并导入它?
- javascript - 结合使用重复代码的方法的可能方式
- c# - 无法捕捉异常
- javascript - 反应子嵌套路由
- python - 如何在 python 字典中以某种方式更改所有键?
- selenium - 如何在 Robot Framework 中最大化 Headless Chrome 窗口?