python - Keras:输入 0 与层 conv1d_5 不兼容:预期 ndim=3,发现 ndim=2
问题描述
我是 Keras 的初学者。我在这里使用一个包含 20 个特征的简单数据集作为回归模型。我不断收到预期 ndim=3 的错误,发现 conv2d 层的 ndim=2。这是我的代码:
def buildreg():
regressor=Sequential()
regressor.add(Dense(units=170,input_dim=20))
regressor.add(Dense(units=25))
regressor.add(Conv1D(16,5))
regressor.add(Dense(units=100))
regressor.add(Dense(units=55))
regressor.add(Dense(units=10))
regressor.add(Dense(units=60))
regressor.add(Dense(units=1))
regressor.compile(optimizer='adam',loss='mean_absolute_error')
return regressor
from keras.wrappers.scikit_learn import KerasRegressor
model=KerasRegressor(build_fn=buildreg,batch_size=15,epochs=20)
输入是一个具有 20 个特征的数据框。数据集是一个非常小的数据集,大约 1k 行。我完全清楚它会过拟合的事实。处理得当。数据与 MLPRegressor 一起工作得很好。
解决方案
1) 如果您想使用 Conv1D,则使用 Conv1D 作为第一层并指定 input_shape 如下
input_shape=(N_features, 1)
并将您的火车重塑为形状(nb_of_examples,nb_of_features,1)。
因此,您修改后的代码将如下所示,
Processed--XTRAIN = Processed--XTRAIN.reshape(1457,20,1)
def buildreg():
regressor=Sequential()
regressor.add(Conv1D(16,5,input_shape=(20, 1)))
regressor.add(Dense(units=170,input_dim=20))
regressor.add(Dense(units=25))
regressor.add(Dense(units=100))
regressor.add(Dense(units=55))
regressor.add(Dense(units=10))
regressor.add(Dense(units=60))
regressor.add(Dense(units=1))
regressor.compile(optimizer='adam',loss='mean_absolute_error')
return regressor
from keras.wrappers.scikit_learn import KerasRegressor
model=KerasRegressor(build_fn=buildreg,batch_size=15,epochs=20)
2)否则,您可以删除卷积层并构建简单的ANN模型。只需使用 Dense 层来构建您的模型。
你的代码看起来像这样
def buildreg():
regressor=Sequential()
regressor.add(Dense(units=170,input_dim=20))
regressor.add(Dense(units=25))
regressor.add(Dense(units=100))
regressor.add(Dense(units=55))
regressor.add(Dense(units=10))
regressor.add(Dense(units=60))
regressor.add(Dense(units=1))
regressor.compile(optimizer='adam',loss='mean_absolute_error')
return regressor
from keras.wrappers.scikit_learn import KerasRegressor
model=KerasRegressor(build_fn=buildreg,batch_size=15,epochs=20)
推荐阅读
- sql-server - 获取日期时间列中特定时间的最大时间?
- javascript - 禁用所有输入类型编号的滚动
- r - 按顺序在数据框中添加行并适当地填充行
- c++ - 错误:未在此范围内声明“SHGetKnownFolderPath”
- typescript - 通用枚举类型保护
- javascript - 我正在创建乘法表。需要帮助在每个循环之间创建新行
- regex - Text.Regex.Posix 的 =~ 运算符在某些模式下无法获取返回值
- java - 从 Java 中的线程创建进程有意义吗?
- angular - Jasmine 单元测试失败:无法读取未定义的属性“测试”
- java - 如果文本的最大行数超过 4,如何显示阅读更多文本?