python - 使用 Bi-LSTM 和 glove 的多类分类
问题描述
我正在使用 BI LSTM 和手套嵌入进行多类分类,当我训练我的模型时,在预测 (model.predict) 上我得到不正确的结果,如下所示,结果不在 0 和 1 之间,有人可以帮我吗?
3916/3916 [==============================] - 17s 4ms/step
[[9.9723792e-01 1.6954101e-03 1.0665554e-03]
[1.6794224e-01 8.6485274e-02 7.4557245e-01]
[9.4370516e-03 1.0848863e-03 9.8947805e-01]
...
[1.3264662e-02 9.7078091e-01 1.5954463e-02]
[1.2019513e-02 9.8711687e-01 8.6356810e-04]
[8.1863362e-01 1.5828104e-01 2.3085352e-02]]
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM,Dense, Dropout,Bidirectional
from tensorflow.keras.layers import SpatialDropout1D
from tensorflow.keras.layers import Embedding
from tensorflow.keras.preprocessing.text import Tokenizer
embedding_vector_length = 100
model_2 = Sequential()
model_2.add(Embedding(len(tokenizer.word_index) + 1, embedding_vector_length,
input_length=409,name="Bi-LSTM") )
model_2.add(SpatialDropout1D(0.3))
model_2.add(Bidirectional(LSTM(64, return_sequences=False, recurrent_dropout=0.4)))
model_2.add(Dropout(0.5))
model_2.add(Dense(3,activation='softmax'))
model_2.compile(loss='categorical_crossentropy',optimizer='adam',
metrics=['accuracy'])
print(model_2.summary())
model_2.layers[0].set_weights([embedding_matrix])
model_2.layers[0].trainable = False
print(model_2.summary())
from keras.callbacks import EarlyStopping
es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=3)
history_2=model_2.fit(x_train, y_train,
batch_size=400,
epochs=30,
validation_data=(x_val, y_val),
callbacks=[es])
#We save this model so that we can use in own web app
解决方案
如果打印的矩阵是您的model.predict()
结果 - 它们介于 0 和 1 之间(您需要考虑指数部分)
推荐阅读
- python - 尝试在 Django Post LIKE DISLIKE 功能上使用 Ajax 时找不到页面错误
- time - Influxdb GROUP BY time 聚合错误时间间隔或存储桶上的数据
- pygame - 如何在 pygame 显示中加载 python 维基百科图像?
- c# - C# - 带有 OnClick + href 的链接按钮?
- c# - 拆分单行会导致更多开销吗?
- android - Google Drive REST API:java.lang.IllegalArgumentException:名称不能为空:null
- python - OpenCV-python获取两点(线)之间的像素集
- python - Django TypeError: validate_location() 缺少 2 个必需的位置参数:'location' 和 'parcare_on'
- angularjs - Angularjs HTTP 拦截器异常
- c# - C# LibTiff.Net - MultiTIFF 更改标签值