python - 烧瓶中加载的 keras 模型总是预测同一类
问题描述
奇怪的事情发生在我身上。我使用 keras 训练了一个情感分析模型,如下所示:
max_fatures = 2000
tokenizer = Tokenizer(num_words=max_fatures, split=' ')
tokenizer.fit_on_texts(data)
X = tokenizer.texts_to_sequences(data)
X = pad_sequences(X)
with open('tokenizer.pkl', 'wb') as fid:
_pickle.dump(tokenizer, fid)
le = LabelEncoder()
le.fit(["pos", "neg"])
y = le.transform(data_labels)
y = keras.utils.to_categorical(y)
embed_dim = 128
lstm_out = 196
model = Sequential()
model.add(Embedding(max_fatures, embed_dim, input_length=X.shape[1]))
model.add(SpatialDropout1D(0.4))
model.add(LSTM(lstm_out, dropout=0.2, recurrent_dropout=0.2))
model.add(Dense(3, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
batch_size = 32
model.fit(X, y, epochs=10, batch_size=batch_size, verbose=2)
model.save('deep.h5')
当我将它加载到另一个 python 文件中时,一切都很好。但是当我将它加载到烧瓶 Web 应用程序中时,所有预测的类都是正数。出了什么问题?这是我在烧瓶 Web 应用程序中使用的代码:
with open('./resources/model/tokenizer.pkl', 'rb') as handle:
keras_tokenizer = _pickle.load(handle)
K.clear_session()
model = load_model('./resources/model/deep.h5')
model._make_predict_function()
session = K.get_session()
global graph
graph = tf.get_default_graph()
graph.finalize()
stop_words = []
with open('./resources/stopwords.txt', encoding="utf8") as f:
stop_words = f.read().splitlines()
normalizer = Normalizer()
stemmer = Stemmer()
tokenizer = RegexpTokenizer(r'\w+')
def predict_class(text):
tokens = tokenizer.tokenize(text)
temp = ''
for token in tokens:
if token in stop_words:
continue
token = normalizer.normalize(token)
token = stemmer.stem(token)
temp += token + ' '
if not temp.strip():
return None
text = keras_tokenizer.texts_to_sequences(temp.strip())
text = pad_sequences(text, maxlen=41)
le = LabelEncoder()
le.fit(["pos", "neg"])
with session.as_default():
with graph.as_default():
sentiment = model.predict_classes(text)
return le.inverse_transform(sentiment)[0]
解决方案
是的,我有同样的问题。但就我而言,我的预测是正确的。我认为具有模型架构和权重的“.h5”文件是不够的,您需要使用标记器,因为它包含所有唯一标记的单词索引或模型训练所依据的单词。
因此,我强烈推荐 (Eudald Arranz)[https://stackoverflow.com/users/11153431/eudald-arranz] 在此线程上的最后一篇文章 - 以 JSON 格式保存权重和模型架构。
原因这实际上对我有用。
谢谢, 尤达德
推荐阅读
- python - 转置 Pandas 数据框保留索引
- r - 非零退出状态 tidyverse 安装包 Rstudio
- r - 在某些条件下从上面的值中减去值 - 没有循环
- python - 如何查找图像中每个多边形中的总点数
- autodesk-forge - 为什么我的 Forge 设计自动化活动 iLogic 失败?
- jquery - 用于选项卡可访问菜单的菜单可见性和 MS Edge hack
- authentication - 如何使用 OIDC 请求刷新令牌
- c - 使用 MSVC 命令行创建动态库
- python - 从 csv 文件中获取一些数字的平均值作为输入,并将平均值写入 python 3 中的输出 csv 文件
- python - 有条件地将多个列分配给另一个 DataFrame(条件确定分配该行中的哪一组列)