tensorflow - LSTM 需要更多时间来训练
问题描述
我正在使用以下简单的架构来训练我的模型,但是当我还使用带有填充输入的掩码输入时,我的模型显示每个时期的经过时间为 2-3 小时,为什么会这样。
请为我的模型找到以下代码
class lstm_raw(tf.keras.Model):
def __init__(self,name='spectrogram'):
super().__init__(name=name)
self.lstm = tf.keras.layers.LSTM(32,activation="tanh",kernel_initializer=tf.keras.initializers.he_uniform(seed=45),kernel_regularizer=tf.keras.regularizers.l2())
self.dense1 = tf.keras.layers.Dense(64,activation="relu",kernel_initializer=tf.keras.initializers.he_uniform(seed=45))
self.dense2 = tf.keras.layers.Dense(10,kernel_initializer=tf.keras.initializers.he_uniform(seed=45))
def call(self,X):
lstm_output = self.lstm(X[0],mask=X[1])
dense1 = self.dense1(lstm_output)
dense2 = self.dense2(dense1)
return dense2
with tf.device('/device:GPU:0'):
model1.fit(x=[X_train_pad_seq_test,X_train_mask_test],y=y_train,epochs=20,batch_size=4,steps_per_epoch=len(X_train_pad_seq_test)//4)
我的输入形状如下
((1400, 17640, 1), (1400, 17640, 1))
解决方案
代码中的罪魁祸首是activation="relu"
LSTM 层。
当且仅当激活设置为 时,Tensorflow 才使用 CuDNN 加速 LSTM 单元tanh
。
替换relu
为tanh
,然后查看您的模型起飞!
推荐阅读
- tree - APEX 树区域转义特殊字符
- ubuntu - Terraform 数据源 DataSourceNone 与 Ubuntu 20.04
- sql - 选择带有浇头 SQL 的比萨饼
- python - 'await' 方法调用上的 SyntaxError
- php - 使用 symfony formbuider 上的表单生成器生成有关 NM 关系的单选按钮
- javascript - 如何允许用户在本机反应中输入文件?
- android - 如果我使用 BroadcastReceiver,Android CallScreeningService 放在哪里?
- ios - 无法在 Objective C 中分配子 VC 属性
- telerik-reporting - Telerik Report 参数中的验证
- junit - 当我使用@ExtendWith(MockitoExtension.class) 时如何修复错误?