首页 > 解决方案 > 如何为 PyTorch 的 F.affine_grid 和 F.grid_sample 创建剪切矩阵?

问题描述

我需要创建一个与 autograd 兼容的剪切矩阵,适用于 B、C、H、W 张量,并为剪切值获取输入值(可能随机生成)。我怎样才能为此生成剪切矩阵?

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image


# Load image
def preprocess_simple(image_name, image_size):
    Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
    image = Image.open(image_name).convert('RGB')
    return Loader(image).unsqueeze(0)
    
# Save image   
def deprocess_simple(output_tensor, output_name):
    output_tensor.clamp_(0, 1)
    Image2PIL = transforms.ToPILImage()
    image = Image2PIL(output_tensor.squeeze(0))
    image.save(output_name)


def get_shear_mat(theta):
    ...
    return shear_mat


def shear_img(x, theta, dtype):
    shear_mat = get_shear_mat(theta)
    grid = F.affine_grid(shear_mat , x.size()).type(dtype)
    x = F.grid_sample(x, grid)
    return x


# Shear tensor
test_input = # Test image
shear_values = (3,4) # Example values
sheared_tensor = shear_img(test_input, shear_values)

标签: pythonmatrixpytorch

解决方案


m剪切因子,那么theta = atan(1/m)就是剪切角。您现在可以选择水平剪切或垂直剪切。以下是您的实现方式get_shear_mat,您可以通过设置选择水平剪切,通过设置选择ax=0垂直剪切ax=1

def get_shear_mat(theta, ax=0):
    assert ax in [0, 1]
    m = 1 / torch.tan(torch.tensor(theta))
    if ax == 0: # Horizontal shear
        shear_mat = torch.tensor([[1, m, 0],
                         [0, 1, 0]])
    else: # Vertical shear
        shear_mat = torch.tensor([[1, 0, 0],
                         [m, 1, 0]])
    return shear_mat

请注意,剪切映射只是(x,y)原始图像中的点(x+my,y)到水平剪切点和(x,y+mx)垂直剪切点的映射。这正是我们在这里通过定义shear_mat上面所做的。

一个可选的修改,shear_img以支持第一行中批处理输入的操作。还添加一个参数 - axtoshear_img来定义我们想要水平 ( ax=0) 还是垂直 ( ax=1) 剪切:

def shear_img(x, ax, theta, dtype):
    shear_mat = get_shear_mat(theta, ax)[None, ...].type(dtype).repeat(x.shape[0], 1, 1)
    grid = F.affine_grid(shear_mat , x.size()).type(dtype)
    x = F.grid_sample(x.type(dtype), grid)
    return x

让我们在图像上测试这个实现:

# Let im be a 4D tensor of shape BxCxHxW (an image or a batch of images):
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor # Set type of data
sheared_im = shear_img(im, 0, np.pi/4, dtype) #Horizontal shear by shear angle of pi/4
plt.imshow(sheared_im.squeeze(0).permute(1,2,0)/255)
plt.show()

如果im是我们的裙子舞猫:

在此处输入图像描述

那么我们的情节将是:

在此处输入图像描述

如果我们想要一个垂直剪切:

sheared_im = shear_img(im, 1, np.pi/4, dtype) # Vertical shear by shear angle of pi/4
plt.imshow(sheared_im.squeeze(0).permute(1, 2, 0)/255)
plt.show()

我们获得:

在此处输入图像描述

万岁!


推荐阅读