首页 > 解决方案 > 在我的 tf.function 中创建的 tf.Variable 已被垃圾收集

问题描述

运行环境:google colab
主存储库:https
://github.com/grausof/keras-sincnet 我是否自定义了源代码:否

我正在使用 SincNet Paper 来实现说话人识别目标。我实现了我的图层,但是当我尝试 model.fit 时,我看到了我将在下面显示的错误。

第一次运行我看到:

ValueError: in user code:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:903 fn_with_cond  *
    raise ValueError(

ValueError: A tf.Variable created inside your tf.function has been garbage-collected. Your code needs to keep Python references to variables created inside `tf.function`s.

A common way to raise this error is to create and return a variable only referenced inside your function:

@tf.function
def f():
  v = tf.Variable(1.0)
  return v

v = f()  # Crashes with this error message!

The reason this crashes is that @tf.function annotated function returns a **`tf.Tensor`** with the **value** of the variable when the function is called rather than the variable instance itself. As such there is no code holding a reference to the `v` created inside the function and Python garbage collects it.

The simplest way to fix this issue is to create variables outside the function and capture them:

v = tf.Variable(1.0)

@tf.function
def f():
  return v

f()  # <tf.Tensor: numpy=1.>
v.assign_add(1.)
f()  # <tf.Tensor: numpy=2.>

对于下一次运行,我看到:

ValueError: in user code:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:805 train_function  *
    return step_function(self, iterator)
<ipython-input-60-c4b5571aed6d>:112 call  *
    low_pass1 = 2 * self.filt_beg_freq[i] * sinc(self.filt_beg_freq[i] * self.freq_scale, self.t_right)
<ipython-input-60-c4b5571aed6d>:158 sinc  *
    y = K.concatenate([y_left, K.variable(K.ones(1)), y_right])
/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py:201 wrapper  **
    return target(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py:1521 ones
    return variable(v, dtype=dtype, name=name)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py:983 variable
    constraint=constraint)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:262 __call__
    return cls._variable_v2_call(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
    shape=shape)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:67 getter
    return captured_getter(captured_previous, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:3332 creator
    return next_creator(**kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:67 getter
    return captured_getter(captured_previous, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:3332 creator
    return next_creator(**kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:67 getter
    return captured_getter(captured_previous, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:3332 creator
    return next_creator(**kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:67 getter
    return captured_getter(captured_previous, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:731 invalid_creator_scope
    "tf.function-decorated function tried to create "

ValueError: tf.function-decorated function tried to create variables on non-first call.

我的图层是:

 from tensorflow.keras import backend as K
from tensorflow.keras.layers import Layer
from tensorflow.python.keras.utils import conv_utils
import numpy as np
import math

debug = False
from keras import initializers


class LayerNorm(Layer):
    """ Layer Normalization in the style of https://arxiv.org/abs/1607.06450 """

    def __init__(self, scale_initializer='ones', bias_initializer='zeros', **kwargs):
        super(LayerNorm, self).__init__(**kwargs)
        self.epsilon = 1e-6
        self.scale_initializer = initializers.get(scale_initializer)
        self.bias_initializer = initializers.get(bias_initializer)

    def build(self, input_shape):
        self.scale = self.add_weight(shape=(input_shape[-1],),
                                     initializer=self.scale_initializer,
                                     trainable=True,
                                     name='{}_scale'.format(self.name))
        self.bias = self.add_weight(shape=(input_shape[-1],),
                                    initializer=self.bias_initializer,
                                    trainable=True,
                                    name='{}_bias'.format(self.name))
        self.built = True

    def call(self, x, mask=None):
        mean = K.mean(x, axis=-1, keepdims=True)
        std = K.std(x, axis=-1, keepdims=True)
        norm = (x - mean) * (1 / (std + self.epsilon))
        return norm * self.scale + self.bias

    def compute_output_shape(self, input_shape):
        return input_shape


def debug_print(*objects):
    if debug:
        print(*objects)


class SincConv1D(Layer):

    def __init__(
            self,
            N_filt,
            Filt_dim,
            fs,
            **kwargs):
        self.N_filt = N_filt
        self.Filt_dim = Filt_dim
        self.fs = fs

        super(SincConv1D, self).__init__(**kwargs)

    def build(self, input_shape):
        # The filters are trainable parameters.
        self.filt_b1 = self.add_weight(
            name='filt_b1',
            shape=(self.N_filt,),
            initializer='uniform',
            trainable=True)
        self.filt_band = self.add_weight(
            name='filt_band',
            shape=(self.N_filt,),
            initializer='uniform',
            trainable=True)

        # Mel Initialization of the filterbanks
        low_freq_mel = 80
        high_freq_mel = (2595 * np.log10(1 + (self.fs / 2) / 700))  # Convert Hz to Mel
        mel_points = np.linspace(low_freq_mel, high_freq_mel, self.N_filt)  # Equally spaced in Mel scale
        f_cos = (700 * (10 ** (mel_points / 2595) - 1))  # Convert Mel to Hz
        b1 = np.roll(f_cos, 1)
        b2 = np.roll(f_cos, -1)
        b1[0] = 30
        b2[-1] = (self.fs / 2) - 100
        self.freq_scale = self.fs * 1.0
        self.set_weights([b1 / self.freq_scale, (b2 - b1) / self.freq_scale])

        # Get beginning and end frequencies of the filters.
        min_freq = 50.0
        min_band = 50.0
        self.filt_beg_freq = K.abs(self.filt_b1) + min_freq / self.freq_scale
        self.filt_end_freq = self.filt_beg_freq + (K.abs(self.filt_band) + min_band / self.freq_scale)

        # Filter window (hamming).
        n = np.linspace(0, self.Filt_dim, self.Filt_dim)
        window = 0.54 - 0.46 * K.cos(2 * math.pi * n / self.Filt_dim)
        window = K.cast(window, "float32")
        self.window = K.variable(window)
        debug_print("  window", self.window.shape)

        # TODO what is this?
        t_right_linspace = np.linspace(1, (self.Filt_dim - 1) / 2, int((self.Filt_dim - 1) / 2))
        self.t_right = K.variable(t_right_linspace / self.fs)
        debug_print("  t_right", self.t_right)

        super(SincConv1D, self).build(input_shape)  # Be sure to call this at the end

    def call(self, x, **kwargs):
        debug_print("call")
        # filters = K.zeros(shape=(N_filt, Filt_dim))

        # Compute the filters.
        output_list = []
        for i in range(self.N_filt):
            low_pass1 = 2 * self.filt_beg_freq[i] * sinc(self.filt_beg_freq[i] * self.freq_scale, self.t_right)
            low_pass2 = 2 * self.filt_end_freq[i] * sinc(self.filt_end_freq[i] * self.freq_scale, self.t_right)
            band_pass = (low_pass2 - low_pass1)
            band_pass = band_pass / K.max(band_pass)
            output_list.append(band_pass * self.window)
        filters = K.stack(output_list)  # (80, 251)
        filters = K.transpose(filters)  # (251, 80)
        filters = K.reshape(filters, (self.Filt_dim, 1,
                                      self.N_filt))  # (251,1,80) in TF: (filter_width, in_channels, out_channels) in
        # PyTorch (out_channels, in_channels, filter_width)

        '''Given an input tensor of shape [batch, in_width, in_channels] if data_format is "NWC", or [batch, 
        in_channels, in_width] if data_format is "NCW", and a filter / kernel tensor of shape [filter_width, 
        in_channels, out_channels], this op reshapes the arguments to pass them to conv2d to perform the equivalent 
        convolution operation. Internally, this op reshapes the input tensors and invokes tf.nn.conv2d. For example, 
        if data_format does not start with "NC", a tensor of shape [batch, in_width, in_channels] is reshaped to [
        batch, 1, in_width, in_channels], and the filter is reshaped to [1, filter_width, in_channels, out_channels]. 
        The result is then reshaped back to [batch, out_width, out_channels] (where out_width is a function of the 
        stride and padding as in conv2d) and returned to the caller. '''

        # Do the convolution.
        debug_print("call")
        debug_print("  x", x)
        debug_print("  filters", filters)
        out = K.conv1d(
            x,
            kernel=filters
        )
        debug_print("  out", out)

        return out

    def compute_output_shape(self, input_shape):
        new_size = conv_utils.conv_output_length(
            input_shape[1],
            self.Filt_dim,
            padding="valid",
            stride=1,
            dilation=1)
        return (input_shape[0],) + (new_size,) + (self.N_filt,)


def sinc(band, t_right):
    y_right = K.sin(2 * math.pi * band * t_right) / (2 * math.pi * band * t_right)
    # y_left = flip(y_right, 0) TODO remove if useless
    y_left = K.reverse(y_right, 0)
    y = K.concatenate([y_left, K.variable(K.ones(1)), y_right])
    return y

标签: pythontensorflowkeras

解决方案


在 Sinc 层的原始源代码中,有几行使其与 tensorflow 2.0 及更高版本不兼容,但是,您发布的代码看起来像是与 tf 2.2 一起使用的更新。我用 tf 2.0 尝试了这个新版本,但没有用。

我修改了 autors 发布的代码的先前版本,并且在 tf 2.0 中运行良好,出于某种原因,这些行:

window = K.cast(window, "float32") window = K.variable(window)

编译网络时产生了一些问题,所以我只是注释了行 window = K.variable(window) 并且工作正常。我在 tf 2.0 中适用于我的代码下方发布。

`class SincConv1D(层):

def __init__(self, N_filt, Filt_dim, fs, **kwargs):

    self.N_filt=N_filt
    self.Filt_dim=Filt_dim
    self.fs=fs
    super(SincConv1D, self).__init__()

def build(self, input_shape):

    # The filters are trainable parameters.
    self.filt_b1 = self.add_weight(
        name='filt_b1',
        shape=(self.N_filt,),
        initializer='uniform',
        trainable=True)
    self.filt_band = self.add_weight(
        name='filt_band',
        shape=(self.N_filt,),
        initializer='uniform',
        trainable=True)

    # Mel Initialization of the filterbanks
    low_freq_mel = 80
    high_freq_mel = (2595 * np.log10(1 + (self.fs / 2) / 700))  # Convert Hz to Mel
    mel_points = np.linspace(low_freq_mel, high_freq_mel, self.N_filt)  # Equally spaced in Mel scale
    f_cos = (700 * (10**(mel_points / 2595) - 1)) # Convert Mel to Hz
    b1 = np.roll(f_cos, 1)
    b2 = np.roll(f_cos, -1)
    b1[0] = 30
    b2[-1] = (self.fs / 2) - 100
    self.freq_scale=self.fs * 1.0
    self.set_weights([b1/self.freq_scale, (b2-b1)/self.freq_scale])

    super(SincConv1D, self).build(input_shape)  # Be sure to call this at the end


def call(self, x):
    # print(x.shape)
    # debug_print("call")
    #filters = K.zeros(shape=(N_filt, Filt_dim))

    # Get beginning and end frequencies of the filters.
    min_freq = 50.0
    min_band = 50.0
    filt_beg_freq = K.abs(self.filt_b1) + min_freq / self.freq_scale
    filt_end_freq = filt_beg_freq + (K.abs(self.filt_band) + min_band / self.freq_scale)

    # Filter window (hamming).
    n = np.linspace(0, self.Filt_dim, self.Filt_dim)
    window = 0.54 - 0.46 * K.cos(2 * math.pi * n / self.Filt_dim)
    window = K.cast(window, "float32")
    # window = K.variable(window)
    # debug_print("  window", window)

    # TODO what is this?
    t_right_linspace = np.linspace(1, (self.Filt_dim - 1) / 2, int((self.Filt_dim -1) / 2))
    t_right = K.variable(t_right_linspace / self.fs)
    # debug_print("  t_right", t_right)

    # Compute the filters.
    output_list = []
    for i in range(self.N_filt):
        low_pass1 = 2 * filt_beg_freq[i] * sinc(filt_beg_freq[i] * self.freq_scale, t_right)
        low_pass2 = 2 * filt_end_freq[i] * sinc(filt_end_freq[i] * self.freq_scale, t_right)
        band_pass= (low_pass2 - low_pass1)
        band_pass = band_pass / K.max(band_pass)
        output_list.append(band_pass * window)
    filters = K.stack(output_list) #(80, 251)
    filters = K.transpose(filters) #(251, 80)
    filters = K.reshape(filters, (self.Filt_dim, 1,self.N_filt))   #(251,1,80) in TF: (filter_width, in_channels, out_channels) in PyTorch (out_channels, in_channels, filter_width)
    
    '''
    Given an input tensor of shape [batch, in_width, in_channels] if data_format is "NWC", 
    or [batch, in_channels, in_width] if data_format is "NCW", and a filter / kernel tensor of shape [filter_width, in_channels, out_channels], 
    this op reshapes the arguments to pass them to conv2d to perform the equivalent convolution operation.
    Internally, this op reshapes the input tensors and invokes tf.nn.conv2d. For example, if data_format does not start with "NC", 
    a tensor of shape [batch, in_width, in_channels] is reshaped to [batch, 1, in_width, in_channels], and the filter is reshaped to 
    [1, filter_width, in_channels, out_channels]. The result is then reshaped back to [batch, out_width, out_channels] 
    (where out_width is a function of the stride and padding as in conv2d) and returned to the caller.
    '''
    

    # Do the convolution.
    # debug_print("call")
    # print("  x", x)
    # print("  filters", filters)
    out = K.conv1d(x, kernel=filters)
    # print('e')

    return out`

我不知道为什么这些行会导致冲突,但我希望这段代码对你有用。


推荐阅读