W. Shi, J. Caballero, F. Huszár, J. Totz, A. P. Aitken, R. Bishop, D. Rueckert, Z. Wang, Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network, arXiv:1609.05158 [cs.CV] (2016).
A. Aitken, C. Ledig, L. Theis, J. Caballero, Z. Wang, W. Shi, Checkerboard artifact free sub-pixel convolution, arXiv:1707.02937 [cs.CV] (2017).
Methods for image upsampling:
nn.Upsample
nn.ConvTranspose2d
nn.PixelShuffle
Using Oxford-IIIT Pet Dataset, https://www.robots.ox.ac.uk/~vgg/data/pets/
Imports
from pathlib import Path
from collections import defaultdict
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
import imgaug as ia
from imgaug import augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
import matplotlib.pyplot as plt
Configuration
BASE_DIR = Path('./data')
IMAGES_DIR = BASE_DIR / 'images'
TARGET_DIR = BASE_DIR / 'annotations' / 'trimaps'
SAVE_PATH = 'weights-segmentation.pkl'
BACKGROUND_SHAPE = (160, 160)
VAL_RATIO = 0.2
BATCH_SIZE = 16
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)
images_list = sorted(IMAGES_DIR.glob('*.jpg'))
targets_list = sorted(TARGET_DIR.glob('*.png'))
print("Number of images:", len(images_list))
print("Number of targets:", len(targets_list))
images_train, images_val, targets_train, targets_val = train_test_split(images_list, targets_list,
test_size=VAL_RATIO, random_state=0, shuffle=True)
def tensor_to_segmap(Y, threshold=0.5):
mask = (Y > threshold).numpy()
mask = np.moveaxis(np.uint8(mask), 0, -1)
return SegmentationMapsOnImage(mask, BACKGROUND_SHAPE)
def to_np(X):
x = np.moveaxis(X.numpy(), 0, -1)
x = np.uint8(x * 255)
return x
def plot_img(img):
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_axis_off()
ax.imshow(img, interpolation='bilinear')
plt.show()
def loader_show_image(loader):
X, Y = next(iter(loader))
img = to_np(X[0])
segmap = tensor_to_segmap(Y[0])
img = segmap.draw_on_image(img)[0]
plot_img(img)
def dataset_show_image(dset, idx):
img, segmap = dset[idx]
img = segmap.draw_on_image(img)[0]
plot_img(img)
def show_image_with_segmap(img, segmap, true_segmap):
img = true_segmap.draw_on_image(img, colors=[(0, 0, 0), (0, 255, 0)])[0]
img = segmap.draw_on_image(img)[0]
plot_img(img)
class ImagesDataset(torch.utils.data.Dataset):
def __init__(self, images_list, targets_list):
self.images_list = images_list
self.targets_list = targets_list
self.num_images = len(images_list)
self.images, self.masks = self._load_images_masks()
def __len__(self):
return self.num_images
def __getitem__(self, idx):
X, Y = self.images[idx], self.masks[idx]
return X, Y
def _load_images_masks(self):
print("Loading images...", end=' ')
images = self._load_images(self.images_list, BACKGROUND_SHAPE, resample=Image.BICUBIC, mode='RGB')
targets = self._load_images(self.targets_list, BACKGROUND_SHAPE, resample=Image.NEAREST, mode='L')
masks = self._create_masks(targets)
print("Done.")
return images, masks
def _load_images(self, file_list, shape, resample, mode):
images = []
for img_path in file_list:
img = self._load_and_resize_image(img_path, shape, resample, mode)
img_np = np.asarray(img)
images.append(img_np)
return images
def _load_and_resize_image(self, img_path, shape, resample, mode):
img = Image.open(img_path)
img = img.convert(mode)
width_height_tuple = (shape[1], shape[0])
if img.size != width_height_tuple:
img = img.resize(width_height_tuple, resample)
return img
def _create_masks(self, images):
masks = []
for img in images:
mask = (img == 1)
masks.append(SegmentationMapsOnImage(mask, img.shape))
return masks
train_dset = ImagesDataset(images_train, targets_train)
val_dset = ImagesDataset(images_val, targets_val)
dataset_show_image(val_dset, 1)
def segmap_to_tensor(segmap):
mask = segmap.get_arr()
mask = np.expand_dims(mask, axis=0)
mask = torch.as_tensor(mask, dtype=torch.float)
return mask
def image_to_tensor(img):
img = torch.from_numpy(img.transpose((2, 0, 1)))
return img.float().div(255)
augmenter = iaa.Sequential([
iaa.Fliplr(0.5),
iaa.Affine(
scale=(0.9, 1.1),
translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)},
rotate=(-10, 10),
shear=(-8, 8),
mode='reflect'
),
iaa.Multiply((0.8, 1.2), per_channel=0.2),
iaa.LinearContrast((0.75, 1.5)),
iaa.GaussianBlur((0., 3.))
], random_order=True)
class CollateFN:
def __init__(self, augmenter=None):
self.augmenter = augmenter
def __call__(self, batch):
images, segmaps = zip(*batch)
if self.augmenter is not None:
images, segmaps = self.augmenter(images=images, segmentation_maps=segmaps)
images_tens = [image_to_tensor(img) for img in images]
segmaps_tens = [segmap_to_tensor(segmap) for segmap in segmaps]
X = torch.stack(images_tens, dim=0)
Y = torch.stack(segmaps_tens, dim=0)
return X, Y
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=CollateFN(augmenter))
val_loader = torch.utils.data.DataLoader(val_dset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=CollateFN())
U-Net arXiv:1505.04597 [cs.CV]
Initialization according to arXiv:1707.02937 [cs.CV] :
def init_subpixel(weight):
co, ci, h, w = weight.shape
co2 = co // 4
# initialize sub kernel
k = torch.empty([c02, ci, h, w])
nn.init.kaiming_uniform_(k)
# repeat 4 times
k = k.repeat_interleave(4, dim=0)
weight.data.copy_(k)
def init_linear(m, relu=True):
if relu: nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
else: nn.init.xavier_uniform_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
class ConvBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, act=True):
padding = (kernel_size - 1) // 2
layers = [
nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm2d(out_channels)
]
if act: layers.append(nn.ReLU(inplace=True))
super().__init__(*layers)
def reset_parameters(self):
init_linear(self[0])
self[1].reset_parameters()
class DoubleConv(nn.Sequential):
def __init__(self, in_channels, out_channels):
super().__init__(
ConvBlock(in_channels, out_channels),
ConvBlock(out_channels, out_channels)
)
class Down(nn.Sequential):
def __init__(self, in_channels, out_channels):
super().__init__(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
class UpsampleBilinear(nn.Sequential):
def __init__(self, in_channels, out_channels):
super().__init__(
nn.Conv2d(in_channels, out_channels, kernel_size=1),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
)
def reset_parameters(self):
init_linear(self[0])
class UpsampleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv_t = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
def forward(self, x):
return self.conv_t(x)
def reset_parameters(self):
init_linear(self.conv_t)
class UpsampleShuffle(nn.Sequential):
def __init__(self, in_channels, out_channels):
super().__init__(
nn.Conv2d(in_channels, out_channels * 4, kernel_size=1),
nn.ReLU(inplace=True),
nn.PixelShuffle(2)
)
def reset_parameters(self):
init_subpixel(self[0].weight)
nn.init.zeros_(self[0].bias)
class Up(nn.Module):
def __init__(self, in_channels, out_channels, upsample):
super().__init__()
self.up = upsample(in_channels, in_channels // 2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, upsample):
super().__init__()
sizes = [64, 128, 256, 512, 1024]
self.ini = DoubleConv(n_channels, sizes[0])
self.down = nn.ModuleList([Down(sz1, sz2) for sz1, sz2 in zip(sizes, sizes[1:])])
sizes.reverse()
self.up = nn.ModuleList([Up(sz1, sz2, upsample) for sz1, sz2 in zip(sizes, sizes[1:])])
self.out = nn.Conv2d(sizes[-1], n_classes, kernel_size=1)
def forward(self, x):
x = self.ini(x)
xs = []
for down_layer in self.down:
xs.append(x)
x = down_layer(x)
for up_layer, xi in zip(self.up, reversed(xs)):
x = up_layer(x, xi)
logits = self.out(x)
return logits
def _reset_children(self, module):
for m in module.children():
if hasattr(m, 'reset_parameters'):
m.reset_parameters()
else:
self._reset_children(m)
def reset_parameters(self):
self._reset_children(self)
model = UNet(3, 1, UpsampleShuffle).to(DEVICE)
def MeanDice(logits, targets, threshold=0.5):
probs = torch.sigmoid(logits)
labels = (probs > threshold).float()
num = targets.size(0)
m1 = labels.view(num, -1)
m2 = targets.view(num, -1)
intersection = (m1 * m2).sum(1)
cardinality = m1.sum(1) + m2.sum(1)
dice = (2. * intersection) / (cardinality + 1e-7)
m_dice = torch.mean(dice)
return m_dice
def show_or_save(fig, filename=None):
if filename:
fig.savefig(filename, bbox_inches='tight', pad_inches=0.05)
plt.close(fig)
else:
plt.show()
class History:
def __init__(self):
self.values = defaultdict(list)
def append(self, key, value):
self.values[key].append(value)
def reset(self):
for k in self.values.keys():
self.values[k] = []
def _plot(self, key, line_type='-', label=None):
if not label: label=key
xs = np.arange(1, len(self.values[key])+1)
self.ax.plot(xs, self.values[key], line_type, label=label)
def plot_train_val(self, key, x_is_batch=False, ylog=False, filename=None):
fig = plt.figure()
self.ax = fig.add_subplot(111)
self._plot('train ' + key, '.-', 'train')
self._plot('val ' + key, '.-', 'val')
self.ax.legend()
if ylog: self.ax.set_yscale('log')
self.ax.set_xlabel('batch' if x_is_batch else 'epoch')
self.ax.set_ylabel(key)
show_or_save(fig, filename)
class Learner:
def __init__(self, model, loss, optimizer, train_loader, val_loader, device,
epoch_scheduler=None, batch_scheduler=None):
self.model = model
self.loss = loss
self.optimizer = optimizer
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.epoch_scheduler = epoch_scheduler
self.batch_scheduler = batch_scheduler
self.history = History()
def iterate(self, loader, backward_pass=False):
total_loss = 0.0
num_samples = 0
mean_dice = 0.0
for X, Y in loader:
X, Y = X.to(self.device), Y.to(self.device)
Y_pred = self.model(X)
batch_size = X.size(0)
batch_loss = self.loss(Y_pred, Y)
if backward_pass:
self.optimizer.zero_grad()
batch_loss.backward()
self.optimizer.step()
if self.batch_scheduler is not None:
self.batch_scheduler.step()
Y_pred.detach_() # conserve memory
total_loss += batch_size * batch_loss.item()
mean_dice += batch_size * MeanDice(Y_pred, Y)
num_samples += batch_size
avg_loss = total_loss / num_samples
mean_dice = mean_dice / num_samples
return avg_loss, mean_dice
def train(self):
self.model.train()
train_loss, train_m_dice = self.iterate(self.train_loader, backward_pass=True)
print(f'train: loss {train_loss:.3f}, mean dice {train_m_dice:.3f}')
self.history.append('train loss', train_loss)
self.history.append('train mean dice', train_m_dice)
def validate(self):
self.model.eval()
with torch.no_grad():
val_loss, val_m_dice = self.iterate(self.val_loader)
print(f'val: loss {val_loss:.3f}, mean dice {val_m_dice:.3f}')
self.history.append('val loss', val_loss)
self.history.append('val mean dice', val_m_dice)
def fit(self, epochs):
for i in range(epochs):
print(f'{i+1}/{epochs}')
self.train()
self.validate()
if self.epoch_scheduler is not None:
self.epoch_scheduler.step()
#model.load_state_dict(torch.load(SAVE_PATH))
loss = nn.BCEWithLogitsLoss()
#optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
learner = Learner(model, loss, optimizer, train_loader, val_loader, DEVICE)
EPOCHS = 20
learner.batch_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-1,
steps_per_epoch=len(train_loader),
epochs=EPOCHS)
learner.fit(EPOCHS)
learner.history.plot_train_val('loss')
learner.history.plot_train_val('mean dice')
torch.save(model.state_dict(), SAVE_PATH)
Method | Loss | Mean Dice |
---|---|---|
UpsampleConv | 0.184 | 0.852 |
UpsampleShuffle | 0.181 | 0.858 |
UpsampleBilinear | 0.174 | 0.864 |
model.load_state_dict(torch.load(SAVE_PATH))
model.eval();
X_test, Y_test = val_dset[1]
Y_pred = model(image_to_tensor(X_test).unsqueeze(0).to(DEVICE))
Y_pred = Y_pred.detach().cpu().squeeze(0)
true_segmap = Y_test
segmap = tensor_to_segmap(Y_pred, threshold=0.5)
show_image_with_segmap(X_test, segmap, true_segmap)