首页 > 解决方案 > 如何在 TF 2.0 上运行这个在 TF 1.0 中构建的自定义 AttentionLSTM 类?

问题描述

我正在建立在 TF 1 中构建自定义 AttentionLSTM 层的同事的工作之上。我想使用 TF 2。

我已将顶部的所有导入语句更改为from tensorflow.keras import .... 但是有两个我还没有想出如何改变。

from keras.legacy import interfaces
from keras.layers import Recurrent

两者都在AttentionLSTM类定义中使用一次,其他地方都没有。

class AttentionLSTM(Recurrent):

    @interfaces.legacy_recurrent_support
    def __init__(self, units,
                 activation='tanh',
                 recurrent_activation='hard_sigmoid',
                 attention_activation='tanh',
                 use_bias=True,
                 kernel_initializer='glorot_uniform',
                 recurrent_initializer='orthogonal',
                 attention_initializer='orthogonal',
                 bias_initializer='zeros',
                 unit_forget_bias=True,
                 kernel_regularizer=None,
                 recurrent_regularizer=None,
                 bias_regularizer=None,
                 activity_regularizer=None,
                 attention_regularizer=None,
                 kernel_constraint=None,
                 recurrent_constraint=None,
                 bias_constraint=None,
                 attention_constraint=None,
                 dropout=0.,
                 recurrent_dropout=0.,
                 return_attention=False,
                 **kwargs):
        ...

装修师是做什么的interfaces?我需要更改哪些内容才能在 TF 2 中使用此类?

注意:我认为我应该将Recurrent导入更改为,from tensorflow.keras.layers import RNN但担心会搞乱interfaces装饰器的工作。

标签: tensorflowkerastensorflow2.0tf.keras

解决方案


from keras.legacy import interfaces并从这keras.layers import Recurrent两个库中使用 Keras 2.3.1。最新的 TensorFlow 版本具有默认的 Keras 2.4.3 版本。为了使用这两个库,请将您的 Keras 降级到 2.3.1。

Tensorflow.keras 没有这样的库。

并供keras.layers import Recurrent使用tf.keras.layers.RNN

查看 Keras 发行说明


推荐阅读