首页 > 解决方案 > 尝试在不构建函数 tf.keras.models.save_model 的情况下捕获 EagerTensor

问题描述

我有同样的错误。我在 DNNModel 中添加DenseHashTable,打算在其中保存嵌入,这是代码:

class DNNModel(tf.keras.Model):
    """A DNN Model."""

    def __init__(self, ...):
        super(DNNModel, self).__init__(name, **kwargs)
        vocabulary_list, embeddings = run_data.get_pretrain_vocabs_embeddings()
        self.table = self._create_embedding_table(vocabulary_list, embeddings)

    def _create_embedding_table(self, vocab_list, embeddings):
        dimension = embeddings.shape[1]
        table = tf.lookup.experimental.DenseHashTable(tf.string, tf.float32, [2.0]*dimension, 'empty_key', 'deleted_key')
        # 对比,测试到底注释前后有没有效果
        keys = tf.constant([i for i in vocab_list], tf.string)
        values = tf.convert_to_tensor(embeddings, tf.float64)
        values = tf.cast(values, tf.float32)
        table.insert(keys, values)
        return table

    def call(self, inputs, training=None, mask=None):
        fc_embeddings = self._input_layer(inputs)
        bert_embeddings = self._look_up(inputs)
        net = tf.concat([fc_embeddings, bert_embeddings], axis=1)
        ....

运行时一切正常,但是当我使用代码tf.keras.models.save_model(model, FLAGS.servable_model_dir)导出模型时,它将引发错误:RuntimeError: Attempting to capture an EagerTensor without build a function。

我调试代码,它喜欢序列化 dnnModel 中的 DenseHashTable 会引发异常:

  File "/Users/jiananliu/work/neirongrecom/model/ctr_model/wide_n_deep/wide_n_deep_keras_main.py", line 1021, in run
    tf.keras.models.save_model(model, FLAGS.servable_model_dir, include_optimizer=False)
  File "/Users/jiananliu/anaconda3/envs/transformer/lib/python3.7/site-packages/tensorflow/python/keras/saving/save.py", line 138, in save_model
    signatures, options)
  File "/Users/jiananliu/anaconda3/envs/transformer/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/save.py", line 78, in save
    save_lib.save(model, filepath, signatures, options)
  File "/Users/jiananliu/anaconda3/envs/transformer/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 951, in save
    obj, export_dir, signatures, options, meta_graph_def)
  File "/Users/jiananliu/anaconda3/envs/transformer/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 1027, in _build_meta_graph
    options.namespace_whitelist)
  File "/Users/jiananliu/anaconda3/envs/transformer/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 595, in _fill_meta_graph_def
    object_map, resource_map, asset_info = saveable_view.map_resources()
  File "/Users/jiananliu/anaconda3/envs/transformer/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 270, in map_resources
    new_resource = new_obj._create_resource()
  File "/Users/jiananliu/anaconda3/envs/transformer/lib/python3.7/site-packages/tensorflow/python/ops/lookup_ops.py", line 1945, in _create_resource
    name=self._name)
  File "/Users/jiananliu/anaconda3/envs/transformer/lib/python3.7/site-packages/tensorflow/python/ops/gen_lookup_ops.py", line 1113, in mutable_dense_hash_table_v2
    max_load_factor=max_load_factor, name=name)
  File "/Users/jiananliu/anaconda3/envs/transformer/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py", line 470, in _apply_op_helper
    preferred_dtype=default_dtype)
  File "/Users/jiananliu/anaconda3/envs/transformer/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 1307, in convert_to_tensor
    raise RuntimeError("Attempting to capture an EagerTensor without "
RuntimeError: Attempting to capture an EagerTensor without building a function.

请帮我!!!

标签: tensorflow2.0

解决方案


我发现用 tf.lookup.experimental.DenseHashTable 替换 tensorflow.python.ops.lookup_ops.MutableHashTable 可以解决这个错误。

def create_table():
table= MutableHashTable(key_dtype=tf.int32, value_dtype=tf.int32, default_value=[0])
# table = tf.lookup.experimental.DenseHashTable(
#     key_dtype=tf.int32,
#     value_dtype=tf.int32,
#     default_value=-1,
#     empty_key=0,
#     deleted_key=-1)
return table

推荐阅读