首页 > 解决方案 > 如何编码具有可变长度的分类数据,以便在 PyTorch 中提取到 nn.Embedding

问题描述

假设我为每个样本都有一个名为 movie_genre 的数据字段movie,它是从以下流派中选择的:

Action
Adventure
Animation
Comedy
...

对于每个movie,它可能包含多种流派:

mid    genres
1      Action | Adventure
2      Animation
3      Comedy | Adventure | Action

这意味着,电影的流派是一个变量列表。

如果我使用一个热向量来编码genre,Action 可以编码为 (1, 0, 0, 0),Adventure 可以编码为(0, 1, 0, 0),依此类推。

所以mid1的movie可以编码为(1, 1, 0, 0),mid2的genre可以编码为(0, 0, 1, 0),以此类推。

但是,pytorch 嵌入层nn.Embedding将包含索引的张量作为输入,而不是 one-hot 向量。那么我应该如何对数据进行编码以便可以将其提取到嵌入层中呢?

标签: machine-learningpytorchdata-miningencodeword-embedding

解决方案


目前我可以想到两种方法:

  • 将您的多标签问题转换为多类别问题。也就是说为每个标签组合创建一个新标签(例如 Action | Adventure 成为自己的标签),然后像往常一样嵌入这些新标签。
  • 分别嵌入每个类别,并将列表中出现的所有类别的嵌入相加。

编辑:您可以使用 pytorch nn.EmbeddingBag 以有效的方式执行第二个操作:https ://pytorch.org/docs/stable/nn.html?highlight=nn%20e#torch.nn.EmbeddingBag


推荐阅读