tensorflow - 当我将列表对象添加到 keras 子类模型时,`tf.model_to_estimator` 引发 AttributeError
问题描述
我已经像这样创建了一个 keras 子类模型:
class SubModel(tf.Keras.Model):
def __init__(self, features, **kwargs):
"""Init function of Model.
Args:
features: A list of SparseFeature and DenseFeature.
"""
assert len(features) > 0
super(SubModel, self).__init__(name='SubModel', **kwargs)
self.features = features
请注意,__init__
函数中有一个特性将在call
此模型的方法中使用。当我用 keras 风格训练和评估模型时,一切都很好。
但是,现在我想使用tf.keras.model_to_estimator
函数将此模型转换为估计器。它引发了一个错误:AttributeError: '_ListWrapper' object has no attribute 'get_config'
.
根据我的调试,features
这是添加到模型中的属性导致此错误。在转换为估计器时,它把特征作为模型的一个,并在克隆模型时layer
尝试调用该函数。get_config
似乎添加到模型中的所有属性都将被视为layer
克隆模型时。
但我真的很想features
用作模型的一部分,以便可以通过该模型的其他功能访问它,例如call
. 还有其他方法可以解决这个问题吗?
解决方案
我认为tf.keras.model_to_estimator
与Keras 模型完美兼容,但与模型很差,Sequential
尤其是在子类中实现复杂操作时。Functional API
Subclass
因此,如果您已经定义了一个子类 keras 模型,并且想将其转换为 estimator,最好的方法是定义model_fn
函数,并将 keras 模型放入其中,如下面的代码:
def model_fn(features, labels, mode):
model = SubModel()
outputs = model(features)
loss = tf.keras.losses.xx(labels, outputs)
return tf.estimator.EstimatorSpec(...)
推荐阅读
- javascript - 如何在我的 javascript 类(ES6)中实现事件处理程序
- apache - 具有跨源 (CORS) 的 Apache 反向代理:CentOS 7 / 实施 SSO(单点登录)
- python - 如何在 Python 中对 numpy 矩阵和列表中的字母数字字符串进行排序?
- java - 为什么 bufferedReader 比 Java 中的 Scanner 类高效得多?
- c# - 如果 TcpClient 崩溃,请重新连接到服务器
- mysql - mysql - 每月和每年的总和
- javascript - 需要碰撞脚本方面的帮助
- sql - 如何在 SQL Server 中比较具有相同 ID 的不同行中的相同列值
- javascript - 如何更改 Ant 表头复选框
- c++ - 带有 boost asio ssl 的未定义引用