python - 当每批中的观察值具有不同数量的缺失值时使用 Keras 掩蔽层
问题描述
我正在使用 Keras 为具有不同长度的序列构建 RNN。我用 -99 值填充了每个序列的缺失值(我没有使用 0,因为这是我的数据集中的一个有意义的值)。该模型的定义如下:
model = keras.models.Sequential([
keras.layers.Masking(mask_value=-99, input_shape=(n_lags, n_input_vars)),
keras.layers.LSTM(64, return_sequences=True),
keras.layers.LSTM(16),
keras.layers.Dense(3)
])
model.compile(loss="mse", optimizer="adam")
history = model.fit(X_train, y_train, epochs=50, validation_data=(X_val, y_val), batch_size=100)
训练集已被分批成 100 个观察值的块。每个批次可能包含不同长度的观察(序列),例如:
序列 1: [0, 1, 3], [-99, -99, -99], [-99, -99, -99]
序列 2: [1, 5, 9], [6, 7, 10], [-99, -99, -99]
……
序列 100: [8、7、4]、[-99、-99、-99]、[-99、-99、-99]
Keras Masking 层是否允许批量观察的序列长度不同的这种情况?或者我是否需要为每个观察构建具有相同缺失输入值的批次?
解决方案
@ad2004 只是部分正确,因为 return_sequence = True 的 LSTM 可以正确传播掩码,但是,第二个 LSTM 层(默认 return_sequence = False)将丢失掩码。所以掩码实际上从未传播到输出层,因此损失仍将包括填充数据(当然,掩码层将 -99 变为 0,只是损失仍将包括填充值的部分)。为了验证这一点,我们可以简单地打印出每一层的 input_mask 和 output_mask,如果是 None 则意味着没有掩码。
for i, l in enumerate(model.layers):
print(f'layer {i}: {l}')
print(f'has input mask: {l.input_mask}')
print(f'has output mask: {l.output_mask}')
layer 0: <tensorflow.python.keras.layers.core.Masking object at 0x6675b2f98>
has input mask: None
has output mask: Tensor("masking_7/Identity_1:0", shape=(None, 30), dtype=bool)
layer 1: <tensorflow.python.keras.layers.recurrent_v2.LSTM object at 0x66537f278>
has input mask: Tensor("masking_7/Identity_1:0", shape=(None, 30), dtype=bool)
has output mask: Tensor("masking_7/Identity_1:0", shape=(None, 30), dtype=bool)
layer 2: <tensorflow.python.keras.layers.recurrent_v2.LSTM object at 0x6676b4588>
has input mask: Tensor("masking_7/Identity_1:0", shape=(None, 30), dtype=bool)
has output mask: None
layer 3: <tensorflow.python.keras.layers.core.Dense object at 0x6676b6240>
has input mask: None
has output mask: None
推荐阅读
- php - 从表 1 中获取所有名称,其中从表 2 中获取 id(以逗号分隔) - 使用 Codeigniter 点燃的数据表服务器端
- node.js - 如何在不重新加载当前页面的情况下做出响应?
- go - 是否需要指定列名?
- docker - 如何将卷数据的副本复制到桌面上的文件夹中?
- c++ - 如何无休止地循环地图
- javascript - Django:在不重定向的情况下更新页面信息
- android - 我怎样才能在 xml 文件中显示这个?有谁可以给我看这个 xml 的示例代码?
- mysql - 查询:查找购买过店铺提供的各类鲜花的顾客
- django - ImageField upload_to 不适用于 django 更新语句
- sql - postgresql - 行值更改后查询中的“group_id”列?