PixelShuffle for image upsampling

W. Shi, J. Caballero, F. Huszár, J. Totz, A. P. Aitken, R. Bishop, D. Rueckert, Z. Wang, Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network, arXiv:1609.05158 [cs.CV] (2016).

A. Aitken, C. Ledig, L. Theis, J. Caballero, Z. Wang, W. Shi, Checkerboard artifact free sub-pixel convolution, arXiv:1707.02937 [cs.CV] (2017).

Introduction

Methods for image upsampling:

  • Interpolation, nn.Upsample
  • Transpose convolution, nn.ConvTranspose2d
  • Sub-pixel convolution, nn.PixelShuffle

Transposed convolution

Transposed convolution

Blog post about checkerboard artifacts

Sub-pixel convolution

Sub-pixel convolution

Configuration

Using Oxford-IIIT Pet Dataset, https://www.robots.ox.ac.uk/~vgg/data/pets/

Imports

In [1]:
from pathlib import Path
from collections import defaultdict
import numpy as np
from PIL import Image

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms

import imgaug as ia
from imgaug import augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage

import matplotlib.pyplot as plt

Configuration

In [2]:
BASE_DIR = Path('./data')
IMAGES_DIR = BASE_DIR / 'images'
TARGET_DIR = BASE_DIR / 'annotations' / 'trimaps'

SAVE_PATH = 'weights-segmentation.pkl'

BACKGROUND_SHAPE = (160, 160)
VAL_RATIO = 0.2
BATCH_SIZE = 16
In [3]:
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)
device: cuda

File lists

In [4]:
images_list = sorted(IMAGES_DIR.glob('*.jpg'))
In [5]:
targets_list = sorted(TARGET_DIR.glob('*.png'))
In [6]:
print("Number of images:", len(images_list))
print("Number of targets:", len(targets_list))
Number of images: 7390
Number of targets: 7390
In [7]:
images_train, images_val, targets_train, targets_val = train_test_split(images_list, targets_list,
        test_size=VAL_RATIO, random_state=0, shuffle=True)

Utils

In [8]:
def tensor_to_segmap(Y, threshold=0.5):
    mask = (Y > threshold).numpy()
    mask = np.moveaxis(np.uint8(mask), 0, -1)
    return SegmentationMapsOnImage(mask, BACKGROUND_SHAPE)
In [9]:
def to_np(X):
    x = np.moveaxis(X.numpy(), 0, -1)
    x = np.uint8(x * 255)
    return x
In [10]:
def plot_img(img):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.set_axis_off()
    ax.imshow(img, interpolation='bilinear')
    plt.show()
In [11]:
def loader_show_image(loader):
    X, Y = next(iter(loader))
    img = to_np(X[0])
    segmap = tensor_to_segmap(Y[0])
    img = segmap.draw_on_image(img)[0]
    plot_img(img)
In [12]:
def dataset_show_image(dset, idx):
    img, segmap = dset[idx]
    img = segmap.draw_on_image(img)[0]
    plot_img(img)
In [13]:
def show_image_with_segmap(img, segmap, true_segmap):
    img = true_segmap.draw_on_image(img, colors=[(0, 0, 0), (0, 255, 0)])[0]
    img = segmap.draw_on_image(img)[0]
    plot_img(img)

Dataset

In [14]:
class ImagesDataset(torch.utils.data.Dataset):
    def __init__(self, images_list, targets_list):
        self.images_list = images_list
        self.targets_list = targets_list
        self.num_images = len(images_list)
        self.images, self.masks = self._load_images_masks()
                
    def __len__(self):
        return self.num_images
    
    def __getitem__(self, idx):
        X, Y = self.images[idx], self.masks[idx]
        return X, Y
    
    def _load_images_masks(self):
        print("Loading images...", end=' ')
        images = self._load_images(self.images_list, BACKGROUND_SHAPE, resample=Image.BICUBIC, mode='RGB')
        targets = self._load_images(self.targets_list, BACKGROUND_SHAPE, resample=Image.NEAREST, mode='L')
        masks = self._create_masks(targets)
        print("Done.")
        
        return images, masks
    
    def _load_images(self, file_list, shape, resample, mode):
        images = []
        for img_path in file_list:
            img = self._load_and_resize_image(img_path, shape, resample, mode)
            img_np = np.asarray(img)
            images.append(img_np)
        return images
    
    def _load_and_resize_image(self, img_path, shape, resample, mode):
        img = Image.open(img_path)
        img = img.convert(mode)

        width_height_tuple = (shape[1], shape[0])
        if img.size != width_height_tuple:
            img = img.resize(width_height_tuple, resample)
        
        return img
    
    def _create_masks(self, images):
        masks = []
        for img in images:
            mask = (img == 1)
            masks.append(SegmentationMapsOnImage(mask, img.shape))
        return masks
In [15]:
train_dset = ImagesDataset(images_train, targets_train)
val_dset = ImagesDataset(images_val, targets_val)
Loading images... Done.
Loading images... Done.
In [16]:
dataset_show_image(val_dset, 1)

Transforms & DataLoader

In [17]:
def segmap_to_tensor(segmap):
        mask = segmap.get_arr()
        mask = np.expand_dims(mask, axis=0)
        mask = torch.as_tensor(mask, dtype=torch.float)
        return mask
In [18]:
def image_to_tensor(img):
    img = torch.from_numpy(img.transpose((2, 0, 1)))
    return img.float().div(255)
In [19]:
augmenter = iaa.Sequential([
                iaa.Fliplr(0.5),
                iaa.Affine(
                    scale=(0.9, 1.1),
                    translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)},
                    rotate=(-10, 10),
                    shear=(-8, 8),
                    mode='reflect'
                ),
                iaa.Multiply((0.8, 1.2), per_channel=0.2),
                iaa.LinearContrast((0.75, 1.5)),
                iaa.GaussianBlur((0., 3.))
            ], random_order=True)
In [20]:
class CollateFN:
    def __init__(self, augmenter=None):
        self.augmenter = augmenter
    
    def __call__(self, batch):
        images, segmaps = zip(*batch)
        
        if self.augmenter is not None:
            images, segmaps = self.augmenter(images=images, segmentation_maps=segmaps)
        
        images_tens = [image_to_tensor(img) for img in images]
        segmaps_tens = [segmap_to_tensor(segmap) for segmap in segmaps]
        X = torch.stack(images_tens, dim=0)
        Y = torch.stack(segmaps_tens, dim=0)
        return X, Y
In [21]:
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=CollateFN(augmenter))
val_loader = torch.utils.data.DataLoader(val_dset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=CollateFN())

Model

Initialization according to arXiv:1707.02937 [cs.CV] :

In [23]:
def init_subpixel(weight):
    co, ci, h, w = weight.shape
    co2 = co // 4
    # initialize sub kernel
    k = torch.empty([c02, ci, h, w])
    nn.init.kaiming_uniform_(k)
    # repeat 4 times
    k = k.repeat_interleave(4, dim=0)
    weight.data.copy_(k)
In [24]:
def init_linear(m, relu=True):
    if relu: nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
    else: nn.init.xavier_uniform_(m.weight)
    if m.bias is not None: nn.init.zeros_(m.bias)
In [25]:
class ConvBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, act=True):
        padding = (kernel_size - 1) // 2
        layers = [
          nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False),
          nn.BatchNorm2d(out_channels)
        ]
        if act: layers.append(nn.ReLU(inplace=True))
        super().__init__(*layers)
    
    def reset_parameters(self):
        init_linear(self[0])
        self[1].reset_parameters()
In [26]:
class DoubleConv(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super().__init__(
            ConvBlock(in_channels, out_channels),
            ConvBlock(out_channels, out_channels)
        )
In [27]:
class Down(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super().__init__(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
In [28]:
class UpsampleBilinear(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super().__init__(
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        )
    
    def reset_parameters(self):
        init_linear(self[0])
In [29]:
class UpsampleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_t = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
    
    def forward(self, x):
        return self.conv_t(x)
    
    def reset_parameters(self):
        init_linear(self.conv_t)
In [30]:
class UpsampleShuffle(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super().__init__(
            nn.Conv2d(in_channels, out_channels * 4, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.PixelShuffle(2)
        )
        
    def reset_parameters(self):
        init_subpixel(self[0].weight)
        nn.init.zeros_(self[0].bias)
In [31]:
class Up(nn.Module):
    def __init__(self, in_channels, out_channels, upsample):
        super().__init__()
        self.up = upsample(in_channels, in_channels // 2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)
In [32]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, upsample):
        super().__init__()
        sizes = [64, 128, 256, 512, 1024]
        self.ini = DoubleConv(n_channels, sizes[0])
        self.down = nn.ModuleList([Down(sz1, sz2) for sz1, sz2 in zip(sizes, sizes[1:])])
        
        sizes.reverse()
        self.up = nn.ModuleList([Up(sz1, sz2, upsample) for sz1, sz2 in zip(sizes, sizes[1:])])
        self.out = nn.Conv2d(sizes[-1], n_classes, kernel_size=1)

    def forward(self, x):
        x = self.ini(x)
        
        xs = []
        for down_layer in self.down:
            xs.append(x)
            x = down_layer(x)
        
        for up_layer, xi in zip(self.up, reversed(xs)):
            x = up_layer(x, xi)
            
        logits = self.out(x)
        return logits

    def _reset_children(self, module):
        for m in module.children():
            if hasattr(m, 'reset_parameters'):
                m.reset_parameters()
            else:
                self._reset_children(m)
    
    def reset_parameters(self):
        self._reset_children(self)
In [33]:
model = UNet(3, 1, UpsampleShuffle).to(DEVICE)

Metrics

In [34]:
def MeanDice(logits, targets, threshold=0.5):
    probs = torch.sigmoid(logits)
    labels = (probs > threshold).float()
    num = targets.size(0)
    m1 = labels.view(num, -1)
    m2 = targets.view(num, -1)
        
    intersection = (m1 * m2).sum(1)
    cardinality = m1.sum(1) + m2.sum(1)
    dice = (2. * intersection) / (cardinality + 1e-7)
    m_dice = torch.mean(dice)
    return m_dice

Training

In [35]:
def show_or_save(fig, filename=None):
    if filename:
        fig.savefig(filename, bbox_inches='tight', pad_inches=0.05)
        plt.close(fig)
    else:
        plt.show()
In [36]:
class History:
    def __init__(self):
        self.values = defaultdict(list)

    def append(self, key, value):
        self.values[key].append(value)

    def reset(self):
        for k in self.values.keys():
            self.values[k] = []

    def _plot(self, key, line_type='-', label=None):
        if not label: label=key
        xs = np.arange(1, len(self.values[key])+1)
        self.ax.plot(xs, self.values[key], line_type, label=label)

    def plot_train_val(self, key, x_is_batch=False, ylog=False, filename=None):
        fig = plt.figure()
        self.ax = fig.add_subplot(111)
        self._plot('train ' + key, '.-', 'train')
        self._plot('val ' + key, '.-', 'val')
        self.ax.legend()
        if ylog: self.ax.set_yscale('log')
        self.ax.set_xlabel('batch' if x_is_batch else 'epoch')
        self.ax.set_ylabel(key)
        show_or_save(fig, filename)
In [37]:
class Learner:
    def __init__(self, model, loss, optimizer, train_loader, val_loader, device,
                 epoch_scheduler=None, batch_scheduler=None):
        self.model = model
        self.loss = loss
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.epoch_scheduler = epoch_scheduler
        self.batch_scheduler = batch_scheduler
        self.history = History()
    
    
    def iterate(self, loader, backward_pass=False):
        total_loss = 0.0
        num_samples = 0
        mean_dice = 0.0

        for X, Y in loader:
            X, Y = X.to(self.device), Y.to(self.device)
            Y_pred = self.model(X)
            batch_size = X.size(0)
            batch_loss = self.loss(Y_pred, Y)
            if backward_pass:
                self.optimizer.zero_grad()
                batch_loss.backward()
                self.optimizer.step()
                if self.batch_scheduler is not None:
                    self.batch_scheduler.step()
            
            Y_pred.detach_() # conserve memory
            total_loss += batch_size * batch_loss.item()
            mean_dice += batch_size * MeanDice(Y_pred, Y)
            num_samples += batch_size
    
        avg_loss = total_loss / num_samples
        mean_dice = mean_dice / num_samples
        return avg_loss, mean_dice
    
    
    def train(self):
        self.model.train()
        train_loss, train_m_dice = self.iterate(self.train_loader, backward_pass=True)
        print(f'train: loss {train_loss:.3f}, mean dice {train_m_dice:.3f}')
        self.history.append('train loss', train_loss)
        self.history.append('train mean dice', train_m_dice)

        
    def validate(self):
        self.model.eval()
        with torch.no_grad():
            val_loss, val_m_dice = self.iterate(self.val_loader)
        print(f'val: loss {val_loss:.3f}, mean dice {val_m_dice:.3f}')
        self.history.append('val loss', val_loss)
        self.history.append('val mean dice', val_m_dice)


    def fit(self, epochs):
        for i in range(epochs):
            print(f'{i+1}/{epochs}')
            self.train()
            self.validate()
            if self.epoch_scheduler is not None:
                self.epoch_scheduler.step()
In [38]:
#model.load_state_dict(torch.load(SAVE_PATH))
In [39]:
loss = nn.BCEWithLogitsLoss()
In [40]:
#optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
In [41]:
learner = Learner(model, loss, optimizer, train_loader, val_loader, DEVICE)
In [42]:
EPOCHS = 20
In [43]:
learner.batch_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-1,
                                                        steps_per_epoch=len(train_loader),
                                                        epochs=EPOCHS)
In [44]:
learner.fit(EPOCHS)
1/20
train: loss 0.483, mean dice 0.464
/home/aiserver/.virtualenvs/deeplearning/lib/python3.6/site-packages/ipykernel_launcher.py:2: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /pytorch/torch/csrc/utils/tensor_numpy.cpp:141.)
  
val: loss 0.426, mean dice 0.620
2/20
train: loss 0.421, mean dice 0.608
val: loss 0.417, mean dice 0.587
3/20
train: loss 0.400, mean dice 0.642
val: loss 0.464, mean dice 0.441
4/20
train: loss 0.369, mean dice 0.676
val: loss 0.405, mean dice 0.591
5/20
train: loss 0.338, mean dice 0.707
val: loss 0.402, mean dice 0.572
6/20
train: loss 0.315, mean dice 0.731
val: loss 0.336, mean dice 0.665
7/20
train: loss 0.298, mean dice 0.745
val: loss 0.455, mean dice 0.726
8/20
train: loss 0.283, mean dice 0.757
val: loss 0.278, mean dice 0.776
9/20
train: loss 0.269, mean dice 0.772
val: loss 0.286, mean dice 0.792
10/20
train: loss 0.256, mean dice 0.784
val: loss 0.306, mean dice 0.745
11/20
train: loss 0.247, mean dice 0.791
val: loss 0.232, mean dice 0.815
12/20
train: loss 0.235, mean dice 0.803
val: loss 0.252, mean dice 0.805
13/20
train: loss 0.222, mean dice 0.814
val: loss 0.238, mean dice 0.789
14/20
train: loss 0.218, mean dice 0.818
val: loss 0.209, mean dice 0.840
15/20
train: loss 0.207, mean dice 0.827
val: loss 0.211, mean dice 0.824
16/20
train: loss 0.197, mean dice 0.835
val: loss 0.192, mean dice 0.852
17/20
train: loss 0.193, mean dice 0.839
val: loss 0.190, mean dice 0.851
18/20
train: loss 0.184, mean dice 0.844
val: loss 0.183, mean dice 0.859
19/20
train: loss 0.182, mean dice 0.846
val: loss 0.184, mean dice 0.856
20/20
train: loss 0.181, mean dice 0.846
val: loss 0.181, mean dice 0.858
In [45]:
learner.history.plot_train_val('loss')
In [46]:
learner.history.plot_train_val('mean dice')
In [47]:
torch.save(model.state_dict(), SAVE_PATH)
Method Loss Mean Dice
UpsampleConv 0.184 0.852
UpsampleShuffle 0.181 0.858
UpsampleBilinear 0.174 0.864

Testing

In [48]:
model.load_state_dict(torch.load(SAVE_PATH))
Out[48]:
<All keys matched successfully>
In [49]:
model.eval();
In [50]:
X_test, Y_test = val_dset[1]
Y_pred  = model(image_to_tensor(X_test).unsqueeze(0).to(DEVICE))
Y_pred = Y_pred.detach().cpu().squeeze(0)
In [51]:
true_segmap =  Y_test
segmap = tensor_to_segmap(Y_pred, threshold=0.5)
In [52]:
show_image_with_segmap(X_test, segmap, true_segmap)
In [ ]: