python - 使用 Keras 中的功能 API 向 GRU 模型添加遮罩层的正确方法?
问题描述
我试图弄清楚如何在 Keras 中使用带有功能 API的掩蔽层。
使用非功能性 Keras 语法,我可以轻松地创建一个 GRU 模型,它可以像这样屏蔽所有零值:
model = tf.keras.Sequential()
model.add( tf.keras.layers.Masking( mask_value = 0.0, input_shape = ( nTimeSteps, nVariables ) ) )
model.add( tf.keras.layers.GRU( 32 ) )
model.add( tf.keras.layers.Dense( 10, activation = "softmax" ) )
opt = tf.keras.optimizers.SGD( learning_rate = 0.001 )
model.compile( loss = 'categorical_crossentropy', optimizer = opt, metrics = ['accuracy'] )
到目前为止,我尝试使用功能 API 复制此模型如下所示:
x = tf.keras.layers.Masking( mask_value = 0.0, input_shape = ( nTimeSteps, nVariables ) )
x = tf.keras.layers.GRU( 32 )( x )
z = tf.keras.layers.Dense( numberOfOutputs, activation = "softmax" )( x )
model = tf.keras.Model( inputs = x, outputs = z )
opt = tf.keras.optimizers.SGD( learning_rate = 0.001 )
model.compile( loss = 'categorical_crossentropy', optimizer = opt, metrics = ['accuracy'] )
但是,它不起作用 - 它会产生以下错误:
AttributeError: 'Masking' object has no attribute 'shape'
将屏蔽层与功能 API 一起使用的正确方法是什么?
解决方案
你错过了Input
功能 API 格式的层。这是一个虚拟示例
nsamples = 10
nTimeSteps, nVariables = 6, 4
numberOfOutputs = 2
X = np.random.randint(0,6, (nsamples ,nTimeSteps, nVariables))
y = np.random.randint(0,numberOfOutputs, nsamples)
inp = tf.keras.Input(shape = ( nTimeSteps, nVariables ))
x = tf.keras.layers.Masking( mask_value = 0.0 )(inp)
x = tf.keras.layers.GRU( 32 )( x )
z = tf.keras.layers.Dense( numberOfOutputs, activation = "softmax" )( x )
model = tf.keras.Model( inputs = inp, outputs = z )
opt = tf.keras.optimizers.SGD( learning_rate = 0.001 )
model.compile( loss = 'sparse_categorical_crossentropy',
optimizer = opt, metrics = ['accuracy'] )
model.fit(X,y, epochs=3)
推荐阅读
- javascript - 如何在 html 中显示多个常量?
- php - 如何在父元素中动态显示子元素
- android - 试图在 android.R.class 中找到特定的图标。是否有所有图像的预览?
- javascript - ReactJS/CSS 利用元素的整页
- javascript - 一个元素的类可以根据浏览器是扩大还是缩小而改变?
- python - 为什么我的 QDialogButtonBox 在 QMainWindow 中不起作用?
- linear-algebra - 使用内积和范数简化瑞利商(快速问题)
- windows - 无法验证 veracrypt 可执行文件的 PGP 签名
- database - 如何在 Apache IoTDB 中建模大量相同的时间序列设备?
- python - 无法 sudo apt-get install python3-venv 或 python-venv