nlp - 使用 2D 索引张量从 3D 张量中检索元素
问题描述
我正在玩 GPT2,我有 2 个张量:
O:形状为 (B, S-1, V) 的输出张量,其中 B 是批量大小 S 是时间步数,V 是词汇量大小。这是生成模型的输出,并沿第二维进行了软最大化。
L:一个 2D 张量形状 (B, S-1),其中每个元素是每个样本的每个时间步的正确标记的索引。这基本上是标签。
我想根据张量L从张量O中提取相应正确标记的预测概率,这样我最终将得到一个二维张量形状 (B, S)。除了使用循环之外,有没有一种有效的方法来做到这一点?
解决方案
作为参考,我的回答基于这篇 Medium 文章。
本质上,您的答案在于torch.gather
,假设您的两个张量都只是常规torch.Tensor
s (或可以转换为一个)。
import torch
# Specify some arbitrary dimensions for now
B = 3
V = 6
S = 4
# Make example reproducible
torch.manual_seed(42)
# L necessarily has to be a torch.LongTensor, otherwise indexing will fail.
L = torch.randint(0, V, size=[B, S])
O = torch.rand([B, S, V])
# Now collect the results. L needs to have similar dimension,
# except in the axis you want to collect along.
X = torch.gather(O, dim=2, index=L.unsqueeze(dim=2))
# Make sure X has no "unnecessary" dimension
X = X.squeeze(dim=2)
很难看出这是否会产生完全正确的结果,这就是为什么我包含了一个随机种子,它使示例在结果中具有确定性,并且您可以轻松验证它是否可以获得所需的结果。但是,为了澄清起见,也可以使用低维张量,这会更清楚到底是做什么torch.gather
的。
请注意,torch.gather
理论上还允许您在同一行中索引多个索引。这意味着如果您获得了一个多个值正确的多类示例,您可以类似地使用L
shape的张量[B, S, number_of_correct_samples]
。
推荐阅读
- google-bigquery - 有没有办法在 Big-query 中过滤掉我的项目中的两个特定名称?
- php - Laravel 页面自动恢复数据,无需刷新
- mysql - SQL:如何从多个表中获取计数到一个查询中?
- c# - Radzen Blazor 对话框未关闭
- laravel - Laravel 8“在此服务器上找不到请求的资源 /dashboard。”
- xamarin - 如何将数据从 TimePicker 和 Editor 传递到 Xamarin Form 中的标签?
- python - Python奇异值分解不匹配顺序和符号
- excel - 使用表格标题作为单元格中的内容
- c# - MVC 显示集合:InvalidCastException:无法将“Models.ConversionRate”类型的对象转换为“System.Collections.IEnumerable”类型
- java - 将字符串拆分为具有动态长度的不同部分