tensorflow2.0 - 尝试在不构建函数 tf.keras.models.save_model 的情况下捕获 EagerTensor
问题描述
我有同样的错误。我在 DNNModel 中添加DenseHashTable,打算在其中保存嵌入,这是代码:
class DNNModel(tf.keras.Model):
"""A DNN Model."""
def __init__(self, ...):
super(DNNModel, self).__init__(name, **kwargs)
vocabulary_list, embeddings = run_data.get_pretrain_vocabs_embeddings()
self.table = self._create_embedding_table(vocabulary_list, embeddings)
def _create_embedding_table(self, vocab_list, embeddings):
dimension = embeddings.shape[1]
table = tf.lookup.experimental.DenseHashTable(tf.string, tf.float32, [2.0]*dimension, 'empty_key', 'deleted_key')
# 对比,测试到底注释前后有没有效果
keys = tf.constant([i for i in vocab_list], tf.string)
values = tf.convert_to_tensor(embeddings, tf.float64)
values = tf.cast(values, tf.float32)
table.insert(keys, values)
return table
def call(self, inputs, training=None, mask=None):
fc_embeddings = self._input_layer(inputs)
bert_embeddings = self._look_up(inputs)
net = tf.concat([fc_embeddings, bert_embeddings], axis=1)
....
运行时一切正常,但是当我使用代码tf.keras.models.save_model(model, FLAGS.servable_model_dir)导出模型时,它将引发错误:RuntimeError: Attempting to capture an EagerTensor without build a function。
我调试代码,它喜欢序列化 dnnModel 中的 DenseHashTable 会引发异常:
File "/Users/jiananliu/work/neirongrecom/model/ctr_model/wide_n_deep/wide_n_deep_keras_main.py", line 1021, in run
tf.keras.models.save_model(model, FLAGS.servable_model_dir, include_optimizer=False)
File "/Users/jiananliu/anaconda3/envs/transformer/lib/python3.7/site-packages/tensorflow/python/keras/saving/save.py", line 138, in save_model
signatures, options)
File "/Users/jiananliu/anaconda3/envs/transformer/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/save.py", line 78, in save
save_lib.save(model, filepath, signatures, options)
File "/Users/jiananliu/anaconda3/envs/transformer/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 951, in save
obj, export_dir, signatures, options, meta_graph_def)
File "/Users/jiananliu/anaconda3/envs/transformer/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 1027, in _build_meta_graph
options.namespace_whitelist)
File "/Users/jiananliu/anaconda3/envs/transformer/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 595, in _fill_meta_graph_def
object_map, resource_map, asset_info = saveable_view.map_resources()
File "/Users/jiananliu/anaconda3/envs/transformer/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 270, in map_resources
new_resource = new_obj._create_resource()
File "/Users/jiananliu/anaconda3/envs/transformer/lib/python3.7/site-packages/tensorflow/python/ops/lookup_ops.py", line 1945, in _create_resource
name=self._name)
File "/Users/jiananliu/anaconda3/envs/transformer/lib/python3.7/site-packages/tensorflow/python/ops/gen_lookup_ops.py", line 1113, in mutable_dense_hash_table_v2
max_load_factor=max_load_factor, name=name)
File "/Users/jiananliu/anaconda3/envs/transformer/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py", line 470, in _apply_op_helper
preferred_dtype=default_dtype)
File "/Users/jiananliu/anaconda3/envs/transformer/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 1307, in convert_to_tensor
raise RuntimeError("Attempting to capture an EagerTensor without "
RuntimeError: Attempting to capture an EagerTensor without building a function.
请帮我!!!
解决方案
我发现用 tf.lookup.experimental.DenseHashTable 替换 tensorflow.python.ops.lookup_ops.MutableHashTable 可以解决这个错误。
def create_table():
table= MutableHashTable(key_dtype=tf.int32, value_dtype=tf.int32, default_value=[0])
# table = tf.lookup.experimental.DenseHashTable(
# key_dtype=tf.int32,
# value_dtype=tf.int32,
# default_value=-1,
# empty_key=0,
# deleted_key=-1)
return table
推荐阅读
- google-apps-script - 如何修复基于过滤器在所有 Google 表格中动态传输数据的代码错误
- flutter - 无法在 Flutter sliverappbar 中使用 Google 字体
- django - 并非针对所有应用程序进行迁移(某些应用程序只是跳过)
- hive - 使表保持最新 - Hive
- c# - C#如何从子类中读取所有变量父类和子类的名称
- node.js - 如何在 React.js 上以二进制数据呈现来自 MongoDB 的图像?
- python - python tcp套接字错误值
- javascript - 在 for 循环中使用时拼接不起作用
- javascript - JqueryDate 选择器 UI 仅在 IE 中使用触摸屏重复年份和月份
- elasticsearch - 有没有办法在 ElasticSearch 中收到已删除文档的通知?