tensorflow - 如何在 lstm 输出的每个时间步应用平均池化?
问题描述
我正在尝试在 lstm 输出的每个时间步应用平均池,请找到我的架构如下
X_input = tf.keras.layers.Input(shape=(64,35))
X= tf.keras.layers.LSTM(512,activation="tanh",return_sequences=True,kernel_initializer=tf.keras.initializers.he_uniform(seed=45),kernel_regularizer=tf.keras.regularizers.l2(0.1))(X_input)
X= tf.keras.layers.LSTM(256,activation="tanh",return_sequences=True,kernel_initializer=tf.keras.initializers.he_uniform(seed=45),kernel_regularizer=tf.keras.regularizers.l2(0.1))(X)
X = tf.keras.layers.GlobalAvgPool1D()(X)
X = tf.keras.layers.Dense(128,activation="relu",kernel_initializer=tf.keras.initializers.he_uniform(seed=45),kernel_regularizer=tf.keras.regularizers.l2(0.1))(X)
X = tf.keras.layers.Dense(64,activation="relu",kernel_initializer=tf.keras.initializers.he_uniform(seed=45),kernel_regularizer=tf.keras.regularizers.l2(0.1))(X)
X = tf.keras.layers.Dense(32,activation="relu",kernel_initializer=tf.keras.initializers.he_uniform(seed=45),kernel_regularizer=tf.keras.regularizers.l2(0.1))(X)
# X = tf.keras.layers.Dense(16,activation="relu",kernel_initializer=tf.keras.initializers.he_uniform(seed=45),kernel_regularizer=tf.keras.regularizers.l2(0.1))(X)
output_layer = tf.keras.layers.Dense(10,activation='softmax', kernel_initializer=tf.keras.initializers.he_uniform(seed=45))(X)
model2 = tf.keras.Model(inputs = X_input,outputs = output_layer)
我想在每个时间步取平均值,而不是在每个单元上例如现在我得到形状 (None,256) 但我想从全局平均池化层得到形状 (None,64),我需要什么为此做。
解决方案
我不确定这是最有效的方法,但你可以试试这个:
X = tf.keras.layers.Reshape(target_shape=(64,256,1))(X)
X = tf.keras.layers.TimeDistributed(tf.keras.layers.GlobalAveragePooling1D())(X)
X = tf.keras.layers.Reshape(target_shape=(64,))(X)
代替 :
X = tf.keras.layers.GlobalAvgPool1D()(X)
现在的摘要是:
Model: "functional_13"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_14 (InputLayer) [(None, 64, 35)] 0
_________________________________________________________________
lstm_26 (LSTM) (None, 64, 512) 1122304
_________________________________________________________________
lstm_27 (LSTM) (None, 64, 256) 787456
_________________________________________________________________
reshape_2 (Reshape) (None, 64, 256, 1) 0
_________________________________________________________________
time_distributed_8 (TimeDist (None, 64, 1) 0
_________________________________________________________________
reshape_3 (Reshape) (None, 64) 0
_________________________________________________________________
dense_61 (Dense) (None, 128) 8320
_________________________________________________________________
dense_62 (Dense) (None, 64) 8256
_________________________________________________________________
dense_63 (Dense) (None, 32) 2080
_________________________________________________________________
dense_64 (Dense) (None, 10) 330
=================================================================
Total params: 1,928,746
Trainable params: 1,928,746
Non-trainable params: 0
推荐阅读
- excel - 在 Excel 中发送 cmd 消息
- r - R逻辑回归 - 错误消息
- python - Pyinstaller python to .exe没有在命令行界面上显示任何输出
- php - 每次打开新终端时如何让laravel自动运行?
- python - 在哪里插入重试循环语句?
- python-3.x - pandas 切割左包边独占 bin 边缘
- cmake - cmake 命令有什么方法可以在文件中使用参数?
- java - RabbitMQ 没有在 Spring Boot 中自动配置
- c++ - C++ 运算符重载中的类型推断
- node.js - 为什么我的 deleteMessage 功能不起作用?