首页 > 解决方案 > 张量流2中的ConvLSTMCell

问题描述

从 1 升级到 tensorflow 版本 2 后,来自 tf.contrib 的所有模块都已折旧。

为了应用注意力方法,我需要每个细胞的状态。

最初,我在 tf 版本 1 中所做的是:


#ConvLSTMCell
convlstm_layer = tf.contrib.rnn.ConvLSTMCell(
                conv_ndims = 2,    
                input_shape = [10, 10, 32],
                output_channels = 32,
                kernel_shape = [2, 2],
                use_bias = True,
                skip_connection = False,
                forget_bias = 1.0,
                initializers = None,
                )


# Run RNN with ConvLSTMCell
outputs, state = tf.compat.v1.nn.dynamic_rnn(convlstm_layer, conv1_out, time_major = False, dtype = input.dtype)

现在,我正在尝试将其转换为 tf 版本 2 中的代码。

然而,正如我上面提到的,两个模块(tf.contrib 和 tf.compat)都被贬值了。

我找到了 tf.compat.v1.nn.dynamic_rnn 的替代品,即tf.keras.layers.rnn

但是没有这样的函数可以创建 ConvLSTMCell。有什么建议吗?

标签: pythontensorflowtensorflow2.0tf.keras

解决方案


我认为您正在寻找的内容在这里:https ://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM2D?version=stable

您可以在代码中导入它,例如:

import tensorflow as tf
conv_lstm_layer = tf.keras.layers.ConvLSTM2D(my_parameters)

推荐阅读