python - 将 Lasagne 转换为 Keras 代码 (CNN -> LSTM)
问题描述
我想转换这个千层面代码:
et = {}
net['input'] = lasagne.layers.InputLayer((100, 1, 24, 113))
net['conv1/5x1'] = lasagne.layers.Conv2DLayer(net['input'], 64, (5, 1))
net['shuff'] = lasagne.layers.DimshuffleLayer(net['conv1/5x1'], (0, 2, 1, 3))
net['lstm1'] = lasagne.layers.LSTMLayer(net['shuff'], 128)
在 Keras 代码中。目前我想出了这个:
multi_input = Input(shape=(1, 24, 113), name='multi_input')
y = Conv2D(64, (5, 1), activation='relu', data_format='channels_first')(multi_input)
y = LSTM(128)(y)
但我得到了错误:Input 0 is incompatible with layer lstm_1: expected ndim=3, found ndim=4
解决方案
解决方案
from keras.layers import Input, Conv2D, LSTM, Permute, Reshape
multi_input = Input(shape=(1, 24, 113), name='multi_input')
print(multi_input.shape) # (?, 1, 24, 113)
y = Conv2D(64, (5, 1), activation='relu', data_format='channels_first')(multi_input)
print(y.shape) # (?, 64, 20, 113)
y = Permute((2, 1, 3))(y)
print(y.shape) # (?, 20, 64, 113)
# This line is what you missed
# ==================================================================
y = Reshape((int(y.shape[1]), int(y.shape[2]) * int(y.shape[3])))(y)
# ==================================================================
print(y.shape) # (?, 20, 7232)
y = LSTM(128)(y)
print(y.shape) # (?, 128)
解释
我把 Lasagne 和 Keras 的文档放在这里,方便大家相互参考:
循环层可以与前馈层类似地使用,除了输入形状应该是
(batch_size, sequence_length, num_inputs)
输入形状
具有形状的 3D 张量
(batch_size, timesteps, input_dim)
。
基本上 API 是相同的,但 Lasagne 可能会为你重塑(我需要稍后检查源代码)。这就是您收到此错误的原因:
Input 0 is incompatible with layer lstm_1: expected ndim=3, found ndim=4
, 因为之后的张量形状Conv2D
是(?, 64, 20, 113)
ndim=4
因此,解决方案是将其重塑为(?, 20, 7232)
.
编辑
用千层面源代码确认,它可以为您解决问题:
num_inputs = np.prod(input_shape[2:])
所以作为 LSTM 输入的正确张量形状是(?, 20, 64 * 113)
=(?, 20, 7232)
笔记
Permute
在 Keras 中是多余的,因为无论如何你都必须重塑。我把它放在这里的原因是为了有一个从 Lasagne 到 Keras 的“完整翻译”,它做了DimshuffleLaye
在 Lasagne 中所做的事情。
DimshuffleLaye
然而,由于我在Edit中提到的原因,Lasagne 中需要它,Lasagne LSTM 创建的新维度来自“最后两个”维度的乘积。
推荐阅读
- kubernetes - 如何在 kubernetes 中使用 consul 作为默认 dns 或服务发现
- android - 运行时 IllegalStateException 预期开始数组,但为字符串
- javascript - 生成 10 个唯一的随机整数
- html - 3 列固定标题,可扩展高度以适应内容
- angular - 导航后角度数据不更新
- reactjs - 传入 React.Component 类型的 prop 时出现 TS 错误
- c# - EF Core 2.1 GROUP BY 并选择每个组中的第一项
- python - 当我想使用 python pandas 仅过滤列中的部分值时,如何在 Pandas 中使用过滤器?
- android - 通过 adb 阻止单个 apk 通知
- python - python xlwt,写入下一个可用行