python - 具有多个输入的 Keras 顺序模型
问题描述
我正在制作一个 MLP 模型,它需要两个输入并产生一个输出。
我有两个输入数组(每个输入一个)和 1 个输出数组。神经网络有 1 个隐藏层和 2 个神经元。每个数组有 336 个元素。
model0 = keras.Sequential([
keras.layers.Dense(2, input_dim=2, activation=keras.activations.sigmoid, use_bias=True),
keras.layers.Dense(1, activation=keras.activations.relu, use_bias=True),
])
# Compile the neural network #
model0.compile(
optimizer = keras.optimizers.RMSprop(lr=0.02,rho=0.9,epsilon=None,decay=0),
loss = 'mean_squared_error',
metrics=['accuracy']
)
我尝试了两种方法,它们都给出了错误。
model0.fit(numpy.array([array_1, array_2]),output, batch_size=16, epochs=100)
ValueError:检查输入时出错:预期的dense_input具有形状(2,)但得到的数组形状为(336,)
第二种方式:
model0.fit([array_1, array_2],output, batch_size=16, epochs=100)
ValueError:检查模型输入时出错:您传递给模型的 Numpy 数组列表不是模型预期的大小。预计会看到 1 个数组,但得到了以下 2 个数组的列表:
类似的问题。但不使用顺序模型。
解决方案
要解决此问题,您有两种选择。
1. 使用顺序模型
在馈送到网络之前,您可以将两个数组连接成一个。假设这两个数组的形状为 (Number_data_points, ),现在可以使用numpy.stack
方法合并数组。
merged_array = np.stack([array_1, array_2], axis=1)
model0 = keras.Sequential([
keras.layers.Dense(2, input_dim=2, activation=keras.activations.sigmoid, use_bias=True),
keras.layers.Dense(1, activation=keras.activations.relu, use_bias=True),
])
model0.fit(merged_array,output, batch_size=16, epochs=100)
2.使用功能API。
当模型有多个输入时,这是最推荐使用的方法。
input1 = keras.layers.Input(shape=(1, ))
input2 = keras.layers.Input(shape=(1,))
merged = keras.layers.Concatenate(axis=1)([input1, input2])
dense1 = keras.layers.Dense(2, input_dim=2, activation=keras.activations.sigmoid, use_bias=True)(merged)
output = keras.layers.Dense(1, activation=keras.activations.relu, use_bias=True)(dense1)
model10 = keras.models.Model(inputs=[input1, input2], output=output)
现在您可以使用您尝试适合模型的第二种方法
model0.fit([array_1, array_2],output, batch_size=16, epochs=100)
推荐阅读
- node.js - 为什么 Node.js 不能导入模块?错误 [ERR_MODULE_NOT_FOUND]
- java - 有没有任何方法可以使用 selenium 处理身份验证警报,而无需任何操作系统工具,如 autoIT 和机器人
- html - 将 .HTML Jupyter Markdown 上传到 Git 页面
- woocommerce - 以编程方式将订单注释和项目元注释添加到现有的 woocommerce 订单
- postgresql - 如何提高sql查询速度并快速获取记录
- omnet++ - 如何控制数据包生成率和发送间隔
- node.js - 在 REST 中使用 Facebook Flow 登录(带有 TypeScript + Passport.js 的 Express.js)- 几个问题
- c# - 如何在 C# 中解析这些 JSON 数据,简单地切换到 javascript 会更有益吗?
- appdynamics - Appdynamics - 体验旅程地图 - 下车率
- c# - 如何从 ServiceStack.HttpResponseExtensionsInternal 中捕获特殊的 ServiceStack 异常“response.FlushAsync()”?