首页 > 解决方案 > LSTM 在 Keras 嵌入模型中不起作用

问题描述

我有一个神经网络,它利用 keras 的嵌入层作为模型的输入。嵌入由 word2vec 生成,然后通过嵌入层组织并输入到模型中。我非常有信心我的这部分代码可以正常工作。当我使用单个密集层运行我的模型时,一切都按预期运行,并且我的验证准确率约为 75%。然而,用 LSTM 替换密集层会返回大约 50 的准确度,并且对于所有 epoch,它甚至不会改变小数点。

嵌入和模型的代码如下:

vectorizer = TextVectorization(max_tokens=MAX_TOKENS, output_sequence_length=INPUT_LENGTH)
text_ds = tf.data.Dataset.from_tensor_slices(train_samples).batch(128)
vectorizer.adapt(text_ds)
voc = vectorizer.get_vocabulary()
word_index = dict(zip(voc, range(len(voc))))



tokenizer = Tokenizer()
tokenizer.fit_on_texts(samples)
sequences = tokenizer.texts_to_sequences(samples)

word_index = tokenizer.word_index

embeddings_index = {}
f = open(os.path.join("Data/", 'model.csv'), encoding="latin-1")
for line in f:
    values = line.split()
    word = values[0]
    try:
        coefs = np.asarray(values[1:], dtype='float32')
    except ValueError:
        coefs = np.zeros(100,)
    embeddings_index[word] = coefs
f.close()



embedding_dim = 100
hits = 0
misses = 0

embedding_matrix = np.zeros((MAX_TOKENS+2, embedding_dim))
for word, i in word_index.items():
    embedding_vector = embeddings_index.get(word)
    if embedding_vector is not None:
        # Words not found in embedding index will be all-zeros.
        # This includes the representation for "padding" and "OOV"
        embedding_matrix[i] = embedding_vector
        hits += 1
    else:
        misses += 1



x_train = vectorizer(np.array([[s] for s in train_samples])).numpy()
x_val = vectorizer(np.array([[s] for s in val_samples])).numpy()

y_train = np.array(train_labels).astype("float32")
y_val = np.array(val_labels).astype("float32")



model = Sequential()
model.add(Embedding(MAX_TOKENS+2, embedding_dim,  weights=[embedding_matrix], input_length=INPUT_LENGTH, trainable=False))
model.add(LSTM(32))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
print(model.summary())
# Fit the model
history = model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=3,batch_size=32, verbose=2)

标签: pythonkeraslstm

解决方案


推荐阅读