python - Pytorch嵌入错误索引超出自身范围
问题描述
如上面的标题所述,它是 Pytroch 错误:“IndexError: index out of range in self”。一旦使用超过 500 行的数据集,就会出现此错误。此外,当我重新加载模型并尝试运行第二个数据集时,我收到此错误。我已经尝试了所有可能的手动设置嵌入大小的方法,所以我在网上找到的所有内容都不起作用。附上模型、优化器和运行时,非常感谢您的帮助。
class Model(nn.Module):
def __init__(self, embedding_size, num_numerical_cols, output_size, layers, p=0.4):
super().__init__()
self.all_embeddings = nn.ModuleList([nn.Embedding(ni, nf) for ni, nf in embedding_size])
self.embedding_dropout = nn.Dropout(p)
self.batch_norm_num = nn.BatchNorm1d(num_numerical_cols)
all_layers = []
num_categorical_cols = sum((nf for ni, nf in embedding_size))
input_size = num_categorical_cols + num_numerical_cols
for i in layers:
all_layers.append(nn.Linear(input_size, i))
all_layers.append(nn.ReLU(inplace=True))
all_layers.append(nn.BatchNorm1d(i))
all_layers.append(nn.Dropout(p))
input_size = i
all_layers.append(nn.Linear(layers[-1], output_size))
self.layers = nn.Sequential(*all_layers)
def forward(self, x_categorical, x_numerical):
embeddings = []
for i,e in enumerate(self.all_embeddings):
embeddings.append(e(x_categorical[:,i]))
x = torch.cat(embeddings, 1)
x = self.embedding_dropout(x)
x_numerical = self.batch_norm_num(x_numerical)
x = torch.cat([x, x_numerical], 1)
x = self.layers(x)
return x
model = Model(categorical_embedding_sizes, numerical_data.shape[1], 5, [400,100,50], p=0.4)
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 100
aggregated_losses = []
for i in range(epochs):
i += 1
y_pred = model(categorical_train_data, numerical_train_data)
single_loss = loss_function(y_pred, train_outputs)
aggregated_losses.append(single_loss)
print(f'epoch: {i:3} loss: {single_loss.item():10.8f}')
optimizer.zero_grad()
single_loss.backward()
optimizer.step()
print(f'epoch: {i:3} loss: {single_loss.item():10.10f}')
这是描述的错误:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-157-202810d3193a> in <module>
4 for i in range(epochs):
5 i += 1
----> 6 y_pred = model(categorical_train_data, numerical_train_data)
7 single_loss = loss_function(y_pred, train_outputs)
8 aggregated_losses.append(single_loss)
C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
548 result = self._slow_forward(*input, **kwargs)
549 else:
--> 550 result = self.forward(*input, **kwargs)
551 for hook in self._forward_hooks.values():
552 hook_result = hook(self, input, result)
<ipython-input-117-fd6404aba4b5> in forward(self, x_categorical, x_numerical)
25 embeddings = []
26 for i,e in enumerate(self.all_embeddings):
---> 27 embeddings.append(e(x_categorical[:,i]))
28 x = torch.cat(embeddings, 1)
29 x = self.embedding_dropout(x)
C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
548 result = self._slow_forward(*input, **kwargs)
549 else:
--> 550 result = self.forward(*input, **kwargs)
551 for hook in self._forward_hooks.values():
552 hook_result = hook(self, input, result)
C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\sparse.py in forward(self, input)
112 return F.embedding(
113 input, self.weight, self.padding_idx, self.max_norm,
--> 114 self.norm_type, self.scale_grad_by_freq, self.sparse)
115
116 def extra_repr(self):
C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
1722 # remove once script supports set_grad_enabled
1723 _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 1724 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
1725
1726
IndexError: index out of range in self
解决方案
推荐阅读
- angular - Angular 6不显示webp图像
- java - Java FX 更改 ImageView 图像或图像 URL?
- ruby-on-rails - `PG::CONNECTIONBad: FATAL: 角色“my_username”不存在` 但是“my_username”在 psql 中被列为用户
- ip - 使用交换服务器发送邮件
- amazon-web-services - 如何在 Elastic Beanstalk 中使用来自第三方提供商的域名
- interface-builder - 我不能在导航栏和表格视图之间添加 UIView 有什么原因吗?
- eclipse - Scala IDE在启动时抛出错误
- swift - 将 UITableView 约束到 InputAccessory
- android - 可以在firebase中有一个包含搜索查询吗?
- google-sheets - Google Sheet - 查找范围/列之间的部分匹配