pytorch - 设备 = TPU:0 中的异常:如果设备数 (1) 不同于 8,则无法复制
问题描述
我试图创建一个 gan,它将通过 kaggle 的数据集生成动漫面孔。
我在 colab 上使用 pytorch,为了更快的训练,我使用了 tpu 和 pytorch_xla
但是当我运行代码时,它会生成一个错误,并说 Exception in device=TPU:0: Cannot replicate if number of devices (1) is different from 8
这是我的代码
# -*- coding: utf-8 -*-
"""x.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1W8fEqtdMRIaiKGWvVrrYv0boMVRBE3ZS
"""
import torch
import torchvision
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu
nn = torch.nn
F = nn.functional
from google.colab import files
files.upload()
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json
! kaggle datasets download -d splcher/animefacedataset
! unzip /content/animefacedataset.zip
! mkdir images/animeface
! mv images/* images/animeface
image_size = (64, 64)
batch_size = 128
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
SERIAL_EXEC = xmp.MpSerialExecutor()
T = torchvision.transforms
get_dataset = lambda: torchvision.datasets.ImageFolder("/content/images", transform=T.Compose([
T.Resize(image_size),
T.CenterCrop(image_size),
T.ToTensor(),
T.Normalize(*stats)
])
)
dataset = SERIAL_EXEC.run(get_dataset)
sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=True
)
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
sampler=sampler,
drop_last=True
)
# Commented out IPython magic to ensure Python compatibility.
import torch
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
# %matplotlib inline
def denorm(img_tensors):
return img_tensors * stats[1][0] + stats[0][0]
def show_images(images, nmax=64):
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_xticks([]); ax.set_yticks([])
ax.imshow(make_grid(denorm(images.detach()[:nmax]), nrow=8).permute(1, 2, 0))
def show_batch(dl, nmax=64):
for images, _ in dl:
show_images(images, nmax)
break
show_batch(dataloader)
from IPython.display import HTML
class ModelBase():
def info(self, dummy_inp):
self.eval()
last_out = dummy_inp
var = ""
for idx, (name, model) in enumerate(self.named_children()):
try:
var += (f"""
<tr>
<td><b>{idx}</b></td>
<td syle="text-align: left!important;"><i>({name}): {model}</i></td>
<td>{list(last_out.shape)}</td>
<td>{list(model(last_out).shape)}</td>
</tr>
""")
last_out = model(last_out)
except Exception as e:
var += (f"""
<tr>
<td><b>{idx}</b></td>
<td syle="text-align: left!important;"><i>({name}): {model}</i></td>
<td>{list(last_out.shape)}</td>
<td class="exception"><b>Exception</b>: {e}</td>
</tr>
""")
break
self.train()
return HTML(f"""
<style>
table tr {"{"}
border-collapse: collapse;
text-align: center;
{"}"}
table tr:nth-child(even){"{"}
background-color: #f2f2f2
{"}"}
.exception{"{"}
color: red;
{"}"}
</style>
<table>
<colgroup>
<col span="1" style="width: 10%">
<col span="1" style="width: 60%">
<col span="1" style="width: 15%">
<col span="1" style="width: 15%">
</colgroup>
<tbody>
<tr>
<th><b>Index<b></th>
<th><b>Model</b></th>
<th><b>Input Shape</b></th>
<th><b>Output Shape</b></th>
</tr>
{var}
</tbody>
</table>
""")
class DiscriminatorModel(ModelBase, nn.Sequential):
def __init__(self):
super().__init__(
nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),
nn.Flatten(),
nn.Sigmoid()
)
D = xmp.MpModelWrapper(DiscriminatorModel())
D
def get_default_device():
import os
"""Pick TPU if avilable, else GPU if available, else CPU"""
if 'COLAB_TPU_ADDR' in os.environ:
return xm.xla_device()
elif torch.cuda.is_available():
return torch.device('cuda')
else:
return torch.device('cpu')
def to_device(data, device):
"""Move tensor(s) to chosen device"""
if isinstance(data, (list,tuple)):
return [to_device(x, device) for x in data]
return data.to(device)
class DeviceDataLoader():
"""Wrap a dataloader to move data to a device"""
def __init__(self, dl, device):
self.dl = dl
self.device = device
def __iter__(self):
"""Yield a batch of data after moving it to device"""
for b in self.dl:
yield to_device(b, self.device)
def __len__(self):
"""Number of batches"""
return len(self.dl)
device = get_default_device()
device
D = to_device(D, device)
D.info(torch.ones(1, 3, 64, 64).to(device))
latent_size = 128
class GeneratorModel(ModelBase, nn.Sequential):
def __init__(self):
super().__init__(
nn.ConvTranspose2d(latent_size, 512, kernel_size=4, stride=1, padding=0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
nn.Tanh()
)
G = xmp.MpModelWrapper(GeneratorModel())
G = to_device(G, device)
G.info(torch.ones(1, latent_size, 1, 1).to(device))
xb = torch.randn(batch_size, latent_size, 1, 1).to(device)
output = G(xb)
output.shape
show_images(output.cpu())
def train_discriminator(real_images, d_opt):
d_opt.zero_grad()
real_preds = D(real_images)
real_targets = torch.ones(real_images.size(0), 1, device=device)
real_loss = F.binary_cross_entropy(real_preds, real_targets)
real_score = torch.mean(real_preds).item()
latent = torch.randn(batch_size, latent_size, 1, 1, device=device)
fake_images = G(latent)
fake_targets = torch.zeros(fake_images.size(0), 1, device=device)
fake_preds = D(fake_images)
fake_loss = F.binary_cross_entropy(fake_preds, fake_targets)
fake_score = torch.mean(fake_preds).item()
loss = real_loss + fake_loss
loss.backward()
xm.optimizer_step(d_opt)
return loss.item(), real_score, fake_score
def train_generator(opt_g):
# Clear generator gradients
opt_g.zero_grad()
# Generate fake images
latent = torch.randn(batch_size, latent_size, 1, 1, device=device)
fake_images = G(latent)
# Try to fool the discriminator
preds = D(fake_images)
targets = torch.ones(batch_size, 1, device=device)
loss = F.binary_cross_entropy(preds, targets)
# Update generator weights
loss.backward()
xm.optimizer_step(opt_g)
return loss.item()
from torchvision.utils import save_image
import os
sample_dir = 'generated'
os.makedirs(sample_dir, exist_ok=True)
def save_samples(index, latent_tensors, show=True):
fake_images = G(latent_tensors)
fake_fname = 'generated-images-{0:0=4d}.png'.format(index)
save_image(denorm(fake_images), os.path.join(sample_dir, fake_fname), nrow=8)
print('Saving', fake_fname)
if show:
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_xticks([]); ax.set_yticks([])
ax.imshow(make_grid(fake_images.cpu().detach(), nrow=8).permute(1, 2, 0))
fixed_latent = torch.randn(64, latent_size, 1, 1, device=device)
save_samples(0, fixed_latent)
para_loader = pl.ParallelLoader(dataloader, [device])
def fit(rank, epochs, lr, start_idx=1):
torch.set_default_tensor_type('torch.FloatTensor')
if device == torch.device("cuda"):
torch.cuda.empty_cache()
tracker = xm.RateTracker()
# Losses & scores
losses_g = []
losses_d = []
real_scores = []
fake_scores = []
# Create optimizers
opt_d = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
opt_g = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
for epoch in range(epochs):
dataloader = para_loader.per_device_loader(device)
t = tqdm(enumerate(dataloader), desc="")
for idx, (real_images, _) in t:
# Train discriminator
loss_d, real_score, fake_score = train_discriminator(real_images.cuda(), opt_d)
# Train generator
loss_g = train_generator(opt_g)
tracker.add(batch_size)
t.set_description("[xla:{}]({}) loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}"
.format(xm.get_ordinal(), x, loss_g, loss_d, real_score, fake_score))
# Record losses & scores
losses_g.append(loss_g)
losses_d.append(loss_d)
real_scores.append(real_score)
fake_scores.append(fake_score)
# Log losses & scores (last batch)
print("Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}".format(
epoch+1, epochs, loss_g, loss_d, real_score, fake_score))
# Save generated images
if rank == 0:
save_samples(epoch+start_idx, fixed_latent, show=False)
return losses_g, losses_d, real_scores, fake_scores
lr = 0.0002
lr = lr * xm.xrt_world_size()
epochs = 25
history = xmp.spawn(fit, args=(epochs, lr), nprocs=8,
start_method='fork')
解决方案
此错误表明系统为您的作业收到了意外数量的进程。尝试调用:
history = xmp.spawn(fit, args=(epochs, lr), nprocs=1, start_method='fork')
推荐阅读
- sql - SQL Server 2016:如何获取单行视图
- amazon-web-services - 无法再使用 React native 将图像上传到 Amazon S3 存储桶
- reactjs - 如何在 KeyboardDatePicker 中显示自定义文本(占位符)
- python - 将嵌套的 Jsonl 文件转换为 CSV 格式:取消嵌套 Jsonl 并提取为 CSV
- python - 使用 Python 文件在 Robot Framework 中将变量定义为变量
- java - 如何在 Java 中使用 double[] 参数填充 setter 方法?
- python - 没有名为“火炬”的模块
- flutter - Rotation Transition 逆时针旋转颤振
- python - 为什么这个程序不交换输入列表的第一项和最后一项?
- python - SumOfLongRootToLeafPath 函数返回的值如何