首页 > 解决方案 > 为什么我在训练数据时会收到此错误?

问题描述

为什么我在运行训练数据时收到此错误?这是我的火车代码,我正面临损失=标准(输出,标签)的错误 我不知道为什么我会面临这个错误

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

train_accuracies = []
train_loss = []
predictions = []
for epoch in range(5):
    iterations = 0
    running_loss = 0
    for i,(inputs,labels) in enumerate(train_loader):

        iterations+=1

        inputs = inputs.float()
        labels = labels.long()

        # Feed Forward
        output = net(inputs)
        # Loss Calculation
        loss = criterion(output, labels)

        running_loss = running_loss + loss.item()
        #running_loss = running_loss + loss.tolist() 
        _, prd = torch.max(output, dim = 1)
        predictions.append(prd.item())
        #predictions.extend(prd.tolist())
        accuracy = (prd == labels).float().mean()
        train_accuracies.append(accuracy.item())
        #train_accuracies.append(accuracy.tolist())
        train_loss.append(running_loss / iterations)

        #i = i.view(i.shape[0], -1)

        # Clear the gradient buffer (we don't want to accumulate gradients)
        optimizer.zero_grad()
        # Backpropagation 
        loss.backward()
        # Weight Update: w <-- w - lr * gradient
        optimizer.step()



        #print("Epoch [{}][{}/{}], Loss: {:.3f}".format(epoch, i, len(train_loader), running_loss / iterations))
        print("Epoch [{}][{}/{}], Loss: {:.3f}".format(epoch ,i , len(train_loader), running_loss))

向我展示的错误是:

RuntimeError                              Traceback (most recent call last)
<ipython-input-76-4f34dec75c72> in <module>
     15         output = net(inputs)
     16         # Loss Calculation
---> 17         loss = criterion(output, labels)
     18 
     19         running_loss = running_loss + loss.item()

RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed.  at C:\w\1\s\tmp_conda_3.7_055457\conda\conda-bld\pytorch_1565416617654\work\aten\src\THNN/generic/ClassNLLCriterion.c:94

对此有任何想法吗?

标签: pythonneural-networkpytorch

解决方案


推荐阅读