首页 > 解决方案 > 手写文本识别(CNN + LSTM + CTC)RNN解释需要

问题描述

我正在尝试理解以下代码,它位于 python 和 tensorflow 中。我正在尝试实现手写文本识别。我在这里指的是以下代码

我不明白为什么 RNN 输出通过“atrous_conv2d”

这是我的模型的架构,接受一个 CNN 输入并传入这个 RNN 进程,然后将其传递给一个 CTC。

 def build_RNN(self, rnnIn4d):

    rnnIn3d = tf.squeeze(rnnIn4d, axis=[2])  # squeeze remove 1 dimensions, here it removes the 2nd index

    n_hidden = 256
    n_layers = 2
    cells = []

    for _ in range(n_layers):
        cells.append(tf.nn.rnn_cell.LSTMCell(num_units=n_hidden))

    stacked = tf.nn.rnn_cell.MultiRNNCell(cells)  # combine the 2 LSTMCell created

    # BxTxF -> BxTx2H
    ((fw, bw), _) = tf.nn.bidirectional_dynamic_rnn(cell_fw=stacked, cell_bw=stacked, inputs=rnnIn3d,
                                                    dtype=rnnIn3d.dtype)

    # BxTxH + BxTxH -> BxTx2H -> BxTx1X2H
    concat = tf.expand_dims(tf.concat([fw, bw], 2), 2)

    # project output to chars (including blank): BxTx1x2H -> BxTx1xC -> BxTxC
    kernel = tf.Variable(tf.truncated_normal([1, 1, n_hidden * 2, len(self.char_list) + 1], stddev=0.1))
    rnn = tf.nn.atrous_conv2d(value=concat, filters=kernel, rate=1, padding='SAME')

    return tf.squeeze(rnn, axis=[2])

标签: pythontensorflowdeep-learninghandwriting

解决方案


CTC 损失层的输入格式为 B x T x C

B - 批量大小 T - 输出的最大长度(由于空白字符,最大字长的两倍) C - 字符数 + 1(空白字符)

atrous 的输入是形状 (B x T x 1 X 2T) == (batch, height ,width ,channel) 我们使用的过滤器是 (1,1,2T,C) == (height ,width ,input channel ,输出通道)

在 atrous CNN 之后,我们将得到 (B ,T ,1 ,C),这是 CTC 所需的输出

注意:我们将在我们将图像输入到 CNN 之前进行转置,因为 tf 是 row major。

速率为 1 的 atrous 与正常的卷积层相同。


推荐阅读