tensorflow - 无法加载已保存的 Keras 自定义模型
问题描述
我是 keras 的新手。我尝试在此链接中运行示例模型。由于代码很长,您可以通过转到上面的链接来查看它。
模型:
class Transformer(keras.Model):
def __init__(
self,
num_hid=64,
num_head=2,
num_feed_forward=128,
source_maxlen=100,
target_maxlen=100,
num_layers_enc=4,
num_layers_dec=1,
num_classes=10,
):
super().__init__()
self.loss_metric = keras.metrics.Mean(name="loss")
self.num_layers_enc = num_layers_enc
self.num_layers_dec = num_layers_dec
self.target_maxlen = target_maxlen
self.num_classes = num_classes
self.enc_input = SpeechFeatureEmbedding(num_hid=num_hid, maxlen=source_maxlen)
self.dec_input = TokenEmbedding(
num_vocab=num_classes, maxlen=target_maxlen, num_hid=num_hid
)
self.encoder = keras.Sequential(
[self.enc_input]
+ [
TransformerEncoder(num_hid, num_head, num_feed_forward)
for _ in range(num_layers_enc)
]
)
for i in range(num_layers_dec):
setattr(
self,
f"dec_layer_{i}",
TransformerDecoder(num_hid, num_head, num_feed_forward),
)
self.classifier = layers.Dense(num_classes)
def decode(self, enc_out, target):
y = self.dec_input(target)
for i in range(self.num_layers_dec):
y = getattr(self, f"dec_layer_{i}")(enc_out, y)
return y
def call(self, inputs):
source = inputs[0]
target = inputs[1]
x = self.encoder(source)
y = self.decode(x, target)
return self.classifier(y)
@property
def metrics(self):
return [self.loss_metric]
def train_step(self, batch):
"""Processes one batch inside model.fit()."""
source = batch["source"]
target = batch["target"]
dec_input = target[:, :-1]
dec_target = target[:, 1:]
with tf.GradientTape() as tape:
preds = self([source, dec_input])
one_hot = tf.one_hot(dec_target, depth=self.num_classes)
mask = tf.math.logical_not(tf.math.equal(dec_target, 0))
loss = self.compiled_loss(one_hot, preds, sample_weight=mask)
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
self.loss_metric.update_state(loss)
return {"loss": self.loss_metric.result()}
def test_step(self, batch):
source = batch["source"]
target = batch["target"]
dec_input = target[:, :-1]
dec_target = target[:, 1:]
preds = self([source, dec_input])
one_hot = tf.one_hot(dec_target, depth=self.num_classes)
mask = tf.math.logical_not(tf.math.equal(dec_target, 0))
loss = self.compiled_loss(one_hot, preds, sample_weight=mask)
self.loss_metric.update_state(loss)
return {"loss": self.loss_metric.result()}
def generate(self, source, target_start_token_idx):
"""Performs inference over one batch of inputs using greedy decoding."""
bs = tf.shape(source)[0]
enc = self.encoder(source)
dec_input = tf.ones((bs, 1), dtype=tf.int32) * target_start_token_idx
dec_logits = []
for i in range(self.target_maxlen - 1):
dec_out = self.decode(enc, dec_input)
logits = self.classifier(dec_out)
logits = tf.argmax(logits, axis=-1, output_type=tf.int32)
last_logit = tf.expand_dims(logits[:, -1], axis=-1)
dec_logits.append(last_logit)
dec_input = tf.concat([dec_input, last_logit], axis=-1)
return dec_input
该示例工作正常,问题是当我使用以下代码保存模型时:
filepath = "SavedModel"
model.save(filepath)
输出显示以下警告:
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.embeddings.Embedding object at 0x7f8843636e50>, because it is not built.
WARNING:absl:Found untraced functions such as conv1d_layer_call_fn, conv1d_layer_call_and_return_conditional_losses, conv1d_1_layer_call_fn, conv1d_1_layer_call_and_return_conditional_losses, conv1d_2_layer_call_fn while saving (showing 5 of 345). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: SavedModel/assets
INFO:tensorflow:Assets written to: SavedModel/assets
然后当我加载模型并尝试将其与数据相匹配时,
custom_objects = {'TokenEmbedding': TokenEmbedding,
'SpeechFeatureEmbedding': SpeechFeatureEmbedding,
'TransformerEncoder': TransformerEncoder,
'TransformerDecoder': TransformerDecoder,
'Transformer': Transformer,
'CustomSchedule': CustomSchedule
}
loaded_model = tf.keras.models.load_model(filepath, custom_objects=custom_objects)
它给了我一堆错误信息:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-40-f4e926fd01d6> in <module>()
----> 1 loaded_model.fit(ds, validation_data=val_ds, callbacks=[display_cb], epochs=1)
9 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
992 except Exception as e: # pylint:disable=broad-except
993 if hasattr(e, "ag_error_metadata"):
--> 994 raise e.ag_error_metadata.to_exception(e)
995 else:
996 raise
ValueError: in user code:
/usr/local/lib/python3.7/dist-packages/keras/engine/training.py:853 train_function *
return step_function(self, iterator)
/usr/local/lib/python3.7/dist-packages/keras/engine/training.py:842 step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
/usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:1286 run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:2849 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:3632 _call_for_each_replica
return fn(*args, **kwargs)
/usr/local/lib/python3.7/dist-packages/keras/engine/training.py:835 run_step **
outputs = model.train_step(data)
/usr/local/lib/python3.7/dist-packages/keras/engine/training.py:787 train_step
y_pred = self(x, training=True)
/usr/local/lib/python3.7/dist-packages/keras/engine/base_layer.py:1020 __call__
input_spec.assert_input_compatibility(self.input_spec, inputs, self.name)
/usr/local/lib/python3.7/dist-packages/keras/engine/input_spec.py:185 assert_input_compatibility
(name, list(inputs.keys()), names))
ValueError: Missing data for input "input_2". You passed a data dictionary with keys ['source', 'target']. Expected the following keys: ['input_2']
我使用 google colab 运行它,我尝试搜索这个错误很长时间,但我看不到任何解决方案。任何人都可以帮助我吗?谢谢!
解决方案
推荐阅读
- c# - 元素无法通过键盘进行交互,因为它不可聚焦
- c# - ServiceStack.OrmLite:在 SqlExpression 中为匿名类型使用别名
- vert.x - 如何让 HOCON 配置文件格式在部署为 fat jar 的 Vert.x 中工作?
- python - 如何在不使用 apt-get 的情况下安装 libasound2-dev 32 位?
- r - 绘制图表
- ios - URL(string:) 为包含西里尔字母的路径返回 nil
- python - UnicodeEncodeError:带有阿拉伯数据
- r - 使用 docker-compose 和 RStudio/Jupyter 的持久卷
- amazon-ec2 - Netflix eureka 客户端在 AWS ec2 中使用私有 DNS 而不是公共 DNS 或公共 IP 向 eureka 发现服务器注册
- sql - 在 SQL Server 表中存储日期时间值