首页 > 解决方案 > Keras model.predict 给出相同的值

问题描述

我正在使用 Kerasmodel.predict来获取数据集中不存在的新句子的标签。但是无论句子如何,预测总是给出相同的值。

这是我的预测代码

from sklearn.preprocessing import LabelEncoder

maxlen = 300
### PREDICT NEW UNSEEN DATA ###
tokenizer = Tokenizer()
label_enc = LabelEncoder()
label_enc.fit(tar_list)
X_test = ['asdsadav dawd','this is boring', 'wow i like this you did a great job', 'ima cry tht was mean','1 nov 1968 george harrison became the first beatle to release a solo album in the u k with the soundtrack to ']

X_test = tokenizer.texts_to_sequences(X_test)
X_test = sequence.pad_sequences(X_test, maxlen=maxlen)

print(X_test)

a = (model.predict(X_test)>0.5).astype(int).ravel()
print(a)

reverse_pred = label_enc.inverse_transform(a.ravel())
print(reverse_pred)

print(model.predict(X_test))

这是输出

 [[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
[1 0 1 0 1 0 1 0 1 0]
[1 0 1 0 1 0 1 0 1 0]
[[0.988675   0.01132498]
 [0.988675   0.01132498]
 [0.988675   0.01132498]
 [0.988675   0.01132498]
 [0.988675   0.01132498]]

正如我们所见,每个句子的概率输出都是相同的。

我的训练模型代码是

model = Sequential()
model.add(Embedding(max_words, 300, input_length=max_len))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(SpatialDropout1D(0.5))
model.add(Conv1D(16, kernel_size=3, activation='relu'))
model.add(Bidirectional(LSTM(16)))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(Dropout(0.5))
model.add(Dense(2, activation='softmax'))
model.summary()
model.compile(loss='sparse_categorical_crossentropy', metrics=['accuracy'], optimizer = 'adam')

这是 X_train 的标记器拟合

max_words = 3000
max_len = 300
tok = Tokenizer(num_words = max_words)
tok.fit_on_texts(X_train)
sequences = tok.texts_to_sequences(X_train)
sequences_matrix = sequence.pad_sequences(sequences, maxlen = max_len)
print(sequences_matrix)
Y_train = np.array(Y_train)
Y_test = np.array(Y_test)

输出是

[[  0   0   0 ...  11  28  33]
 [  0   0   0 ...   2 125  63]
 [  0   0   0 ...   9 184  91]
 ...
 [  0   0   0 ... 105  22  85]
 [  0   0   0 ...  22  42 512]
 [  0   0   0 ...   9  28 406]]

标签: python-3.xtensorflowkerasnlp

解决方案


推荐阅读