python - 一批 2D 张量中前 n 个分位数的二进制掩码,但每个张量都有单独的 n
问题描述
我有一个 shape 的张量 A 和一个 shape 的(100, 16, 16)
张量 B (100)
,其中 100 是批量大小。我想创建一个具有 shape 的 A 的二进制掩码(100, 16, 16)
,其中在掩码的每个元素(元素具有 shape (1, 16, 16)
)中,1
如果该元素大于计算的分位数值,则值为 else 0
。张量 B 中的每个元素依次表示 A 中每个单独元素的百分位值。如果 B 只是一个标量,我可以使用:
flat_A = torch.reshape(A, (100, -1))
quants = torch.quantile(flat_A, B, dim=1)
quants = torch.reshape(quants, (100, 1, 1))
mask = torch.where(A >= quants, 1, 0)
# quants will have shape (100, 1, 1)
问题是:如果 B 是我上面所说的形状的一维张量,(100)
我如何计算 A 中每个单独元素的百分位值?我尝试了以下方法,但结果看起来不像我预期的那样:
>>> torch.quantile(flat_A, B, dim=1).shape
torch.Size([100, 100])
>>> torch.quantile(flat_A, B, dim=0).shape
torch.Size([100, 256])
我认为结果的形状应该是(100)
,所以我可以使用mask = torch.where(A >= quants, 1, 0)
,或者我误解了它?
对于更多上下文,这个问题也是我之前在这里遇到的标量 B 值问题的扩展。
解决方案
这是使用torch.quantile()
函数的一种方式。请注意,为了简单起见,这里我使用形状为(5, 2, 2)而不是(100, 16, 16)的张量。
import torch
# Generate some data of shape (5, 2, 2)
A = torch.arange(5 * 2 * 2).reshape(5, 2, 2) + 1.0
B = torch.linspace(0, 1, 5) # 5 quantile values for each element in A
Af = A.reshape(A.shape[0], -1) # flattens A to a 2D tensor
quantiles = torch.quantile(Af, B, dim = 1, keepdim = True)
quants = quantiles[torch.arange(A.shape[0]), torch.arange(A.shape[0]), 0]
mask = (A >= quants[:, None, None]).type(torch.uint8)
这里的张量quantiles
是有形状的,torch.Size([5, 5, 1])
因为它存储了in (或in中的行B
)中每个元素的每个分位数值的阈值。由于我们有 5 个分位数,因此我们为 中的每个元素获得 5 个阈值。A
Af
A
例如,th 分位数quantiles[i, j, 0]
的阈值为or ,并且您基本上需要在批量大小或 5 范围内的值。B[i]
A[j]
Af[j]
quantiles[k, k, 0]
k
现在,为了满足您需要相应分位数 inB
和 in 元素的阈值的要求A
,只需索引出对角线元素 fromquantiles
并填充quants
具有 shape的元素torch.Size([5])
。
最后得到mask
,A
与每个元素的相应阈值进行比较。请注意,这使用与阈值进行广播的元素比较。mask
具有所需的形状torch.Size([5, 2, 2])
。
推荐阅读
- python-3.x - 对包含元组列表的 pandas DataFrame 列中的第一个元素求和返回 ValueError
- puppeteer - 无法抓取网站运行代码时出错。TimeoutError: Navigation Timeout Exceeded: 超过 30000ms
- javascript - 如何阻止用户在某个点滚动?
- reactjs - 设置用户后组件不会重新渲染
- java - 在 BouncyCastle BCFKS 密钥库中存储 X25519 密钥对
- android - 拖动时如何更新 SeekBarPreference 值?
- redis - 下划线作为redis中的键
- html - 角度材料步进器无法与角度形式的 mat-grid-list 一起正常工作
- ios - 我有一个表格视图单元格,我从视图控制器调用一个函数,该函数位于表格视图单元格中,如图所示,但它显示错误
- events - Linux:使用 entr 和 mutt 创建新文件时发送电子邮件