python - 保存和加载 keras 模型
问题描述
我正在研究使用通用句子嵌入对提供的句子进行编码的 Keras 模型。但是,当我保存模型以供将来使用时,会引发上述错误。 NameError: name 'embed' is not defined
UniversalEmbedding(x)
使用函数将句子转换为嵌入。整个模型的代码取自这个链接。
!wget https://raw.githubusercontent.com/Tony607/Keras-Text-Transfer-Learning/master/train_5500.txt
!wget https://raw.githubusercontent.com/Tony607/Keras-Text-Transfer-Learning/master/test_data.txt
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import re
import seaborn as sns
import keras.layers as layers
from keras.models import Model
from keras import backend as K
np.random.seed(10)
def get_dataframe(filename):
lines = open(filename, 'r').read().splitlines()
data = []
for i in range(0, len(lines)):
label = lines[i].split(' ')[0]
label = label.split(":")[0]
text = ' '.join(lines[i].split(' ')[1:])
text = re.sub('[^A-Za-z0-9 ,\?\'\"-._\+\!/\`@=;:]+', '', text)
data.append([label, text])
df = pd.DataFrame(data, columns=['label', 'text'])
df.label = df.label.astype('category')
return df
df_train = get_dataframe('train_5500.txt')
df_train = get_dataframe('test_data.txt')
category_counts = len(df_train.label.cat.categories)
module_url = "https://tfhub.dev/google/universal-sentence-encoder-large/3"
embed = hub.Module(module_url)
embed_size = embed.get_output_info_dict()['default'].get_shape()[1].value
def UniversalEmbedding(x):
return embed(tf.squeeze(tf.cast(x, tf.string)), signature="default", as_dict=True)["default"]
input_text = layers.Input(shape=(1,), dtype='string')
embedding = layers.Lambda(UniversalEmbedding, output_shape=(embed_size,))(input_text)
dense = layers.Dense(256, activation='relu')(embedding)
pred = layers.Dense(category_counts, activation='softmax')(dense)
model = Model(inputs=[input_text], outputs=pred)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
train_text = df_train['text'].tolist()
train_text = np.array(train_text, dtype=object)[:, np.newaxis]
train_label = np.asarray(pd.get_dummies(df_train.label), dtype = np.int8)
df_test = get_dataframe('test_data.txt')
test_text = df_test['text'].tolist()
test_text = np.array(test_text, dtype=object)[:, np.newaxis]
test_label = np.asarray(pd.get_dummies(df_test.label), dtype = np.int8)
with tf.Session() as session:
K.set_session(session)
session.run(tf.global_variables_initializer())
session.run(tf.tables_initializer())
history = model.fit(train_text,
train_label,
validation_data=(test_text, test_label),
epochs=2,
batch_size=32)
model.save_weights('./model.h5')
model.save('mod.h5')
当我尝试加载模型时
from keras.models import load_model
load_model('mod.h5')
解决方案
当您尝试使用 keras 的 load_model 加载模型时,该方法会显示错误,因为embed
它不是 keras 内置的,要解决此问题,您可能必须在使用 load_model 加载模型之前在代码中再次定义它。
请尝试在您提供的链接( https://www.dlology.com/blog/keras-meets-universal-sentence-encoder-transfer-learning-for-text-dataembed = hub.Module(module_url)
中写入所需的库和 url / ) 在尝试加载模型之前。
推荐阅读
- php - Laravel 表保存错误的数据
- java - 比较和设置操作Java
- macos - MacOS - 当 VkPhysicalDeviceFeatures wideLines = VK_TURE 并且也不支持 vkCmdSetLineWidth API 时,Vulkan 在运行时 vkCreateDevice() 失败
- amazon-web-services - 如何为 AWS Codepipeline 选择区域?
- firebase - 如何通过 Google 身份验证防止 Firebase 垃圾邮件注册?
- node.js - 在节点中使用 RSA 公钥验证 JWT
- firebase - 当firebase HTTP请求大小大于10MB时如何处理错误?
- r - 从开始日期和结束日期中提取包含 R 中特定值的行
- html - 如何始终在html中证明左侧的货币符号和右侧的数字?
- python - 如何在 google colab 中升级到 pytorch-nightly?