python - TensorFlow 检查点自定义地图
问题描述
我正在使用自定义图层对 keras 模型进行子类化。每个层都包装了一个参数字典,在生成它们的层时使用该字典。似乎这些参数字典不是在 Tensorflow 中进行训练检查点之前设置的,而是在之后设置的,这会导致错误。我不知道如何解决这个问题,因为ValueError
提出的信息也提供了过时的信息(tf.contrib
不再存在)。
ValueError: Unable to save the object {'units': 32, 'activation': 'tanh', 'recurrent_initializer': 'glorot_uniform', 'dropout': 0, 'return_sequences': True} (在属性上自动构建的字典包装器任务)。包装的字典在包装器之外被修改(它的最终值为 {'units': 32, 'activation': 'tanh', 'recurrent_initializer': 'glorot_uniform', 'dropout': 0, 'return_sequences': True},它的添加检查点依赖项时的值是无),这会破坏对象创建的恢复。
如果您不需要这个字典检查点,请将其包装在 tf.contrib.checkpoint.NoDependency 对象中;它将被自动解包并随后被忽略。
这是引发此问题的层的示例:
class RecurrentConfig(BaseLayer):
'''Basic configurable recurrent layer'''
def __init__(self, params: Dict[Any, Any], mode: ModeKeys, layer_name: str = '', **kwargs):
self.layer_name = layer_name
self.cell_name = params.pop('cell', 'GRU')
self.num_layers = params.pop('num_layers', 1)
kwargs['name'] = layer_name
super().__init__(params, mode, **kwargs)
if layer_name == '':
self.layer_name = self.cell_name
self.layers: List[layers.Layer] = stack_layers(self.params,
self.num_layers,
self.cell_name)
def call(self, inputs: np.ndarray) -> layers.Layer:
'''This function is a sequential/functional call to this layers logic
Args:
inputs: Array to be processed within this layer
Returns:
inputs processed through this layer'''
processed = inputs
for layer in self.layers:
processed = layer(processed)
return processed
@staticmethod
def default_params() -> Dict[Any, Any]:
return{
'units': 32,
'recurrent_initializer': 'glorot_uniform',
'dropout': 0,
'recurrent_dropout': 0,
'activation': 'tanh',
'return_sequences': True
}
BaseLayer.py
'''Basic ABC for a keras style layer'''
from typing import Dict, Any
from tensorflow.keras import layers
from mosaix_py.mosaix_learn.configurable import Configurable
class BaseLayer(Configurable, layers.Layer):
'''Base configurable Keras layer'''
def get_config(self) -> Dict[str, Any]:
'''Return configuration dictionary as part of keras serialization'''
config = super().get_config()
config.update(self.params)
return config
@staticmethod
def default_params() -> Dict[Any, Any]:
raise NotImplementedError('Layer does not implement default params')
解决方案
我面临的问题是我正在从字典中弹出项目,并传递给图层。图层
self.cell_name = params.pop('cell', 'GRU')
self.num_layers = params.pop('num_layers', 1)
将字典传递到层时,它必须在跟踪时保持不变。
我的解决方案是进一步抽象出参数解析并传入最终的字典。
推荐阅读
- graphql - GQL 联合不重新调整值
- gradle - “src/main/kotlin”中的 Kotlin 源文件无法解析其他 gradle 模块类
- curl - 使用 wget 或 curl 下载时,VS Code 的扩展不起作用
- java - 需要在 Java 中决定一个球队的进球数(足球模拟器)
- wordpress - 在 ECS 上运行 Wordpress 的 docker 映像时,如何将我的数据库保存在 Wordpress 中?
- python - 使用 Python 和 OpenCV 测量图像中不规则有界对象的像素之间的距离
- javascript - 未捕获的类型错误:无法设置未定义的属性“左”
- javascript - 如何从无渲染类 React-Native 调用函数
- c# - 在 ASP.NET MVC 控制器方法中实例化 WinForm 的缺点
- typescript - 只在声明文件中为 Typescript 文件定义接口是典型的吗?