首页 > 解决方案 > 对于我来说,枚举(trainloader)中的数据需要很长时间才能执行

问题描述

我正在尝试使用 PyTorch 训练 GAN 模型,问题是代码在第二个循环(for)中花费了很多时间,我什至只取了数据集的一部分,但仍然是同样的问题。

为了更好地了解整个代码,您可以在这里找到我的笔记本:https ://colab.research.google.com/drive/1STi87M2pNjVOt-LGc6rZ5-kpAf0NwFAm?usp=sharing

以及数据集:https ://drive.google.com/drive/folders/1LpEJwa_OlirZV2mK9vdxNSWZor-AFty_?usp=sharing

有问题的代码部分位于我笔记本的最后一个单元格中。这是问题的代码(如果您对此问题有任何想法,请告诉我谢谢)

谢谢你,我感谢任何可以帮助或不提供任何帮助的人(谢谢)

import os
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import numpy as np
import matplotlib.pyplot as plt
#from dataset import NTUSkeletonDataset
from torch.utils.data import Dataset, DataLoader
#import GAN
from torch.autograd import Variable
import matplotlib.pyplot as plt
import time

# Root directory for dataset
dataroot = "Data/nturgb+d_skeletons"

# Batch size during training
batch_size = 5

# Size of z latent vector (i.e. size of generator input)
latent_dim = 20

# Number of training epochs
num_epochs = 200

# Learning rate for optimizers
lrG = 0.00005
lrD = 0.00005

clip_value = 0.01
n_critic = 20

trainset = NTUSkeletonDataset(root_dir=dataroot, pinpoint=1, merge=2)
trainloader = DataLoader(trainset, batch_size=batch_size,
                         shuffle=True, num_workers=4)

cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if cuda else "cpu")
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

generator = Gen0(latent_dim).to(device)
discriminator = Dis0().to(device)

optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=lrG)
optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=lrD)

epoch_loss = np.zeros((num_epochs, 3, len(trainloader)//n_critic+1))

for epoch in range(num_epochs):
    j = 0
    print("Boucle 1")
    epoch_start = time.time()
    for i, data in enumerate(trainloader):
        print("something")
        size = (-1, data.size(-1))
        data = data.reshape(size)
        print
        optimizer_D.zero_grad()

        real_skeleton = Variable(data.type(Tensor)).to(device)

        critic_real = -torch.mean(discriminator(real_skeleton))
        # critic_real.backward()

        # sample noise as generator input
        z = torch.randn(real_skeleton.size(0), latent_dim).to(device)

        # Generate a batch of fake skeleton
        fake_skeleton = generator(z).detach()

        critic_fake = torch.mean(discriminator(fake_skeleton))
        # critic_fake.backward()

        loss_D = critic_real + critic_fake
        loss_D.backward()

        optimizer_D.step()

        # clip weights of discriminator
        for p in discriminator.parameters():
            p.data.clamp_(-clip_value, clip_value)

        # Train the generator every n_critic iterations:
        if i % n_critic == n_critic - 1:
            optimizer_G.zero_grad()

            # Generate a batch of
            gen_skeleton = generator(z)
            # adversarial loss
            loss_G = -torch.mean(discriminator(gen_skeleton))

            loss_G.backward()
            optimizer_G.step()

            for k, l in enumerate((loss_G, critic_real, critic_fake)):
                epoch_loss[epoch, k, j] = l.item()
            j += 1

    epoch_end = time.time()
    print('[%d] time eplased: %.3f' % (epoch, epoch_end-epoch_start))
    for k, l in enumerate(('G', 'critic real', 'critic fake')):
        print('\t', l, epoch_loss[epoch, k].mean(axis=-1))

    if epoch % 20 == 19:
        m = copy.deepcopy(generator.state_dict())
        torch.save(m, 'gen0_%d.pt' % epoch)

np.save('gen0_epoch_loss.npy', epoch_loss)

标签: pythonfor-looppytorchgenerative-adversarial-network

解决方案


推荐阅读