keras - 如何获取 Keras 模型的运行时批量大小
问题描述
基于这篇文章。我需要一些基本的实施帮助。下面你会看到我的模型使用了 Dropout 层。使用noise_shape 参数时,碰巧最后一批不适合批量大小,从而产生错误(参见其他帖子)。
原始型号:
def LSTM_model(X_train,Y_train,dropout,hidden_units,MaskWert,batchsize):
model = Sequential()
model.add(Masking(mask_value=MaskWert, input_shape=(X_train.shape[1],X_train.shape[2]) ))
model.add(Dropout(dropout, noise_shape=(batchsize, 1, X_train.shape[2]) ))
model.add(Dense(hidden_units, activation='sigmoid', kernel_constraint=max_norm(max_value=4.) ))
model.add(LSTM(hidden_units, return_sequences=True, dropout=dropout, recurrent_dropout=dropout))
现在 Alexandre Passos 建议使用tf.shape获取运行时的批量大小。我尝试以不同的方式将运行时批量大小的想法实现到 Keras 中,但从未奏效。
import Keras.backend as K
def backend_shape(x):
return K.shape(x)
def LSTM_model(X_train,Y_train,dropout,hidden_units,MaskWert,batchsize):
batchsize=backend_shape(X_train)
model = Sequential()
...
model.add(Dropout(dropout, noise_shape=(batchsize[0], 1, X_train.shape[2]) ))
...
但这只是给了我输入张量形状,而不是运行时输入张量形状。
我也尝试使用 Lambda 层
def output_of_lambda(input_shape):
return (input_shape)
def LSTM_model_2(X_train,Y_train,dropout,hidden_units,MaskWert,batchsize):
model = Sequential()
model.add(Lambda(output_of_lambda, outputshape=output_of_lambda))
...
model.add(Dropout(dropout, noise_shape=(outputshape[0], 1, X_train.shape[2]) ))
以及不同的变体。但正如您已经猜到的那样,这根本不起作用。模型定义实际上是正确的地方吗?您能否给我一个提示或更好地告诉我如何获得 Keras 模型的运行批量大小?非常感谢。
解决方案
当前的实现确实会根据运行时批量大小进行调整。从Dropout
层实现代码:
symbolic_shape = K.shape(inputs)
noise_shape = [symbolic_shape[axis] if shape is None else shape
for axis, shape in enumerate(self.noise_shape)]
因此,如果您noise_shape=(None, 1, features)
按照上面的代码给出形状将是 (runtime_batchsize, 1, features)。
推荐阅读
- windows - 执行 npm.cmd 后批处理脚本中断
- javascript - Tabulator.js:获取/选择当前页面上的行
- python - 如何在 Tensorflow 中计算训练 RNN 语言模型的准确率?
- statistics - 超过 100 人到达车站的概率是多少,如果他们基于 2 分钟的指数分布来?
- python-3.x - 通过 Python 在 PowerPoint 中更新链接的 excel 路径
- javascript - 有没有办法插入来自 api 的路由?
- r - 如何遍历数据集减去下面行中的值 - 使用 R
- react-native - 如何让 Android 上的 Chrome 启动我的图像选择器而不是 React Native 中的系统 UI?
- c# - 将父级转换为其通用子级
- css - 如何通过 css 强制所有帖子图像居中对齐?