首页 > 解决方案 > TensorFlow 检查点自定义地图

问题描述

我正在使用自定义图层对 keras 模型进行子类化。每个层都包装了一个参数字典,在生成它们的层时使用该字典。似乎这些参数字典不是在 Tensorflow 中进行训练检查点之前设置的,而是在之后设置的,这会导致错误。我不知道如何解决这个问题,因为ValueError提出的信息也提供了过时的信息(tf.contrib不再存在)。

ValueError: Unable to save the object {'units': 32, 'activation': 'tanh', 'recurrent_initializer': 'glorot_uniform', 'dropout': 0, 'return_sequences': True} (在属性上自动构建的字典包装器任务)。包装的字典在包装器之外被修改(它的最终值为 {'units': 32, 'activation': 'tanh', 'recurrent_initializer': 'glorot_uniform', 'dropout': 0, 'return_sequences': True},它的添加检查点依赖项时的值是无),这会破坏对象创建的恢复。

如果您不需要这个字典检查点,请将其包装在 tf.contrib.checkpoint.NoDependency 对象中;它将被自动解包并随后被忽略。

这是引发此问题的层的示例:

class RecurrentConfig(BaseLayer):
    '''Basic configurable recurrent layer'''
    def __init__(self, params: Dict[Any, Any], mode: ModeKeys, layer_name: str = '', **kwargs):
        self.layer_name = layer_name
        self.cell_name = params.pop('cell', 'GRU')
        self.num_layers = params.pop('num_layers', 1)
        kwargs['name'] = layer_name
        super().__init__(params, mode, **kwargs)
        if layer_name == '':
            self.layer_name = self.cell_name
        self.layers: List[layers.Layer] = stack_layers(self.params,
                                                       self.num_layers,
                                                       self.cell_name)

    def call(self, inputs: np.ndarray) -> layers.Layer:
        '''This function is a sequential/functional call to this layers logic
        Args:
            inputs: Array to be processed within this layer
        Returns:
            inputs processed through this layer'''
        processed = inputs
        for layer in self.layers:
            processed = layer(processed)
        return processed

    @staticmethod
    def default_params() -> Dict[Any, Any]:
        return{
            'units': 32,
            'recurrent_initializer': 'glorot_uniform',
            'dropout': 0,
            'recurrent_dropout': 0,
            'activation': 'tanh',
            'return_sequences': True
        }

BaseLayer.py

'''Basic ABC for a keras style layer'''

from typing import Dict, Any

from tensorflow.keras import layers
from mosaix_py.mosaix_learn.configurable import Configurable

class BaseLayer(Configurable, layers.Layer):
    '''Base configurable Keras layer'''
    def get_config(self) -> Dict[str, Any]:
        '''Return configuration dictionary as part of keras serialization'''
        config = super().get_config()
        config.update(self.params)
        return config

    @staticmethod
    def default_params() -> Dict[Any, Any]:
        raise NotImplementedError('Layer does not implement default params')

标签: pythontensorflow

解决方案


我面临的问题是我正在从字典中弹出项目,并传递给图层。图层

    self.cell_name = params.pop('cell', 'GRU')
    self.num_layers = params.pop('num_layers', 1)

将字典传递到层时,它必须在跟踪时保持不变。

我的解决方案是进一步抽象出参数解析并传入最终的字典。


推荐阅读