首页 > 解决方案 > 为什么两次预测之间的差距如此之大

问题描述

我构建了一个 Siamese Network 对四类书法风格进行分类。Siamese Network 有两个分支,损失函数由三部分组成:两个分支的两个分类损失分别是 CrossEntropyLoss,和一个带有权重的 ContrastiveLoss;问题是第一个分支的分类准确率随着训练epoch的增加而提高,可以达到95%甚至更高。但是,另一个分支几乎没有变化,在50%。问题出在哪里?

class TwinData(data.Dataset):

    def __init__(self, imgdataset, trans=None):
        self.imgdataset = imgdataset
        self.trans = trans

    def __getitem__(self, index):
        # print(self.imgdataset.imgs[0])
        img1_t = self.imgdataset.imgs[index]
        same = random.randint(0, 1)
        if same:
            while True:
                img2_t = random.choice(self.imgdataset.imgs)
                if img1_t[1] == img2_t[1]:
                    break
        else:
            while True:
                img2_t = random.choice(self.imgdataset.imgs)
                if img1_t[1] != img2_t[1]:
                    break

        img1 = Image.open(img1_t[0])
        img2 = Image.open(img2_t[0])
        img1 = img1.convert("L")
        img2 = img1.convert("L")

        if self.trans:
            img1 = self.trans(img1)
            img2 = self.trans(img2)

        return img1, img2, img1_t[1], img2_t[1], torch.from_numpy(np.array(int(img1_t[1] != img2_t[1]), dtype=np.float32))

    def __len__(self):
        return len(self.imgdataset.imgs)
class SiameseNet(nn.Module):

    def __init__(self):
        super(SiameseNet, self).__init__()
        self.net = BasicNet()

    def forward(self, x1, x2):
        output1 = self.net(x1)
        output2 = self.net(x2)
        return output1, output2

    def get_feature(self, x):
        x = self.net(x)
        return x

我在文件中定义了 BasicNet 并没有粘贴到这里。

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        distance = F.pairwise_distance(output1, output2, keepdim=True)
        loss = torch.mean((1 - label) * torch.pow(distance, 2) +
                          (label) * torch.pow(torch.clamp(self.margin - distance, min=0.0), 2))

        return loss
    for epoch in range(EPOCH):
        print('*' * 30, 'epoch {}'.format(epoch + 1), '*' * 30)
        net.train()
        running_loss, running_acc_1, running_acc_2 = 0.0, 0.0, 0.0
        for i, (img1, img2, label1, label2, target) in enumerate(twin_train_dataloader):
            img1, img2, label1, label2, target = img1.cuda(), img2.cuda(), label1.cuda(), label2.cuda(), target.cuda()
            optimizer.zero_grad()
            outputs_1, outputs_2 = net(img1, img2)
            loss_1 = criterion_1(outputs_1 + EPSILON, label1)
            loss_2 = criterion_2(outputs_2 + EPSILON, label2)
            twin_loss = criterion_3(outputs_1, outputs_2, target)
            loss = loss_1 + loss_2 + LAMBDA * twin_loss
            running_loss += loss.item() * target.size(0)
            _, preds_1 = torch.max(outputs_1, 1)
            _, preds_2 = torch.max(outputs_2, 1)
            num_correct_1 = (preds_1 == label1).sum()
            num_correct_2 = (preds_2 == label2).sum()
            running_acc_1 += num_correct_1.item()
            running_acc_2 += num_correct_2.item()
            loss.backward()
            optimizer.step()

标签: deep-learningpytorchimage-classificationsiamese-network

解决方案


推荐阅读