python - RNN 从哪里获取批量大小?
问题描述
我正在通过以下方式训练 RNN:
def create_rnn_model(stateful,length):
model = Sequential()
model.add(SimpleRNN(20,return_sequences=False,stateful=stateful,batch_input_shape=(1,length,1)))
adam = optimizers.Adam(lr=0.001)
model.add(Dense(1))
model.compile(loss='mean_absolute_error', optimizer=adam, metrics=[root_mean_squared_error])
print(model.summary())
return model
和适合
model_info = model_rnn_stateful.fit(x=x_train, y=y_train, validation_data=(x_test, y_test), batch_size=1, epochs=10,verbose=1)
并预测
predicted_rnn_stateful = model_rnn_stateful.predict(x_test)
但是当我预测它会抛出一个错误
ValueError:在有状态的网络中,您应该只传递包含多个样本的输入,这些样本可以除以批量大小。发现:200 个样本。批量大小:32。
没有我指定 32 的地方。我不知道它来自哪里。我的批量大小仅为 1。感谢任何帮助。
编辑 我的脚本/IDE 中没有使用断点。谢谢
解决方案
来自Keras 文档
- batch_size:整数或无。每次梯度更新的样本数。如果未指定,batch_size 将默认为 32。
1 可能是 batch_size 的错误值,然后它采用默认值 32。尝试使用 2 或 20 作为 batch_size
推荐阅读
- python - 将二进制文件解压缩为文本
- python - 检查一个单词中是否有各种字符串
- google-apps-script - “脚本没有权限”错误 - 日历 API
- plot - Plot(...) 作为输出而不是 Julia 中的绘图
- coq - 如何一步一步地检查 Coq 中更复杂的策略做什么?
- mysql - 我可以将哪些文件格式导入 MySQLWorkbench?
- mysql - Tricky Triple Sql Join 查找最近的日期
- keras - Keras如何继续保存validation loss的历史
- rx-java - 从 rxjava 中的 Single 提取并返回对象
- javascript - 如何仅使用 es5 重新格式化 JSON 代码。