python - RuntimeError:张量 a (1024) 的大小必须与非单维 3 的张量 b (512) 的大小相匹配
问题描述
我正在做以下操作,
energy.masked_fill(mask == 0, float("-1e20"))
我的python痕迹如下,
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "seq_sum.py", line 418, in forward
enc_src = self.encoder(src, src_mask)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "seq_sum.py", line 71, in forward
src = layer(src, src_mask)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "seq_sum.py", line 110, in forward
_src, _ = self.self_attention(src, src, src, src_mask)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "seq_sum.py", line 191, in forward
energy = energy.masked_fill(mask == 0, float("-1e20"))
RuntimeError: The size of tensor a (1024) must match the size of tensor b (512) at non-singleton dimension 3
这些是我的注意力层代码,
Q = self.fc_q(query)
K = self.fc_k(key)
V = self.fc_v(value)
#Q = [batch size, query len, hid dim]
#K = [batch size, key len, hid dim]
#V = [batch size, value len, hid dim]
# Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
# K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
# V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).view(-1, 1024)
K = K.view(batch_size, -1, self.n_heads, self.head_dim).view(-1, 1024)
V = V.view(batch_size, -1, self.n_heads, self.head_dim).view(-1, 1024)
energy = torch.matmul(Q, K.transpose(1,0)) / self.scale
我正在按照下面的 github 代码执行我的 seq to seq 操作,seq2seq pytorch 实际测试代码可在下面的位置获得,用于测试 1024 到 1024 输出的 seq 的代码
在这里尝试的第二个示例我已经注释掉了 pos_embedding 由于具有大索引的 CUDA 错误(RuntimeError: cuda runtime error (59)
解决方案
我查看了您的代码(顺便说一下,没有运行seq_len = 10
),问题是您在代码中硬编码batch_size
为等于 1(行143
)。
看起来您尝试在其上运行模型的示例具有batch_size = 2
.
只需取消注释您编写的上一行,batch_size = query.shape[0]
一切运行正常。
推荐阅读
- python - 每年在python中生成日期范围?
- angular - IE 11 在尝试调用 HTTP 获取请求时给出“无效的调用对象”
- c# - 带有编码字符串的 URI 对象行为
- java - 如何限制用户在android中手动编辑日历
- javascript - useState 与 React 中的异步操作冲突
- ansible - 配置文件中的 Ansible 可选变量
- pyspark - 从 pyspark 数据框中删除具有相同值但在不同列中的重复行
- c# - 使用实体框架将前一天的记录复制到同一个表中
- ios - TableView:由于未捕获的异常“NSInternalInconsistencyException”错误而终止应用程序
- python - 桌面中长时间运行的后台作业