首页 > 解决方案 > 如何在 Tensorflow 2 中实现稀疏嵌入,如 Pytorch Embedding(sparse=True)?

问题描述

我有一个网络有很多需要嵌入的项目。

但是,在每个训练批次中,实际使用的项目只有很小一部分。

如果我使用普通tf.keras.layers.Embedding层,它会将所有项目添加到网络参数中,从而消耗大量内存并显着降低分布式训练的速度,因为在每一步中,所有未使用的 grads 仍然是同步的。

我想要的是,在每个训练步骤中,只有实际使用的嵌入权重被添加到图中并被计算和同步。

Pytorch已经通过torch.nn.Embedding(sparse=True).

如何在 Tensorflow 2 中实现这一点?

标签: tensorflowpytorch

解决方案


我的错...检查 tf.GradientTape() 告诉我 tf.gather 的梯度已经是一个稀疏张量,所以这不需要麻烦。


推荐阅读