首页 > 解决方案 > 这个用于文本生成的 RNN 结构是否正确?有什么要改进的?

问题描述

我用 keras 制作了一个用于文本生成的RNN 。受官示例的启发,我决定将其改编为法国诗歌,并逐字预测,而不是逐字预测。

1)我创建了训练示例和标签如下:

从一个poetry-french.txt文件中,我组成了:

相当于这个

['Un', 'poète', 'est', 'parti', 'sa', 'tombe', 'fermée','.', '\n', 'Pas', 'un', 'chant', ',', 'pas', 'un', 'mot', 'dans', 'cette', 'langue', 'aimée']

然后,我用这段代码制作了一个训练示例列表,以及对应的标签(都是一次性向量的矩阵):

maxlen = 250 # lenght (<=> number of words) of each training example
step = 30 
X = []       # List of all training examples
Y = []       # List all corresponding labels
for i in range(0, len(new_text) - maxlen, step):
    X.append(new_text[i: i + maxlen])
    Y.append(new_text[i + maxlen])


X = 

        [ 
            [
                [0..010..0]  Un
                [0..010..0]  poète
                [0..010..0]  est
                       .....
                [0..010..0]  langue
                [0..010..0]  aimée

            ]

                .......
            [
                [0..010..0] brisant
                [0..010..0] ma
                [0..010..0] lyre
                ....
                [0..010..0] n'attendait
                [0..010..0] pas
            ]
        ]



Y = [

            [0..010..0]    '\n' THIS IS the next caracter that follows the 1st training example
                                (Un poète est blablabla langue aimée \n)

                ....

            [0..010..0]


        ]

Y[i] 是语料库中 X[i] 之后的单词(X[i] 是单词列表,表示文本的摘录)。

2)我通过以下方式构建了RNN:

# Parameters
params = {
          'batch_size': 12,
          'shuffle': True,
          'word_indices': word_indices,
          'indices_word': indices_word,
          'maxlen': maxlen
          }


training_generator = DataGenerator(X, Y, **params)


model = Sequential()
model.add(LSTM(128, input_shape=(maxlen, len(words))))
model.add(Dense(len(words), activation='softmax'))

optimizer = RMSprop(learning_rate=0.01)
model.compile(loss='categorical_crossentropy', optimizer=optimizer)

model.fit_generator(generator=training_generator, steps_per_epoch = 6, epochs=22, callbacks=[print_callback])

当我进行培训时,我看到两个问题:

于是出现了两个问题:

1)你认为我的RNN整体结构正确吗?

2)当连续生成超过 3 个“\n”时,我如何“告诉”网络它做错了?

我希望我的解释很清楚,如有必要,我会编辑帖子。

是我的代码的 GitHub 链接

非常感谢 !

艾美瑞克

标签: python-3.xkerasrecurrent-neural-network

解决方案


推荐阅读