python - Tensoflow2 LSTM - 未使用的参数 input_shape?
问题描述
所以我用以下代码构建了神经网络:
import tensorflow as tf
tf_model = tf.keras.Sequential()
tf_model.add(tf.keras.layers.LSTM(50, activation='relu'))
tf_model.add(tf.keras.layers.Dense(20, activation='relu'))
tf_model.add(tf.keras.layers.Dense(10, activation='relu'))
tf_model.add(tf.keras.layers.Dense(1, activation='linear'))
tf_model.compile(optimizer='Adam', loss='mse')
我的训练集形状如下:
>> ts_train_X.shape
(16469, 3, 21)
我已经阅读了很多关于 stackoverflow 的文章和问题,以便为 LSTM 带来正确的数据框。我发现的几乎每一页都指定了input_shape
参数并将其传递给 LSTM(..) 或 Sequential(..)。
当我查看LSTM API时,我找不到对该参数的引用。我还瞥见了源代码,在我看来,形状似乎是自动推断出来的,但我不确定这一点。
这引出了我的问题:为什么我的代码有效?如果我没有指定 input_shape 参数,作为第一层的 LSTM 层如何知道我的输入的形状?
编辑:根据评论中的建议更改标题。
解决方案
可以将参数input_shape
提供给任何 kerasLayer
子类的构造函数,因为这是定义 API 的方式。
代码之所以有效,是因为input_shape
它作为关键字参数(**kwargs
)传递,然后这些关键字参数由LSTM
构造函数传递给Layer
构造函数,然后构造函数继续存储信息以供以后使用。这实际上意味着input_shape
参数不必在每一层中定义,而是作为关键字参数传递。
我认为问题在于,由于keras
已移至tensorflow
,文档可能不完整。您可以在顺序 API 指南中找到有关该input_shape
参数的更多信息。
推荐阅读
- java - 两遍渲染的模糊内核分离
- php - 在 Codeigniter 中使用 form_helper 时无法填充字段
- javascript - css文件的笑话错误
- angular - 从 ts 文件中读取常量:在 angular2 中找不到 404 文件
- asp.net-core - IdentityServer4 使用 asp.net 核心中的密码授权请求 JWT / 访问承载令牌
- javascript - 如何用 sinon 监视由 EventEmitter 事件触发的回调调用?Javascript,ES6,单元测试,Chai
- hadoop - Hadoop 集群交互式用户的永久 Kerberos 票证
- sbt - %% 不起作用,但 % 在库依赖项中起作用
- c# - 如何将 Button 事件从 WPF 应用程序路由到 wcf 服务
- casperjs - 无法使用 CasperJS 下载文件