pytorch - 试图用另一个相同维度的张量来掩盖张量,得到“索引 1 超出尺寸 1 的维度 0 的范围”
问题描述
attn_weights = F.softmax(self.attn(torch.cat((input, hidden_cat), 2)), dim=2)
attn_weights[mask] = float('-inf')
attn_applied = torch.bmm(attn_weights.transpose(0,1),encoder_outputs.transpose(0,1)).transpose(0,1)
attn_output = torch.cat((input, attn_applied), 2)
所以我试图将掩码中所有等于 1 的索引设置为负无穷大,但是那条线
attn_weights[mask] = float('-inf')
不断抛出此异常“索引 1 超出尺寸 1 的维度 0 的范围”不确定发生了什么 attn_weights 和掩码都具有相同的维度,即 1 x 2048 x 40。
解决方案
原来掩码张量的 dtype 必须是 torch.uint8 或 torch.bool 我有它 torch.long
推荐阅读
- internet-explorer - 占位符在 IE 中无法正常工作,但适用于 Chrome
- thingsboard - 通过警报存在为地图小部件中的标记着色
- angular - Angular - 多次订阅而不触发多次调用?
- javascript - 是否可以触发位于其他元素之后的元素事件
- mysql - 从请求数据生成 XML 代码并插入数据库
- node.js - 如何在没有套接字的情况下及时发送多个 HTTP 响应(报告服务器处理状态)
- php - Symfony 4 在 Windows 10 上运行缓慢
- c++ - if 检查和 while 检查之间是否存在速度差异
- aws-code-deploy - 我应该在 CodeDeploy 的哪里添加 AppSpec.yml 文件
- matlab - 日期格式 Matlab