python - Python - 直接来自教科书的代码错误 - 神经网络
问题描述
下面的代码给出了一个错误,这直接来自教科书“Machine Learning with Python Cookbook”。任何人都可以看到这个问题吗?
代码:
import numpy as np
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense,Dropout,Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.utils import np_utils
from keras import backend as K
K.set_image_data_format("channels_first")
np.random.seed(0)
channels=1
height=28
width=28
(data_train,target_train),(data_test,target_test)=mnist.load_data()
data_train=data_train.reshape(data_train.shape[0],channels,height,width)
data_test=data_test.reshape(data_test.shape[0],channels,height,width)
features_train=data_train/255
features_test=data_test/255
target_train=np_utils.to_categorical(target_train)
target_test=np_utils.to_categorical(target_test)
number_of_classes=target_test.shape[1]
net=Sequential()
net.add(Conv2D(filters=64, kernel_size=(5,5), input_shape=(channels,width,height), activation="relu" ))
net.add(MaxPooling2D(pool_size=(2,2)))
net.add(Dropout(0.5))
net.add(Flatten())
net.add(Dense(128,activation="relu"))
net.add(Dropout(0.5))
net.add(Dense(number_of_classes,activation="softmax"))
net.compile(loss="categorical_crossentropy", optimizer="rmsprop",metrics=["accuracy"])
net.fit(features_train,target_train,epochs=2, verbose=0, batch_size=1000,validation_data=(features_test,target_test))
这是日志。错误的底部看起来与 MaxPooling2D 层有关,但该错误对我来说没有多大意义。“默认 MaxPoolingOp 仅支持设备类型 CPU 上的 NHWC”
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-33-7b4b0fc491f3> in <module>
36 net.add(Dense(number_of_classes,activation="softmax"))
37 net.compile(loss="categorical_crossentropy", optimizer="rmsprop",metrics=["accuracy"])
---> 38 net.fit(features_train,target_train,epochs=2, verbose=0, batch_size=1000,validation_data=(features_test,target_test))
C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py in _method_wrapper(self, *args, **kwargs)
106 def _method_wrapper(self, *args, **kwargs):
107 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
--> 108 return method(self, *args, **kwargs)
109
110 # Running inside `run_distribute_coordinator` already.
C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
1096 batch_size=batch_size):
1097 callbacks.on_train_batch_begin(step)
-> 1098 tmp_logs = train_function(iterator)
1099 if data_handler.should_sync:
1100 context.async_wait()
C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
778 else:
779 compiler = "nonXla"
--> 780 result = self._call(*args, **kwds)
781
782 new_tracing_count = self._get_tracing_count()
C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
838 # Lifting succeeded, so variables are initialized and we can run the
839 # stateless function.
--> 840 return self._stateless_fn(*args, **kwds)
841 else:
842 canon_args, canon_kwds = \
C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs)
2827 with self._lock:
2828 graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
-> 2829 return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access
2830
2831 @property
C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py in _filtered_call(self, args, kwargs, cancellation_manager)
1841 `args` and `kwargs`.
1842 """
-> 1843 return self._call_flat(
1844 [t for t in nest.flatten((args, kwargs), expand_composites=True)
1845 if isinstance(t, (ops.Tensor,
C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1921 and executing_eagerly):
1922 # No tape is watching; skip to running the function.
-> 1923 return self._build_call_outputs(self._inference_function.call(
1924 ctx, args, cancellation_manager=cancellation_manager))
1925 forward_backward = self._select_forward_and_backward_functions(
C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager)
543 with _InterpolateFunctionError(self):
544 if cancellation_manager is None:
--> 545 outputs = execute.execute(
546 str(self.signature.name),
547 num_outputs=self._num_outputs,
C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
57 try:
58 ctx.ensure_initialized()
---> 59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
InvalidArgumentError: Default MaxPoolingOp only supports NHWC on device type CPU
[[node sequential_78/max_pooling2d_3/MaxPool (defined at <ipython-input-33-7b4b0fc491f3>:38) ]] [Op:__inference_train_function_1036301]
Function call stack:
train_function
谢谢
非常感谢,这似乎已经解决了最初的问题。但是我现在似乎得到了一个不同的错误
ValueError: Negative dimension size caused by subtracting 5 from 1 for '{{node conv2d/Conv2D}} = Conv2D[T=DT_FLOAT, data_format="NHWC", dilations=[1, 1, 1, 1], explicit_paddings=[], padding="VALID", strides=[1, 1, 1, 1], use_cudnn_on_gpu=true](conv2d_input, conv2d/Conv2D/ReadVariableOp)' with input shapes: [?,1,28,28], [5,5,28,64].
我所做的一切都按照建议进行。
解决方案
推荐阅读
- r - 使用 map2_dfr 将数据绑定在一起
- azure - 从 BLOB 到 Azure Postgres Gen5 8 核心的 MS Azure 数据工厂 ADF 复制活动失败,连接因主机错误而关闭
- android - AndroidX 迁移
- cql - 我可以选择带有 CQL 的 Blob 列表吗?
- r - 程序启动时R项目“意外的字符串常量”错误
- kubernetes - 使用 kubectl expose 时,是否可以选择通过 nodeport 服务发布服务的主机端口?
- swift - 如何使用 XCode 的 Interface Builder 在 Messenger 中创建气泡聊天 TableViewCell?
- laravel - 重定向不适用于自定义中间件
- ios - Swift 框架更新导致错误的布局更新
- ruby-on-rails - 未获得 CSV 文件中每一行的正确值