r - 当我只有两个输入时,为什么 Keras 将我的 input_shape 视为三维?
问题描述
这是我目前遇到的 ValueError:
ValueError: Input 0 of layer sequential_10 is incompatible with the layer: expected ndim=3, found ndim=2. Full shape received: [None, 9]
应该注意的是,我正在使用 Keras、Caret 和 Tidyverse 处理 R。
我目前正在 Kaggle 的 Titanic 数据集上构建神经网络模型。
点击此链接将为您提供我迄今为止编写的所有代码以及我正在使用的完整数据集。
从抛出的 ValueError 中,我了解到我描述X.train
数据形状的方式存在问题。虽然,我不确定如何塑造这些数据以使我的模型运行顺畅。
以下是我开始构建模型的方式:
#Build Model
input_shape <- shape(ncol(X.train), nrow(X.train))
model <- keras_model_sequential()
model %>%
layer_batch_normalization(input_shape = input_shape) %>% #Normalization Layer
layer_dense(units = 256, activation = 'relu') %>% #First Layer
layer_batch_normalization() %>%
layer_dropout(rate = 0.3) %>%
layer_dense(units = 256, activation = 'relu') %>% #Second layer
layer_batch_normalization() %>%
layer_dropout(rate = 0.3) %>%
layer_dense(units = 256, activation = 'relu') %>% #Third layer
layer_batch_normalization() %>%
layer_dropout(rate = 0.3) %>%
layer_dense(units = 1, activation = 'sigmoid') %>% #Output Layer
compile(
loss = 'binary_crossentropy',
optimizer = 'adam',
metrics = c('accuracy'))
这是我遇到的完整错误:
ValueError: Input 0 of layer sequential_10 is incompatible with the layer: expected ndim=3, found ndim=2. Full shape received: [None, 9]
Detailed traceback:
File "/usr/local/share/.virtualenvs/r-reticulate/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
return method(self, *args, **kwargs)
File "/usr/local/share/.virtualenvs/r-reticulate/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1098, in fit
tmp_logs = train_function(iterator)
File "/usr/local/share/.virtualenvs/r-reticulate/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
result = self._call(*args, **kwds)
File "/usr/local/share/.virtualenvs/r-reticulate/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 823, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "/usr/local/share/.virtualenvs/r-reticulate/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 697, in _initialize
*args, **kwds))
File "/usr/local/share/.virtualenvs/r-reticulate/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2855, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/usr/local/share/.virtualenvs/r-reticulate/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/usr/local/share/.virtualenvs/r-reticulate/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3075, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/usr/local/share/.virtualenvs/r-reticulate/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/usr/local/share/.virtualenvs/r-reticulate/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 600, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/usr/local/share/.virtualenvs/r-reticulate/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 973, in wrapper
raise e.ag_error_metadata.to_exception(e)
Traceback:
1. model %>% fit(X.train, y.train, epochs = 500, batch_size = 200,
. validation_split = 0.3, callbacks = list(callback_early_stopping(monitor = "val_loss",
. mode = "auto", patience = 5, min_delta = 0.001, restore_best_weights = TRUE)))
2. withVisible(eval(quote(`_fseq`(`_lhs`)), env, env))
3. eval(quote(`_fseq`(`_lhs`)), env, env)
4. eval(quote(`_fseq`(`_lhs`)), env, env)
5. `_fseq`(`_lhs`)
6. freduce(value, `_function_list`)
7. withVisible(function_list[[k]](value))
8. function_list[[k]](value)
9. fit(., X.train, y.train, epochs = 500, batch_size = 200, validation_split = 0.3,
. callbacks = list(callback_early_stopping(monitor = "val_loss",
. mode = "auto", patience = 5, min_delta = 0.001, restore_best_weights = TRUE)))
10. fit.keras.engine.training.Model(., X.train, y.train, epochs = 500,
. batch_size = 200, validation_split = 0.3, callbacks = list(callback_early_stopping(monitor = "val_loss",
. mode = "auto", patience = 5, min_delta = 0.001, restore_best_weights = TRUE)))
11. do.call(object$fit, args)
12. (structure(function (...)
. {
. dots <- py_resolve_dots(list(...))
. result <- py_call_impl(callable, dots$args, dots$keywords)
. if (convert)
. result <- py_to_r(result)
. if (is.null(result))
. invisible(result)
. else result
. }, class = c("python.builtin.method", "python.builtin.object"
. ), py_object = <environment>))(batch_size = 200L, epochs = 500L,
. verbose = 1L, callbacks = list(<environment>, <environment>),
. validation_split = 0.3, shuffle = TRUE, class_weight = NULL,
. sample_weight = NULL, initial_epoch = 0L, x = <environment>,
. y = <environment>)
13. py_call_impl(callable, dots$args, dots$keywords)
Error in py_call_impl(callable, dots$args, dots$keywords): ValueError: in user code: /usr/local/share/.virtualenvs/r-reticulate/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:806 train_function * return step_function(self, iterator) /usr/local/share/.virtual
keyboard_arrow_right
谢谢你,我很感激你的反馈。
解决方案
推荐阅读
- ruby - 使用 Ruby、PGSql 和 Sinatra 的 new.erb 视图和控制器中的年龄布尔值
- javascript - React 布局帮助 - 无法让我的 SVG 既定尺寸又响应
- numpy - 弃用警告:元素比较失败;这将在未来引发错误。在广播大型阵列时
- ios - APNS 推送通知 iOS 13 - didRegisterForRemoteNotificationsWithDeviceToken 未被调用
- javascript - $.getJSON 中的函数似乎没有运行
- c# - Visual Studio 2019:System.UnauthorizedAccessException:“访问路径“C:/”被拒绝
- python - 在 jupyter 中具有相对路径的 read_csv
- c - 是否可以使用 GTK3 库手动擦除 GtkEntry 的内存?
- github - 部署在一个hexo博客的github上
- node.js - 如何将 .wav 从 lambda 函数中的 URL 保存到 s3 存储桶