python - 'tf' is not defined on load_model() - using lambda
问题描述
I have a Keras
model that I am trying to export and use in a different python code.
Here is my code:
from keras.models import Sequential
from keras.layers import Dense, Embedding, LSTM, GRU, Flatten, Dropout, Lambda
from keras.layers.embeddings import Embedding
import tensorflow as tf
EMBEDDING_DIM = 100
model = Sequential()
model.add(Embedding(vocab_size, 300, weights=[embedding_matrix], input_length=max_length, trainable=False))
model.add(Lambda(lambda x: tf.reduce_mean(x, axis=1)))
model.add(Dense(8, input_dim=4, activation='relu'))
model.add(Dense(3, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X_train_pad, y_train, batch_size=128, epochs=25, validation_data=(X_val_pad, y_val), verbose=2)
model.save('my_model.h5')
In another file, when I import my_model.h5
:
from keras.models import load_model
from keras.layers import Lambda
import tensorflow as tf
def learning(test_samples):
model = load_model('my_model.h5')
#ERROR HERE
#rest of the code
The error is the following:
in <lambda>
model.add(Lambda(lambda x: tf.reduce_mean(x, axis=1)))
NameError: name 'tf' is not defined
After research, I got that the fact that I used lambda
in my model is the reason for this problem, but I added these references and it didn't help:
from keras.models import load_model
from keras.layers import Lambda
import tensorflow as tf
What could be the problem?
Thank you
解决方案
加载模型时,您需要显式处理自定义对象或自定义图层(CTRL+f处理自定义图层的文档):
import tensorflow as tf
import keras
model = keras.models.load_model('my_model.h5', custom_objects={'tf': tf})
推荐阅读
- angular - 导航到另一个页面后,角度引导轮播触摸滑动不起作用,但是当您单击控件或重新加载页面时它开始工作
- cakephp-2.0 - CakePHP 2 中的嵌套表
- python-3.x - 替换 Pandas 中的相似搜索词
- spring-boot - 使用spring + MongoRepository在mongodb cloud(atlas)上查询需要很长时间
- database - 后端 Heroku 和前端 Vercel 的部署 CORS 问题
- c - 即使两个字符串相等,strcmp() 也不返回 0
- pygame - 有没有办法在pygame中删除最小化应用程序或最大化它的选项?
- java - 在Scala中,给定一个子类,如何检查它是否使用反射覆盖其父接口(在Java中定义)的默认方法?
- java - 在 AEM 中找不到 org.apache.cxf.jaxws.spi.ProviderImpl 问题
- c# - Autofac Keyed Factory:不同枚举值的相同具体实现