python - 在我的 tf.function 中创建的 tf.Variable 已被垃圾收集
问题描述
运行环境:google colab
主存储库:https
://github.com/grausof/keras-sincnet
我是否自定义了源代码:否
我正在使用 SincNet Paper 来实现说话人识别目标。我实现了我的图层,但是当我尝试 model.fit 时,我看到了我将在下面显示的错误。
第一次运行我看到:
ValueError: in user code:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:903 fn_with_cond *
raise ValueError(
ValueError: A tf.Variable created inside your tf.function has been garbage-collected. Your code needs to keep Python references to variables created inside `tf.function`s.
A common way to raise this error is to create and return a variable only referenced inside your function:
@tf.function
def f():
v = tf.Variable(1.0)
return v
v = f() # Crashes with this error message!
The reason this crashes is that @tf.function annotated function returns a **`tf.Tensor`** with the **value** of the variable when the function is called rather than the variable instance itself. As such there is no code holding a reference to the `v` created inside the function and Python garbage collects it.
The simplest way to fix this issue is to create variables outside the function and capture them:
v = tf.Variable(1.0)
@tf.function
def f():
return v
f() # <tf.Tensor: numpy=1.>
v.assign_add(1.)
f() # <tf.Tensor: numpy=2.>
对于下一次运行,我看到:
ValueError: in user code:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:805 train_function *
return step_function(self, iterator)
<ipython-input-60-c4b5571aed6d>:112 call *
low_pass1 = 2 * self.filt_beg_freq[i] * sinc(self.filt_beg_freq[i] * self.freq_scale, self.t_right)
<ipython-input-60-c4b5571aed6d>:158 sinc *
y = K.concatenate([y_left, K.variable(K.ones(1)), y_right])
/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py:201 wrapper **
return target(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py:1521 ones
return variable(v, dtype=dtype, name=name)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py:983 variable
constraint=constraint)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:262 __call__
return cls._variable_v2_call(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
shape=shape)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:67 getter
return captured_getter(captured_previous, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:3332 creator
return next_creator(**kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:67 getter
return captured_getter(captured_previous, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:3332 creator
return next_creator(**kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:67 getter
return captured_getter(captured_previous, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:3332 creator
return next_creator(**kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:67 getter
return captured_getter(captured_previous, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:731 invalid_creator_scope
"tf.function-decorated function tried to create "
ValueError: tf.function-decorated function tried to create variables on non-first call.
我的图层是:
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Layer
from tensorflow.python.keras.utils import conv_utils
import numpy as np
import math
debug = False
from keras import initializers
class LayerNorm(Layer):
""" Layer Normalization in the style of https://arxiv.org/abs/1607.06450 """
def __init__(self, scale_initializer='ones', bias_initializer='zeros', **kwargs):
super(LayerNorm, self).__init__(**kwargs)
self.epsilon = 1e-6
self.scale_initializer = initializers.get(scale_initializer)
self.bias_initializer = initializers.get(bias_initializer)
def build(self, input_shape):
self.scale = self.add_weight(shape=(input_shape[-1],),
initializer=self.scale_initializer,
trainable=True,
name='{}_scale'.format(self.name))
self.bias = self.add_weight(shape=(input_shape[-1],),
initializer=self.bias_initializer,
trainable=True,
name='{}_bias'.format(self.name))
self.built = True
def call(self, x, mask=None):
mean = K.mean(x, axis=-1, keepdims=True)
std = K.std(x, axis=-1, keepdims=True)
norm = (x - mean) * (1 / (std + self.epsilon))
return norm * self.scale + self.bias
def compute_output_shape(self, input_shape):
return input_shape
def debug_print(*objects):
if debug:
print(*objects)
class SincConv1D(Layer):
def __init__(
self,
N_filt,
Filt_dim,
fs,
**kwargs):
self.N_filt = N_filt
self.Filt_dim = Filt_dim
self.fs = fs
super(SincConv1D, self).__init__(**kwargs)
def build(self, input_shape):
# The filters are trainable parameters.
self.filt_b1 = self.add_weight(
name='filt_b1',
shape=(self.N_filt,),
initializer='uniform',
trainable=True)
self.filt_band = self.add_weight(
name='filt_band',
shape=(self.N_filt,),
initializer='uniform',
trainable=True)
# Mel Initialization of the filterbanks
low_freq_mel = 80
high_freq_mel = (2595 * np.log10(1 + (self.fs / 2) / 700)) # Convert Hz to Mel
mel_points = np.linspace(low_freq_mel, high_freq_mel, self.N_filt) # Equally spaced in Mel scale
f_cos = (700 * (10 ** (mel_points / 2595) - 1)) # Convert Mel to Hz
b1 = np.roll(f_cos, 1)
b2 = np.roll(f_cos, -1)
b1[0] = 30
b2[-1] = (self.fs / 2) - 100
self.freq_scale = self.fs * 1.0
self.set_weights([b1 / self.freq_scale, (b2 - b1) / self.freq_scale])
# Get beginning and end frequencies of the filters.
min_freq = 50.0
min_band = 50.0
self.filt_beg_freq = K.abs(self.filt_b1) + min_freq / self.freq_scale
self.filt_end_freq = self.filt_beg_freq + (K.abs(self.filt_band) + min_band / self.freq_scale)
# Filter window (hamming).
n = np.linspace(0, self.Filt_dim, self.Filt_dim)
window = 0.54 - 0.46 * K.cos(2 * math.pi * n / self.Filt_dim)
window = K.cast(window, "float32")
self.window = K.variable(window)
debug_print(" window", self.window.shape)
# TODO what is this?
t_right_linspace = np.linspace(1, (self.Filt_dim - 1) / 2, int((self.Filt_dim - 1) / 2))
self.t_right = K.variable(t_right_linspace / self.fs)
debug_print(" t_right", self.t_right)
super(SincConv1D, self).build(input_shape) # Be sure to call this at the end
def call(self, x, **kwargs):
debug_print("call")
# filters = K.zeros(shape=(N_filt, Filt_dim))
# Compute the filters.
output_list = []
for i in range(self.N_filt):
low_pass1 = 2 * self.filt_beg_freq[i] * sinc(self.filt_beg_freq[i] * self.freq_scale, self.t_right)
low_pass2 = 2 * self.filt_end_freq[i] * sinc(self.filt_end_freq[i] * self.freq_scale, self.t_right)
band_pass = (low_pass2 - low_pass1)
band_pass = band_pass / K.max(band_pass)
output_list.append(band_pass * self.window)
filters = K.stack(output_list) # (80, 251)
filters = K.transpose(filters) # (251, 80)
filters = K.reshape(filters, (self.Filt_dim, 1,
self.N_filt)) # (251,1,80) in TF: (filter_width, in_channels, out_channels) in
# PyTorch (out_channels, in_channels, filter_width)
'''Given an input tensor of shape [batch, in_width, in_channels] if data_format is "NWC", or [batch,
in_channels, in_width] if data_format is "NCW", and a filter / kernel tensor of shape [filter_width,
in_channels, out_channels], this op reshapes the arguments to pass them to conv2d to perform the equivalent
convolution operation. Internally, this op reshapes the input tensors and invokes tf.nn.conv2d. For example,
if data_format does not start with "NC", a tensor of shape [batch, in_width, in_channels] is reshaped to [
batch, 1, in_width, in_channels], and the filter is reshaped to [1, filter_width, in_channels, out_channels].
The result is then reshaped back to [batch, out_width, out_channels] (where out_width is a function of the
stride and padding as in conv2d) and returned to the caller. '''
# Do the convolution.
debug_print("call")
debug_print(" x", x)
debug_print(" filters", filters)
out = K.conv1d(
x,
kernel=filters
)
debug_print(" out", out)
return out
def compute_output_shape(self, input_shape):
new_size = conv_utils.conv_output_length(
input_shape[1],
self.Filt_dim,
padding="valid",
stride=1,
dilation=1)
return (input_shape[0],) + (new_size,) + (self.N_filt,)
def sinc(band, t_right):
y_right = K.sin(2 * math.pi * band * t_right) / (2 * math.pi * band * t_right)
# y_left = flip(y_right, 0) TODO remove if useless
y_left = K.reverse(y_right, 0)
y = K.concatenate([y_left, K.variable(K.ones(1)), y_right])
return y
解决方案
在 Sinc 层的原始源代码中,有几行使其与 tensorflow 2.0 及更高版本不兼容,但是,您发布的代码看起来像是与 tf 2.2 一起使用的更新。我用 tf 2.0 尝试了这个新版本,但没有用。
我修改了 autors 发布的代码的先前版本,并且在 tf 2.0 中运行良好,出于某种原因,这些行:
window = K.cast(window, "float32") window = K.variable(window)
编译网络时产生了一些问题,所以我只是注释了行 window = K.variable(window) 并且工作正常。我在 tf 2.0 中适用于我的代码下方发布。
`class SincConv1D(层):
def __init__(self, N_filt, Filt_dim, fs, **kwargs):
self.N_filt=N_filt
self.Filt_dim=Filt_dim
self.fs=fs
super(SincConv1D, self).__init__()
def build(self, input_shape):
# The filters are trainable parameters.
self.filt_b1 = self.add_weight(
name='filt_b1',
shape=(self.N_filt,),
initializer='uniform',
trainable=True)
self.filt_band = self.add_weight(
name='filt_band',
shape=(self.N_filt,),
initializer='uniform',
trainable=True)
# Mel Initialization of the filterbanks
low_freq_mel = 80
high_freq_mel = (2595 * np.log10(1 + (self.fs / 2) / 700)) # Convert Hz to Mel
mel_points = np.linspace(low_freq_mel, high_freq_mel, self.N_filt) # Equally spaced in Mel scale
f_cos = (700 * (10**(mel_points / 2595) - 1)) # Convert Mel to Hz
b1 = np.roll(f_cos, 1)
b2 = np.roll(f_cos, -1)
b1[0] = 30
b2[-1] = (self.fs / 2) - 100
self.freq_scale=self.fs * 1.0
self.set_weights([b1/self.freq_scale, (b2-b1)/self.freq_scale])
super(SincConv1D, self).build(input_shape) # Be sure to call this at the end
def call(self, x):
# print(x.shape)
# debug_print("call")
#filters = K.zeros(shape=(N_filt, Filt_dim))
# Get beginning and end frequencies of the filters.
min_freq = 50.0
min_band = 50.0
filt_beg_freq = K.abs(self.filt_b1) + min_freq / self.freq_scale
filt_end_freq = filt_beg_freq + (K.abs(self.filt_band) + min_band / self.freq_scale)
# Filter window (hamming).
n = np.linspace(0, self.Filt_dim, self.Filt_dim)
window = 0.54 - 0.46 * K.cos(2 * math.pi * n / self.Filt_dim)
window = K.cast(window, "float32")
# window = K.variable(window)
# debug_print(" window", window)
# TODO what is this?
t_right_linspace = np.linspace(1, (self.Filt_dim - 1) / 2, int((self.Filt_dim -1) / 2))
t_right = K.variable(t_right_linspace / self.fs)
# debug_print(" t_right", t_right)
# Compute the filters.
output_list = []
for i in range(self.N_filt):
low_pass1 = 2 * filt_beg_freq[i] * sinc(filt_beg_freq[i] * self.freq_scale, t_right)
low_pass2 = 2 * filt_end_freq[i] * sinc(filt_end_freq[i] * self.freq_scale, t_right)
band_pass= (low_pass2 - low_pass1)
band_pass = band_pass / K.max(band_pass)
output_list.append(band_pass * window)
filters = K.stack(output_list) #(80, 251)
filters = K.transpose(filters) #(251, 80)
filters = K.reshape(filters, (self.Filt_dim, 1,self.N_filt)) #(251,1,80) in TF: (filter_width, in_channels, out_channels) in PyTorch (out_channels, in_channels, filter_width)
'''
Given an input tensor of shape [batch, in_width, in_channels] if data_format is "NWC",
or [batch, in_channels, in_width] if data_format is "NCW", and a filter / kernel tensor of shape [filter_width, in_channels, out_channels],
this op reshapes the arguments to pass them to conv2d to perform the equivalent convolution operation.
Internally, this op reshapes the input tensors and invokes tf.nn.conv2d. For example, if data_format does not start with "NC",
a tensor of shape [batch, in_width, in_channels] is reshaped to [batch, 1, in_width, in_channels], and the filter is reshaped to
[1, filter_width, in_channels, out_channels]. The result is then reshaped back to [batch, out_width, out_channels]
(where out_width is a function of the stride and padding as in conv2d) and returned to the caller.
'''
# Do the convolution.
# debug_print("call")
# print(" x", x)
# print(" filters", filters)
out = K.conv1d(x, kernel=filters)
# print('e')
return out`
我不知道为什么这些行会导致冲突,但我希望这段代码对你有用。
推荐阅读
- python-3.x - 当 URL 读取时间过长时如何跳出循环
- ios - Nativescript 角度 RadSideDrawer 未定义
- c++ - 如何从向量中删除某些指定值?
- python - 带有 python 包的无服务器 AWS Lambda 层不起作用。奇怪的哈希添加到包名称
- django - 如何将对象列表发送到下一个 Django 模板?
- python - Django:根据一对多关系过滤记录
- c# - 列表中的多个条件位置 c#
- tensorflow - tensorflow 警告 - 发现未跟踪的函数,例如 lstm_cell_6_layer_call_and_return_conditional_losses
- javascript - JavaScript 是编译型语言还是解释型语言,还是两者兼而有之?
- haskell - 函数 len 中的非详尽模式