首页 > 解决方案 > pytorch cnn 测试结果卡住

问题描述

培训代码

if os.path.isfile(PATH):
      print("checkpoint training '{}' ...".format(PATH))
      checkpoint = torch.load(PATH)
      start_epoch = checkpoint['epoch']
      start_i = checkpoint['i']
      net.load_state_dict(checkpoint['state_dict'])
      print("=> loaded checkpoint '{}' (trained for {} epochs, {} i)".format(PATH, checkpoint['epoch'],
                                                                           checkpoint['i']))
    else:
        print('new training')

for epoch in range(num_epochs):  # loop over the dataset multiple times
    running_loss = 0.0
    for i in range(len(train_folder_list2)):
        # get the inputs; data is a list of [inputs, labels]
        # net.train()
        inputs, labels = train_input[i], train_list[i]
        inputs = torch.as_tensor(inputs).cuda()
        inputs = inputs.transpose(1, 3)
        labels = torch.as_tensor(labels).cuda()
        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)

        # zero the parameter gradients
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 1:
            save_checkpoint({
               'epoch': start_epoch + epoch + 1,
               'i': start_i + i + 1,
               'state_dict': net.state_dict(),
        })

测试代码

PATH = './checkpoint.pth'
model = Net().cuda()
if os.path.isfile(PATH):
    print('checkpoint check!')
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()

for k in range(len(train_folder_list2)):
    inputs = train_input[k]
    inputs = torch.as_tensor(inputs).cuda()
    inputs = inputs.transpose(1, 3)
    outputs = model(inputs)
    result = outputs.cpu().detach().numpy()

这是查找图像边缘的代码。

如果我运行训练代码,对其进行训练,然后用测试代码对其进行测试,它似乎在图像中找不到任何边缘。无论我放什么图像,边缘都在同一侧。

**添加 CNN代码 此外,我们添加了cnn代码为您提供信息。数据输入与图像和标签分开放在列表中。

class Net(nn.Module):
def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(293904, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 18)

def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 293904)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    x = x.view(18)

    return x

标签: deep-learningpytorchconv-neural-network

解决方案


推荐阅读