首页 > 解决方案 > pytorch autoencoder model evaluation fail

问题描述

I am literally a beginner of PyTorch. I trained an autoencoder network so that I can plot the distribution of the latent vectors (the result of encoders).

This is the code that I used for network training.

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import Dataset
from PIL import Image
import os
import glob

dir_img_decoded = '/media/dohyeong/HDD/mouth_autoencoder/dc_img_2'
if not os.path.exists(dir_img_decoded):
    os.mkdir(dir_img_decoded)

dir_check_point = '/media/dohyeong/HDD/mouth_autoencoder/ckpt_2'
if not os.path.exists(dir_check_point):
    os.mkdir(dir_check_point)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

num_epochs = 200
batch_size = 150  # up -> GPU memory increase
learning_rate = 1e-3

dir_dataset = '/media/dohyeong/HDD/mouth_autoencoder/mouth_crop/dir_normalized_mouth_cropped_images'
images = glob.glob(os.path.join(dir_dataset, '*.png'))
train_images = images[:-113]
test_images = images[-113:]

train_images.sort()
test_images.sort()





class TrumpMouthDataset(Dataset):
    def __init__(self, images):
        super(TrumpMouthDataset, self).__init__()
        self.images = images

        self.transform = transforms.Compose([
            # transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def __getitem__(self, index):
        image = Image.open(self.images[index])

        return self.transform(image)

    def __len__(self):
        return len(self.images)


train_dataset = TrumpMouthDataset(train_images)
test_dataset = TrumpMouthDataset(test_images)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)


class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(60000, 60),
            nn.ReLU(True),
            nn.Linear(60, 3),
            nn.ReLU(True),
        )
        self.decoder = nn.Sequential(
            nn.Linear(3, 60),
            nn.ReLU(True),
            nn.Linear(60, 60000),
            nn.Tanh()
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)

        return encoded, decoded


model = Autoencoder().cuda()
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
    model.to(device)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),
                             lr=learning_rate,
                             weight_decay=1e-5)

for epoch in range(num_epochs):

    total_loss = 0

    for index, imgs in enumerate(train_dataloader):
        imgs = imgs.to(device)

        # ===================forward=====================
        outputs = model(imgs)

        imgs_flatten = imgs.view(imgs.size(0), -1)
        loss = criterion(outputs, imgs_flatten)

        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        print('{} Epoch, [{}/{}] batch, loss: {:.4f}'.format(epoch, index + 1, len(train_dataloader), loss.item()))

    avg_loss = total_loss / len(train_dataset)
    print('{} Epoch, avg_loss: {:.4f}'.format(epoch, avg_loss))


    if epoch % 10 == 0:
        check_point_file = os.path.join(dir_check_point, str(epoch) + ".pth")
        torch.save(model.state_dict(), check_point_file)

After training, I tried to get encoded values using this code.

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

check_point = '/media/dohyeong/HDD/mouth_autoencoder/290.pth'
model = torch.load(check_point)

for index, imgs in enumerate(train_dataloader):

    imgs = imgs.to(device)

    # ===================evaluate=====================
    encoded, _ = model(imgs)

It finished with this error message. "TypeError: 'collections.OrderedDict' object is not callable" May I get some help?

标签: pythondeep-learningpytorch

解决方案


Hi and welcome to the PyTorch community :D

TL;DR

Change model = torch.load(check_point) to model.load_state_dict(torch.load(check_point)).


The only problem is with the line:

model = torch.load(check_point)

The way you saved the checkpoint was:

torch.save(model.state_dict(), check_point_file)

That is, you saved the model's state_dict (which is just a dictionary of the various parameters that together describe the current instance of the model) in check_point_file.

Now, in order to load it back, just reverse the process. check_point_file contains just the state_dict.

It knows nothing about the internals of the model - what it's architecture is, how it's supposed to work etc.

So, load it back:

state_dict = torch.load(check_point)

This state_dict can now be copied onto your Model instance as follows:

model.load_state_dict(state_dict)

Or, more succinctly,

model.load_state_dict(torch.load(check_point))

You got the error because the torch.load(check_point) returned the state_dict which you assigned to model.

When you subsequently called model(imgs), model was an OrderedDict object (not callable).

Hence the error.

See the Serialization Semantics Notes for more details.

Apart from that, your code sure is thorough for a beginner. Great going!


P.S. Your device agnosticity is brilliant! Perhaps you'd want to take a look at:

  1. the line model = Autoencoder().cuda()
  2. The map_location argument of torch.load()

推荐阅读