python - 有时这种 tensorflow 训练有效,有时却无效
问题描述
我正在学习 Python Tensorflow(机器学习),以下示例以前可以使用,但现在它突然开始失败。
import tensorflow as tf
import numpy as np
from tensorflow import keras
我收到以下警告:
这是我的代码:
EPOCHS = 200
BATCH_SIZE = 128
VERBOSE = 1
NB_CLASSES = 10
N_HIDDEN = 128
VALIDATION_SPLIT = 0.2
DROPOUT = 0.3
## loading MNIST dataset
# Labels have one-hot representation
mnist = keras.datasets.mnist
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
## X_train is 60000 rows of 28x28 values; we reshape it to
60000 * 784
RESHAPED = 784
#
X_train = X_train.reshape(60000, RESHAPED)
X_test = X_test.reshape(10000, RESHAPED)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
# Normalise inputs within [0,1]
X_train, X_test = X_train / 255, X_test / 255
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')
Y_train = tf.keras.utils.to_categorical(Y_train,
NB_CLASSES)
Y_test = tf.keras.utils.to_categorical(Y_test, NB_CLASSES)
# One Hot representation for labels
Y_train = tf.keras.utils.to_categorical (Y_train,
NB_CLASSES)
y_test = tf.keras.utils.to_categorical (Y_test, NB_CLASSES)
# Build the model.
model = tf.keras.models.Sequential()
model.add(keras.layers.Dense (N_HIDDEN, input_shape=
(RESHAPED,), name='dense_layer', activation='relu'))
model.add(keras.layers.Dropout (DROPOUT))
model.add(keras.layers.Dense (N_HIDDEN,
name='dense_layer_2', activation='relu'))
model.add(keras.layers.Dropout (DROPOUT))
model.add(keras.layers.Dense (NB_CLASSES,
name='dense_layer_3', activation='softmax'))
# Compile the model
model.compile(optimizer='SGD',
loss='categorical_crossentropy',
metrics=['accuracy'])
# Training the model
model.fit(X_train, Y_train,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
verbose=VERBOSE,
validation_split=VALIDATION_SPLIT)
这是训练失败的地方:
InvalidArgumentError Traceback (most recent call last)
<ipython-input-10-51b78dc3a33e> in <module>
5 epochs=EPOCHS,
6 verbose=VERBOSE,
----> 7 validation_split=VALIDATION_SPLIT)
C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
778 validation_steps=validation_steps,
779 validation_freq=validation_freq,
--> 780 steps_name='steps_per_epoch')
781
782 def evaluate(self,
C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\keras\engine\training_arrays.py in model_iteration(model, inputs, targets, sample_weights, batch_size, epochs, verbose, callbacks, val_inputs, val_targets, val_sample_weights, shuffle, initial_epoch, steps_per_epoch, validation_steps, validation_freq, mode, validation_in_fit, prepared_feed_values_from_dataset, steps_name, **kwargs)
361
362 # Get outputs.
--> 363 batch_outs = f(ins_batch)
364 if not isinstance(batch_outs, list):
365 batch_outs = [batch_outs]
C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\keras\backend.py in __call__(self, inputs)
3290
3291 fetched = self._callable_fn(*array_vals,
-> 3292 run_metadata=self.run_metadata)
3293 self._call_fetch_callbacks(fetched[-len(self._fetches):])
3294 output_structure = nest.pack_sequence_as(
C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in __call__(self, *args, **kwargs)
1456 ret = tf_session.TF_SessionRunCallable(self._session._session,
1457 self._handle, args,
-> 1458 run_metadata_ptr)
1459 if run_metadata:
1460 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
InvalidArgumentError: logits and labels must be broadcastable: logits_size=[128,10] labels_size=[1280,10]
[[{{node loss_1/dense_layer_3_loss/softmax_cross_entropy_with_logits}}]]
我在这里做错了什么?
解决方案
您包含此代码两次。
Y_train = tf.keras.utils.to_categorical(Y_train,
NB_CLASSES)
Y_test = tf.keras.utils.to_categorical(Y_test, NB_CLASSES)
# One Hot representation for labels
Y_train = tf.keras.utils.to_categorical (Y_train,
NB_CLASSES)
y_test = tf.keras.utils.to_categorical (Y_test, NB_CLASSES)
当你第二次执行它时 Y_train 和 Y_test 已经是分类形式
推荐阅读
- docker - 禁用气流身份验证
- c++ - 从父数组打印 n 数组树
- r - 通过在数据集中对最后几个月进行分组来计算唯一的泥瓦匠
- nestjs - 如何一起使用 ParseIntPipe 和 Dto?
- mysql - 无论如何要自动化 Azure MySql 导入/导出
- javascript - 以模态形式获取元素值正确的行
- gitlab - 是否可以获得合并请求的最后一次提交 sha?GITLAB API
- c# - UI 被异步等待方法阻塞
- android - 如何在android studio中以编程方式将按钮大小设置为全屏?
- python-3.x - 我正在尝试在 python 上使用 np.loadtxt 加载 .txt 文件,但它总是给我同样的错误