首页 > 解决方案 > Tensorflow 错误 - 参数单元不是 rnn 单元,缺少属性,需要方法,不可调用

问题描述

我正在尝试训练文本摘要模型,但出现此错误:

参数单元不是 RNN 单元:缺少“output_size”属性,缺少“state_size”属性,需要“zero_state”或“get_initial_state”方法,不可调用。

我不确定这里问题的根源是什么。我的张量流版本是 1.12。

以下代码的相关部分:

class BiGRUModel(object):

def __init__(self,
             source_vocab_size,
             target_vocab_size,
             buckets,
             state_size,
             num_layers,
             embedding_size,
             max_gradient,
             batch_size,
             learning_rate,
             forward_only=False,
             dtype=tf.float32):

    self.source_vocab_size = source_vocab_size
    self.target_vocab_size = target_vocab_size
    self.buckets = buckets
    self.batch_size = batch_size
    self.learning_rate = learning_rate
    self.global_step = tf.Variable(0, trainable=False, name="global_step")
    self.state_size = state_size

    self.encoder_inputs = tf.placeholder(
        tf.int32, shape=[self.batch_size, None])
    self.decoder_inputs = tf.placeholder(
        tf.int32, shape=[self.batch_size, None])
    self.decoder_targets = tf.placeholder(
        tf.int32, shape=[self.batch_size, None])
    self.encoder_len = tf.placeholder(tf.int32, shape=[self.batch_size])
    self.decoder_len = tf.placeholder(tf.int32, shape=[self.batch_size])
    self.beam_tok = tf.placeholder(tf.int32, shape=[self.batch_size])
    self.prev_att = tf.placeholder(tf.float32, shape=[self.batch_size, state_size * 2])

    encoder_fw_cell = tf.contrib.rnn.GRUCell(state_size)
    encoder_bw_cell = tf.contrib.rnn.GRUCell(state_size)
    decoder_cell = tf.contrib.rnn.GRUCell(state_size)

    if not forward_only:
        encoder_fw_cell = tf.contrib.rnn.DropoutWrapper(
            encoder_fw_cell, output_keep_prob=0.50)
        encoder_bw_cell = tf.contrib.rnn.DropoutWrapper(
            encoder_bw_cell, output_keep_prob=0.50)
        decoder_cell = tf.contrib.rnn.DropoutWrapper(
            decoder_cell, output_keep_prob=0.50)


    with tf.variable_scope("seq2seq", dtype=dtype):
        with tf.variable_scope("encoder"):

            encoder_emb = tf.get_variable(
                "embedding", [source_vocab_size, embedding_size],
                initializer=emb_init)

            encoder_inputs_emb = tf.nn.embedding_lookup(
                encoder_emb, self.encoder_inputs)

            encoder_outputs, encoder_states = \
                tf.nn.bidirectional_dynamic_rnn(
                    encoder_fw_cell, encoder_bw_cell, encoder_inputs_emb,
                    sequence_length=self.encoder_len, dtype=dtype)

        with tf.variable_scope("init_state"):
            init_state = fc_layer(tf.concat(encoder_states, 1), state_size)
            # the shape of bidirectional_dynamic_rnn is weird
            # None for batch_size
            self.init_state = init_state
            self.init_state.set_shape([self.batch_size, state_size])
            self.att_states = tf.concat(encoder_outputs, 2)
            self.att_states.set_shape([self.batch_size, None, state_size*2])

        with tf.variable_scope("attention"):
            attention = tf.contrib.seq2seq.BahdanauAttention(
                state_size, self.att_states, self.encoder_len)
            decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
                decoder_cell, attention, state_size * 2)
            wrapper_state = tf.contrib.seq2seq.AttentionWrapper(decoder_cell, attention, initial_cell_state=self.init_state)
            # wrapper_state = tf.contrib.seq2seq.AttentionWrapper(
            #     state_size=self.init_state, output_size=self.prev_att)
            decoder_initial_state = attention.zero_state(dtype, batch_size=self.batch_size * beam_width)

标签: pythonpython-3.xtensorflow

解决方案


在代码的最后一行:

wrapper_state = tf.contrib.seq2seq.AttentionWrapper(self.init_state, self.prev_att)

当 AttentionWrapper 预期以下参数时,您通过init_state并上课:prev_attAttionWrapper

__init__(
    cell,
    attention_mechanism,
    attention_layer_size=None,
    alignment_history=False,
    cell_input_fn=None,
    output_attention=True,
    initial_cell_state=None,
    name=None,
    attention_layer=None
)

推荐阅读