python - 如何设置keras LSTM的输入形状
问题描述
我经常遇到这个问题
ValueError: Input 0 of layer sequential is incompatible with the layer: expected ndim=3, found ndim=2. Full shape received: [10 ,3]
我搜索了一下,发现
LSTM layer expects inputs to have shape of (batch_size, timesteps, input_dim)
好的,但老实说,我仍然有点困惑。
例如,我有这样的训练数据
x_train (100,3) #it consists of like `[[1,2,3],[3,4,5],[5,6,7]]`
y_train (100,3) #answers
我想使用 10 组 3 对数字并预测下一个 3pair ex [7,8,9]
。
就像从 x_train[1~10] 猜测到 y_train[11]`
下面的代码有效,但我仍然不清楚
在那个地方input_shape=(3,1)
是什么意思1
???。应该是3
(我最终想要得到的维度)
并且batch_size
是 LSTM 请求的第一个参数。
所以,,当我想从过去的 10 个项目中预测一个时,这里设置 10 是否正确?
x_train = np.array(x).reshape(100, 3,1)
y_train = np.array(x).reshape(100, 3,1)
model.add(LSTM(512, activation=None, input_shape=(3, 1), return_sequences=True))
model.add(Dense(1, activation="linear"))
opt = Adam(lr=0.001)
model.compile(loss='mse', optimizer=opt)
model.summary()
history = model.fit(x_train, y_train, epochs=epoch, batch_size=10) // how to set batch size???
解决方案
试试这个代码:
import tensorflow as tf
import numpy as np
x = np.random.uniform(0, 10, [101, 3])
x_train = np.array(x[:-1]).reshape(-1, 5, 3) # your data comprise of 20 sequences
y_train = np.array(x[1:]).reshape(-1, 5, 3)
model = tf.keras.Sequential()
model.add(tf.keras.layers.LSTM(512, activation=None, input_shape=(None, 3), return_sequences=True))
model.add(tf.keras.layers.Dense(1, activation="linear"))
opt = tf.keras.optimizers.Adam(lr=0.001)
model.compile(loss='mse', optimizer=opt)
model.summary()
history = model.fit(x_train, y_train, epochs=10, batch_size=10) # here you can set a batch size (your 20 sequences will be splitted into two batches)
推荐阅读
- java - Selenium (firefox) 在 WebElement.Click 上挂起
- excel - 计算每周性能/Power Query
- javascript - 我如何将 Type Script 中的数字乘以 [{ count: 3, value: 2 },{ count: 7, value: 1 }] 这样的数组,使得答案是 3*2 + 7*1 = 13?
- javascript - 使用 JS 定期检查文本的有效方法
- python - 我的代码循环通过 urls 但不是 urls Python 中的页面
- android - 如何使用 CameraX 在应用程序内显示相机
- python - 未找到 BingAds API 类型问题
- python - 是否可以创建一个 django 查询来查找两个模型中最新的唯一日期?
- c - 在 C 中将浮点数 16 转换为 4 位整数的最有效方法
- flutter - 需要管理列高并将变量作为小部件属性传递