首页 > 解决方案 > 为什么双向 LSTM 不收敛训练多元数据序列?

问题描述

我正在训练一个多元数据序列,如下所示。

[[ x x x x x x x]
 [ x x x x x x x]
 [ x x x x x x x]
 [ x x x x x x x]] [ x x x x x x x]
[[ x x x x x x x]
 [ x x x x x x x]
 [ x x x x x x x]
 [ x x x x x x x]] [ x x x x x x x]

BiLSTM 有四步时间序列并预测下一个序列。我有 1800 x 7 数据数组要训练。

我的张量流代码如下。

    n_steps = 4
def split_sequences(sequences, n_steps):
  x, y = list(), list()
  for i in range(len(sequences)-n_steps-5):
        # find the end of this pattern
    end_ix = i + n_steps
        # check if we are beyond the dataset
    if end_ix > len(sequences):
      break
        # gather input and output parts of the pattern
    seq_x, seq_y = sequences[i:end_ix, :], sequences[end_ix, :]    
    x.append(seq_x)
    y.append(seq_y)
  return array(x), array(y)




from keras.layers import Bidirectional
n_steps = 4
x, y = split_sequences(dataset, n_steps)
#print(x.shape)
j=0
for i in range(len(x)):
  print(x[i], y[j])
  j=j+1 
  if(i > 2):
    break
# define model
n_features = x.shape[2]
# define model
def create_BILSTM(units):
    model = Sequential()
    model.add(Bidirectional(LSTM(100, activation='relu'), input_shape=(n_steps, n_features)))
    model.add(Dense(n_features))
    model.compile(optimizer='adam', loss='mse')   
    return model
bilstm_model = create_BILSTM(100)
def fit_model(model):
    early_stop = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 10)
    history = model.fit(x,y,epochs = 200,validation_split = 0.2,batch_size = 1,shuffle = True,callbacks = [early_stop])
    #history = model.fit(x,y,epochs = 200)
    return history
history_gru = fit_model(bilstm_model)

训练不会收敛并在 20 个 epoch 处停止。我应该怎么做才能提高培训?

Epoch 1/200
1015/1015 [==============================] - 5s 4ms/step - loss: 104.6197 - val_loss: 72.3360
Epoch 2/200
1015/1015 [==============================] - 3s 3ms/step - loss: 68.5425 - val_loss: 72.8522
Epoch 3/200
1015/1015 [==============================] - 3s 3ms/step - loss: 71.9564 - val_loss: 71.8446
Epoch 4/200
1015/1015 [==============================] - 3s 3ms/step - loss: 67.9026 - val_loss: 74.8369
Epoch 5/200
1015/1015 [==============================] - 3s 3ms/step - loss: 70.3062 - val_loss: 79.5937
Epoch 6/200
1015/1015 [==============================] - 3s 3ms/step - loss: 69.3122 - val_loss: 73.9319
Epoch 7/200
1015/1015 [==============================] - 3s 3ms/step - loss: 70.8744 - val_loss: 73.5675
Epoch 8/200
1015/1015 [==============================] - 3s 3ms/step - loss: 67.8673 - val_loss: 71.6734
Epoch 9/200
1015/1015 [==============================] - 3s 3ms/step - loss: 66.7135 - val_loss: 71.1467
Epoch 10/200
1015/1015 [==============================] - 3s 3ms/step - loss: 65.6959 - val_loss: 71.6195
Epoch 11/200
1015/1015 [==============================] - 4s 4ms/step - loss: 66.1596 - val_loss: 75.0583
Epoch 12/200
1015/1015 [==============================] - 4s 4ms/step - loss: 63.9930 - val_loss: 72.9204
Epoch 13/200
1015/1015 [==============================] - 3s 3ms/step - loss: 64.9793 - val_loss: 79.2959
Epoch 14/200
1015/1015 [==============================] - 3s 3ms/step - loss: 64.4731 - val_loss: 74.7689
Epoch 15/200
1015/1015 [==============================] - 3s 3ms/step - loss: 66.7180 - val_loss: 71.9217
Epoch 16/200
1015/1015 [==============================] - 3s 3ms/step - loss: 65.0417 - val_loss: 71.0694
Epoch 17/200
1015/1015 [==============================] - 3s 3ms/step - loss: 67.1815 - val_loss: 74.5580
Epoch 18/200
1015/1015 [==============================] - 3s 3ms/step - loss: 68.3193 - val_loss: 71.0372
Epoch 19/200
1015/1015 [==============================] - 4s 3ms/step - loss: 68.6245 - val_loss: 71.9737
Epoch 20/200
1015/1015 [==============================] - 3s 3ms/step - loss: 66.3616 - val_loss: 71.9733
Epoch 21/200
1015/1015 [==============================] - 3s 3ms/step - loss: 71.2269 - val_loss: 72.1401
Epoch 22/200
1015/1015 [==============================] - 4s 3ms/step - loss: 66.2718 - val_loss: 73.4132
Epoch 23/200
1015/1015 [==============================] - 4s 4ms/step - loss: 67.6766 - val_loss: 74.5763
Epoch 24/200
1015/1015 [==============================] - 3s 3ms/step - loss: 66.3503 - val_loss: 71.4974
Epoch 25/200
1015/1015 [==============================] - 3s 3ms/step - loss: 64.1543 - val_loss: 73.8621
Epoch 26/200
1015/1015 [==============================] - 4s 4ms/step - loss: 64.2544 - val_loss: 73.6484
Epoch 27/200
1015/1015 [==============================] - 3s 3ms/step - loss: 66.0953 - val_loss: 73.4449
Epoch 28/200
1015/1015 [==============================] - 3s 3ms/step - loss: 67.1023 - val_loss: 71.2993

标签: tensorflowkerasdeep-learninglstmrecurrent-neural-network

解决方案


推荐阅读