首页 > 解决方案 > 如何使用 tensorflow SimpleRNNCell 处理批处理数据集?

问题描述

我正在使用 Tensorflow 创建 Seq2Seq 模型。我尝试使用小批量处理数据集。当我使用batch()Tensorflow 中的方法构建数据集时,数据集形状变为(None,10). 但是,当向其提供数据时SimpleRNNCell会引发错误:

ValueError: Shape must be rank 2 but is rank 1 for 'simple_rnn_cell/MatMul_1' (op: 'MatMul') with input shapes: [10], [10,10].

代码是这样的:

    def decoder(self, input_x, real_y, encoder_outputs, training=False):

      decoder_state, cell_states = encoder_outputs, []

      predict_shape = (5, 1)
      output = tf.convert_to_tensor(np.zeros(predict_shape), dtype=tf.float32)

      for x in range(self.max_output):
        # below code raises error, here output and decoder_state shape is  (5, 1) (?, 10)
        output, decoder_state = self.decoder_rnn(output, decoder_state)

标签: tensorflowkerasdeep-learning

解决方案


推荐阅读