python - Keras Word2Vec 实现
问题描述
我正在使用http://adventuresinmachinelearning.com/word2vec-keras-tutorial/中的实现来了解 word2Vec。我不明白的是为什么损失函数没有减少?
Iteration 119200, loss=0.7305528521537781
Iteration 119300, loss=0.6254740953445435
Iteration 119400, loss=0.8255964517593384
Iteration 119500, loss=0.7267132997512817
Iteration 119600, loss=0.7213149666786194
Iteration 119700, loss=0.6156617999076843
Iteration 119800, loss=0.11473365128040314
Iteration 119900, loss=0.6617216467857361
据我了解,网络是此任务中使用的标准网络:
input_target = Input((1,))
input_context = Input((1,))
embedding = Embedding(vocab_size, vector_dim, input_length=1, name='embedding')
target = embedding(input_target)
target = Reshape((vector_dim, 1))(target)
context = embedding(input_context)
context = Reshape((vector_dim, 1))(context)
dot_product = Dot(axes=1)([target, context])
dot_product = Reshape((1,))(dot_product)
output = Dense(1, activation='sigmoid')(dot_product)
model = Model(inputs=[input_target, input_context], outputs=output)
model.compile(loss='binary_crossentropy', optimizer='rmsprop') #adam??
单词来自http://mattmahoney.net/dc/text8.zip(英文文本)的大小为 10000 的词汇表
我注意到的是,有些单词是及时学习的,比如数字和文章的上下文很容易猜到,但损失从一开始就一直停留在 0.7 左右,并且随着迭代的进行,它只会随机波动。
训练部分是这样制作的(由于没有标准拟合方法,我觉得很奇怪)
arr_1 = np.zeros((1,))
arr_2 = np.zeros((1,))
arr_3 = np.zeros((1,))
for cnt in range(epochs):
idx = np.random.randint(0, len(labels)-1)
arr_1[0,] = word_target[idx]
arr_2[0,] = word_context[idx]
arr_3[0,] = labels[idx]
loss = model.train_on_batch([arr_1, arr_2], arr_3)
if cnt % 100 == 0:
print("Iteration {}, loss={}".format(cnt, loss))
我是否错过了有关此类网络的重要信息?没写的实现和上面的链接一模一样
解决方案
我遵循了相同的教程,并且在算法再次通过样本后损失下降了。请注意,损失函数仅针对当前目标和上下文词对计算。在本教程的代码示例中,一个 epoch 只是一个示例,因此您需要超过目标词和上下文词的数量才能达到损失下降的点。
我用以下行实现了培训部分
model.fit([word_target, word_context], labels, epochs=5)
请注意,这可能需要很长时间,具体取决于语料库的大小。该train_on_batch
功能使您可以更好地控制训练,您可以在训练的每一步改变批量大小或选择您选择的样本。
推荐阅读
- javascript - 等待直到内部触发函数
- javascript - 如何修复电子中的“不允许加载本地资源”错误,因为“#”转换为“%23”
- meteor - 如何连接mongodb和meteor
- java - 如何使用 mockmvc 进行错误处理测试?
- angular - 如何在文本文件中写入数据或使用 Angular 4+ 修改文本文件的内容
- laravel - 我收到错误“stdClass 类的对象无法转换为字符串”
- scala - 列表上的 groupBy 作为 LinkedHashMap 而不是 Map
- java - 检查数据库中的现有条目 + 自动增量 id -> 多个条件
- git - 使用 git difftool 从命令行在 Visual Studio Code 中启动 GIT 比较
- django - 从 Django 中的 Datatable 获取查询列表参数