首页 > 解决方案 > 设备 = TPU:0 中的异常:如果设备数 (1) 不同于 8,则无法复制

问题描述

我试图创建一个 gan,它将通过 kaggle 的数据集生成动漫面孔。

我在 colab 上使用 pytorch,为了更快的训练,我使用了 tpu 和 pytorch_xla

但是当我运行代码时,它会生成一个错误,并说 Exception in device=TPU:0: Cannot replicate if number of devices (1) is different from 8

这是我的代码

# -*- coding: utf-8 -*-
"""x.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1W8fEqtdMRIaiKGWvVrrYv0boMVRBE3ZS
"""

import torch
import torchvision

!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu

nn = torch.nn
F = nn.functional

from google.colab import files

files.upload()

! mkdir ~/.kaggle

! cp kaggle.json ~/.kaggle/

! chmod 600 ~/.kaggle/kaggle.json

! kaggle datasets download -d splcher/animefacedataset

! unzip /content/animefacedataset.zip

! mkdir images/animeface

! mv images/* images/animeface

image_size = (64, 64)
batch_size = 128
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)

SERIAL_EXEC = xmp.MpSerialExecutor()

T = torchvision.transforms

get_dataset = lambda: torchvision.datasets.ImageFolder("/content/images", transform=T.Compose([
    T.Resize(image_size),
    T.CenterCrop(image_size),
    T.ToTensor(),
    T.Normalize(*stats)
  ])
)

dataset = SERIAL_EXEC.run(get_dataset)

sampler = torch.utils.data.distributed.DistributedSampler(
    dataset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=True
)

dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=batch_size, 
                                         sampler=sampler,
                                         drop_last=True
)

# Commented out IPython magic to ensure Python compatibility.
import torch
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
# %matplotlib inline

def denorm(img_tensors):
    return img_tensors * stats[1][0] + stats[0][0]

def show_images(images, nmax=64):
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(make_grid(denorm(images.detach()[:nmax]), nrow=8).permute(1, 2, 0))

def show_batch(dl, nmax=64):
    for images, _ in dl:
        show_images(images, nmax)
        break

show_batch(dataloader)

from IPython.display import HTML
class ModelBase():
    def info(self, dummy_inp):
        self.eval()
        last_out = dummy_inp
        var = ""
        for idx, (name, model) in enumerate(self.named_children()):
          try:
            var += (f"""
              <tr>
                  <td><b>{idx}</b></td>
                  <td syle="text-align: left!important;"><i>({name}):  {model}</i></td>
                  <td>{list(last_out.shape)}</td>
                  <td>{list(model(last_out).shape)}</td>
              </tr>
            """)
            last_out = model(last_out)
          except Exception as e:
            var += (f"""
              <tr>
                  <td><b>{idx}</b></td>
                  <td syle="text-align: left!important;"><i>({name}):  {model}</i></td>
                  <td>{list(last_out.shape)}</td>
                  <td class="exception"><b>Exception</b>: {e}</td>
              </tr>
            """)
            break
        self.train()
        return HTML(f"""
          <style>
              table tr {"{"}
                  border-collapse: collapse;
                  text-align: center;
              {"}"}
              table tr:nth-child(even){"{"}
                background-color: #f2f2f2
              {"}"}
              .exception{"{"}
                color: red;
              {"}"}
          </style>
          <table>
            <colgroup>
              <col span="1" style="width: 10%">
              <col span="1" style="width: 60%">
              <col span="1" style="width: 15%">
              <col span="1" style="width: 15%">
            </colgroup>
            <tbody>
              <tr>
                  <th><b>Index<b></th>
                  <th><b>Model</b></th>
                  <th><b>Input Shape</b></th>
                  <th><b>Output Shape</b></th>
              </tr>
              {var}
            </tbody>
          </table>
          """)

class DiscriminatorModel(ModelBase, nn.Sequential):
  def __init__(self):
    super().__init__(
      nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
      nn.BatchNorm2d(64),
      nn.LeakyReLU(0.2, inplace=True),

      nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
      nn.BatchNorm2d(128),
      nn.LeakyReLU(0.2, inplace=True),

      nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
      nn.BatchNorm2d(256),
      nn.LeakyReLU(0.2, inplace=True),

      nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
      nn.BatchNorm2d(512),
      nn.LeakyReLU(0.2, inplace=True),

      nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),

      nn.Flatten(),
      nn.Sigmoid()
    )

D = xmp.MpModelWrapper(DiscriminatorModel())

D

def get_default_device():
    import os
    """Pick TPU if avilable, else GPU if available, else CPU"""
    if 'COLAB_TPU_ADDR' in os.environ:
        return xm.xla_device()
    elif torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

device = get_default_device()
device

D = to_device(D, device)

D.info(torch.ones(1, 3, 64, 64).to(device))

latent_size = 128

class GeneratorModel(ModelBase, nn.Sequential):
  def __init__(self):
    super().__init__(
      nn.ConvTranspose2d(latent_size, 512, kernel_size=4, stride=1, padding=0, bias=False),
      nn.BatchNorm2d(512),
      nn.ReLU(True),
      
      nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
      nn.BatchNorm2d(256),
      nn.ReLU(True),
      
      nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
      nn.BatchNorm2d(128),
      nn.ReLU(True),
      
      nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      
      nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
      nn.Tanh()
    )

G = xmp.MpModelWrapper(GeneratorModel())

G = to_device(G, device)

G.info(torch.ones(1, latent_size, 1, 1).to(device))

xb = torch.randn(batch_size, latent_size, 1, 1).to(device)
output = G(xb)
output.shape

show_images(output.cpu())

def train_discriminator(real_images, d_opt):
  d_opt.zero_grad()

  real_preds = D(real_images)
  real_targets = torch.ones(real_images.size(0), 1, device=device)
  real_loss = F.binary_cross_entropy(real_preds, real_targets)
  real_score = torch.mean(real_preds).item()

  latent = torch.randn(batch_size, latent_size, 1, 1, device=device)
  fake_images = G(latent)

  fake_targets = torch.zeros(fake_images.size(0), 1, device=device)
  fake_preds = D(fake_images)
  fake_loss = F.binary_cross_entropy(fake_preds, fake_targets)
  fake_score = torch.mean(fake_preds).item()

  loss = real_loss + fake_loss
  loss.backward()
  xm.optimizer_step(d_opt)

  return loss.item(), real_score, fake_score

def train_generator(opt_g):
    # Clear generator gradients
    opt_g.zero_grad()
    
    # Generate fake images
    latent = torch.randn(batch_size, latent_size, 1, 1, device=device)
    fake_images = G(latent)
    
    # Try to fool the discriminator
    preds = D(fake_images)
    targets = torch.ones(batch_size, 1, device=device)
    loss = F.binary_cross_entropy(preds, targets)
    
    # Update generator weights
    loss.backward()
    xm.optimizer_step(opt_g)
    
    return loss.item()

from torchvision.utils import save_image

import os

sample_dir = 'generated'
os.makedirs(sample_dir, exist_ok=True)

def save_samples(index, latent_tensors, show=True):
    fake_images = G(latent_tensors)
    fake_fname = 'generated-images-{0:0=4d}.png'.format(index)
    save_image(denorm(fake_images), os.path.join(sample_dir, fake_fname), nrow=8)
    print('Saving', fake_fname)
    if show:
        fig, ax = plt.subplots(figsize=(8, 8))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(fake_images.cpu().detach(), nrow=8).permute(1, 2, 0))

fixed_latent = torch.randn(64, latent_size, 1, 1, device=device)

save_samples(0, fixed_latent)

para_loader = pl.ParallelLoader(dataloader, [device])

def fit(rank, epochs, lr, start_idx=1):
    torch.set_default_tensor_type('torch.FloatTensor')
    if device == torch.device("cuda"):
      torch.cuda.empty_cache()
    tracker = xm.RateTracker()
    # Losses & scores
    losses_g = []
    losses_d = []
    real_scores = []
    fake_scores = []
    
    # Create optimizers
    opt_d = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_g = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    
    for epoch in range(epochs):
        dataloader = para_loader.per_device_loader(device)
        t = tqdm(enumerate(dataloader), desc="")
        for idx, (real_images, _) in t:
            # Train discriminator
            loss_d, real_score, fake_score = train_discriminator(real_images.cuda(), opt_d)
            # Train generator
            loss_g = train_generator(opt_g)

            tracker.add(batch_size)

            t.set_description("[xla:{}]({}) loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}"
            .format(xm.get_ordinal(), x, loss_g, loss_d, real_score, fake_score))
            
        # Record losses & scores
        losses_g.append(loss_g)
        losses_d.append(loss_d)
        real_scores.append(real_score)
        fake_scores.append(fake_score)
        
        # Log losses & scores (last batch)
        print("Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}".format(
            epoch+1, epochs, loss_g, loss_d, real_score, fake_score))
    
        # Save generated images
        if rank == 0:
          save_samples(epoch+start_idx, fixed_latent, show=False)
    
    return losses_g, losses_d, real_scores, fake_scores

lr = 0.0002
lr = lr * xm.xrt_world_size()
epochs = 25

history = xmp.spawn(fit, args=(epochs, lr), nprocs=8,
          start_method='fork')

标签: pytorchgenerative-adversarial-networktpu

解决方案


此错误表明系统为您的作业收到了意外数量的进程。尝试调用:

history = xmp.spawn(fit, args=(epochs, lr), nprocs=1, start_method='fork')


推荐阅读