Imports
import numpy as np
from collections import defaultdict
from functools import partial
import matplotlib.pyplot as plt
import warnings
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torchvision import datasets
import torchvision.transforms.functional as TF
import albumentations as A
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
import ignite.metrics
import ignite.contrib.handlers
Configuration
DATA_DIR='./data'
IMAGE_SIZE = 256
NUM_WORKERS = 8
BATCH_SIZE = 32
EPOCHS = 80
LEARNING_RATE = 1e-2
WEIGHT_DECAY = 1e-2
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)
device: cuda
warnings.simplefilter("ignore")
class ConvBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, act=True):
padding = (kernel_size - 1) // 2
layers = [
nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm2d(out_channels)
]
if act: layers.append(nn.ReLU(inplace=True))
super().__init__(*layers)
class Branch(nn.Module):
def __init__(self, channels, block=None):
super().__init__()
middle_channels = channels * 2
self.residual = nn.Sequential(
nn.MaxPool2d(2),
ConvBlock(channels, middle_channels, 3),
block(middle_channels) if block is not None else nn.Identity(),
ConvBlock(middle_channels, channels, 3, act=False),
nn.Upsample(scale_factor=2, mode='nearest')
)
self.gamma = nn.Parameter(torch.zeros(1))
self.act = nn.ReLU(inplace=True)
def forward(self, x):
out = x + self.gamma * self.residual(x)
return self.act(out)
def get_branch(num):
return partial(Branch, block = get_branch(num - 1)) if num > 0 else None
class Net(nn.Sequential):
def __init__(self, classes, num_downsamplings, channels=32, in_channels=3):
super().__init__(
ConvBlock(in_channels, channels, 3),
Branch(channels, get_branch(num_downsamplings - 1)),
ConvBlock(channels, channels, 3),
nn.Conv2d(channels, classes, 1)
)
def init_linear(m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
model = Net(classes=1, num_downsamplings=5).to(DEVICE)
model.apply(init_linear);
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
Number of parameters: 12,586,822
model
Net( (0): ConvBlock( (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) ) (1): Branch( (residual): Sequential( (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (1): ConvBlock( (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) ) (2): Branch( (residual): Sequential( (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (1): ConvBlock( (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) ) (2): Branch( (residual): Sequential( (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (1): ConvBlock( (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) ) (2): Branch( (residual): Sequential( (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (1): ConvBlock( (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) ) (2): Branch( (residual): Sequential( (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (1): ConvBlock( (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) ) (2): Identity() (3): ConvBlock( (0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (4): Upsample(scale_factor=2.0, mode=nearest) ) (act): ReLU(inplace=True) ) (3): ConvBlock( (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (4): Upsample(scale_factor=2.0, mode=nearest) ) (act): ReLU(inplace=True) ) (3): ConvBlock( (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (4): Upsample(scale_factor=2.0, mode=nearest) ) (act): ReLU(inplace=True) ) (3): ConvBlock( (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (4): Upsample(scale_factor=2.0, mode=nearest) ) (act): ReLU(inplace=True) ) (3): ConvBlock( (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (4): Upsample(scale_factor=2.0, mode=nearest) ) (act): ReLU(inplace=True) ) (2): ConvBlock( (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) ) (3): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1)) )
class TransformFn:
def __init__(self, transform):
self.transform = transform
def __call__(self, image, target):
image = np.asarray(image)
target = np.asarray(target)
mask = (target == 1)
mask = mask.astype(np.uint8, copy=False)
transformed = self.transform(image=image, mask=mask)
image, mask = transformed['image'], transformed['mask']
image = TF.to_tensor(image)
mask = np.expand_dims(mask, axis=0)
mask = torch.as_tensor(mask, dtype=torch.float)
return image, mask
train_transform = A.Compose([
A.Resize(width=IMAGE_SIZE, height=IMAGE_SIZE),
A.HorizontalFlip(p=0.5),
A.Affine(
scale=(0.9, 1.1),
translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)},
rotate=(-10, 10),
shear=(-8, 8),
p=1
),
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1),
])
val_transform = A.Compose([
A.Resize(width=IMAGE_SIZE, height=IMAGE_SIZE)
])
train_dset = datasets.OxfordIIITPet(root=DATA_DIR, split="trainval", target_types='segmentation', download=True,
transforms=TransformFn(train_transform))
val_dset = datasets.OxfordIIITPet(root=DATA_DIR, split="test", target_types='segmentation', download=True,
transforms=TransformFn(val_transform))
Data Loader
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_WORKERS, pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_dset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, pin_memory=True)
def plot_image(image):
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_axis_off()
ax.imshow(image)
plt.show()
def plot_tensor_image(image):
image = TF.to_pil_image(image)
plot_image(image)
def plot_image_mask(image, mask, color='red'):
image = TF.convert_image_dtype(image, dtype=torch.uint8)
mask = mask.to(torch.bool)
image_masks = torchvision.utils.draw_segmentation_masks(image=image, masks=mask, alpha=0.5, colors=color)
plot_tensor_image(image_masks)
plot_image_mask(*val_dset[0])
class MeanDice(ignite.metrics.Metric):
def __init__(self, threshold=0.5, smooth=0.01, output_transform=lambda x: x, device="cpu"):
self.threshold = threshold
self.smooth = smooth
super().__init__(output_transform=output_transform, device=device)
def reset(self):
self._dice_sum = 0.
self._num_examples = 0
super().reset()
def update(self, data):
outputs, targets = data[0].detach(), data[1].detach()
probs = torch.sigmoid(outputs)
labels = (probs > self.threshold).float()
num = outputs.size(0)
m1 = labels.view(num, -1)
m2 = targets.view(num, -1)
intersection = (m1 * m2).sum(1)
cardinality = m1.sum(1) + m2.sum(1)
dice = (2. * intersection + self.smooth) / (cardinality + self.smooth)
self._dice_sum += torch.sum(dice).item()
self._num_examples += num
def compute(self):
mean_dice = self._dice_sum / self._num_examples
return mean_dice
loss = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
steps_per_epoch=len(train_loader), epochs=EPOCHS)
Trainer
trainer = create_supervised_trainer(model, optimizer, loss, device=DEVICE)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step());
ignite.metrics.RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
Evaluator
val_metrics = {"mean Dice": MeanDice(), "loss": ignite.metrics.Loss(loss)}
evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=DEVICE)
history = defaultdict(list)
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
train_state = engine.state
epoch = train_state.epoch
max_epochs = train_state.max_epochs
train_loss = train_state.metrics["loss"]
history['train loss'].append(train_loss)
evaluator.run(val_loader)
val_metrics = evaluator.state.metrics
val_loss = val_metrics["loss"]
val_mean_dice = val_metrics["mean Dice"]
history['val loss'].append(val_loss)
history['val mean Dice'].append(val_mean_dice)
print("{}/{} - train: loss {:.3f}; val: loss {:.3f} mean Dice {:.3f}".format(
epoch, max_epochs, train_loss, val_loss, val_mean_dice))
trainer.run(train_loader, max_epochs=EPOCHS);
1/80 - train: loss 0.795; val: loss 0.661 mean Dice 0.522 2/80 - train: loss 0.483; val: loss 0.413 mean Dice 0.673 3/80 - train: loss 0.354; val: loss 0.383 mean Dice 0.707 4/80 - train: loss 0.321; val: loss 0.308 mean Dice 0.750 5/80 - train: loss 0.290; val: loss 0.289 mean Dice 0.784 6/80 - train: loss 0.268; val: loss 0.282 mean Dice 0.783 7/80 - train: loss 0.256; val: loss 0.254 mean Dice 0.803 8/80 - train: loss 0.247; val: loss 0.336 mean Dice 0.765 9/80 - train: loss 0.248; val: loss 0.252 mean Dice 0.799 10/80 - train: loss 0.232; val: loss 0.266 mean Dice 0.786 11/80 - train: loss 0.214; val: loss 0.279 mean Dice 0.793 12/80 - train: loss 0.213; val: loss 0.216 mean Dice 0.819 13/80 - train: loss 0.204; val: loss 0.246 mean Dice 0.818 14/80 - train: loss 0.205; val: loss 0.280 mean Dice 0.802 15/80 - train: loss 0.196; val: loss 0.234 mean Dice 0.790 16/80 - train: loss 0.187; val: loss 0.211 mean Dice 0.831 17/80 - train: loss 0.184; val: loss 0.204 mean Dice 0.847 18/80 - train: loss 0.176; val: loss 0.196 mean Dice 0.846 19/80 - train: loss 0.174; val: loss 0.279 mean Dice 0.733 20/80 - train: loss 0.166; val: loss 0.267 mean Dice 0.807 21/80 - train: loss 0.171; val: loss 0.208 mean Dice 0.844 22/80 - train: loss 0.163; val: loss 0.194 mean Dice 0.852 23/80 - train: loss 0.156; val: loss 0.210 mean Dice 0.826 24/80 - train: loss 0.149; val: loss 0.173 mean Dice 0.870 25/80 - train: loss 0.149; val: loss 0.182 mean Dice 0.851 26/80 - train: loss 0.146; val: loss 0.181 mean Dice 0.866 27/80 - train: loss 0.142; val: loss 0.276 mean Dice 0.813 28/80 - train: loss 0.144; val: loss 0.175 mean Dice 0.852 29/80 - train: loss 0.138; val: loss 0.167 mean Dice 0.867 30/80 - train: loss 0.133; val: loss 0.161 mean Dice 0.872 31/80 - train: loss 0.134; val: loss 0.199 mean Dice 0.858 32/80 - train: loss 0.126; val: loss 0.152 mean Dice 0.877 33/80 - train: loss 0.120; val: loss 0.171 mean Dice 0.862 34/80 - train: loss 0.120; val: loss 0.211 mean Dice 0.851 35/80 - train: loss 0.120; val: loss 0.159 mean Dice 0.881 36/80 - train: loss 0.122; val: loss 0.172 mean Dice 0.868 37/80 - train: loss 0.115; val: loss 0.194 mean Dice 0.865 38/80 - train: loss 0.115; val: loss 0.158 mean Dice 0.878 39/80 - train: loss 0.110; val: loss 0.163 mean Dice 0.880 40/80 - train: loss 0.103; val: loss 0.157 mean Dice 0.878 41/80 - train: loss 0.099; val: loss 0.185 mean Dice 0.865 42/80 - train: loss 0.104; val: loss 0.228 mean Dice 0.845 43/80 - train: loss 0.102; val: loss 0.159 mean Dice 0.882 44/80 - train: loss 0.101; val: loss 0.191 mean Dice 0.867 45/80 - train: loss 0.096; val: loss 0.153 mean Dice 0.884 46/80 - train: loss 0.093; val: loss 0.153 mean Dice 0.884 47/80 - train: loss 0.090; val: loss 0.146 mean Dice 0.894 48/80 - train: loss 0.087; val: loss 0.150 mean Dice 0.884 49/80 - train: loss 0.086; val: loss 0.160 mean Dice 0.885 50/80 - train: loss 0.082; val: loss 0.153 mean Dice 0.896 51/80 - train: loss 0.081; val: loss 0.146 mean Dice 0.893 52/80 - train: loss 0.084; val: loss 0.147 mean Dice 0.894 53/80 - train: loss 0.079; val: loss 0.146 mean Dice 0.898 54/80 - train: loss 0.075; val: loss 0.156 mean Dice 0.891 55/80 - train: loss 0.077; val: loss 0.148 mean Dice 0.897 56/80 - train: loss 0.072; val: loss 0.148 mean Dice 0.896 57/80 - train: loss 0.072; val: loss 0.151 mean Dice 0.897 58/80 - train: loss 0.070; val: loss 0.144 mean Dice 0.899 59/80 - train: loss 0.068; val: loss 0.164 mean Dice 0.894 60/80 - train: loss 0.067; val: loss 0.150 mean Dice 0.899 61/80 - train: loss 0.065; val: loss 0.145 mean Dice 0.901 62/80 - train: loss 0.065; val: loss 0.149 mean Dice 0.901 63/80 - train: loss 0.062; val: loss 0.147 mean Dice 0.900 64/80 - train: loss 0.063; val: loss 0.151 mean Dice 0.899 65/80 - train: loss 0.061; val: loss 0.148 mean Dice 0.905 66/80 - train: loss 0.061; val: loss 0.147 mean Dice 0.904 67/80 - train: loss 0.061; val: loss 0.150 mean Dice 0.904 68/80 - train: loss 0.059; val: loss 0.147 mean Dice 0.905 69/80 - train: loss 0.059; val: loss 0.148 mean Dice 0.905 70/80 - train: loss 0.058; val: loss 0.145 mean Dice 0.906 71/80 - train: loss 0.058; val: loss 0.148 mean Dice 0.905 72/80 - train: loss 0.057; val: loss 0.150 mean Dice 0.905 73/80 - train: loss 0.057; val: loss 0.151 mean Dice 0.905 74/80 - train: loss 0.056; val: loss 0.148 mean Dice 0.905 75/80 - train: loss 0.057; val: loss 0.148 mean Dice 0.906 76/80 - train: loss 0.055; val: loss 0.149 mean Dice 0.906 77/80 - train: loss 0.055; val: loss 0.150 mean Dice 0.906 78/80 - train: loss 0.056; val: loss 0.149 mean Dice 0.906 79/80 - train: loss 0.056; val: loss 0.148 mean Dice 0.906 80/80 - train: loss 0.055; val: loss 0.150 mean Dice 0.905
def plot_history_train_val(history, key):
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['train ' + key]) + 1)
ax.plot(xs, history['train ' + key], '.-', label='train')
ax.plot(xs, history['val ' + key], '.-', label='val')
ax.set_xlabel('epoch')
ax.set_ylabel(key)
ax.legend()
ax.grid()
plt.show()
def plot_history(history, key):
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history[key]) + 1)
ax.plot(xs, history[key], '.-')
ax.set_xlabel('epoch')
ax.set_ylabel(key)
ax.grid()
plt.show()
plot_history_train_val(history, 'loss')
plot_history(history, 'val mean Dice')
model.eval();
def compare_detection(model, image, target, threshold=0.5):
image_device = image.unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = model(image_device)
logits = logits.detach().cpu().squeeze(0)
probs = torch.sigmoid(logits)
pred_mask = probs > threshold
image = TF.convert_image_dtype(image, dtype=torch.uint8)
gt_mask = target.to(torch.bool)
image_masks = torchvision.utils.draw_segmentation_masks(image=image, masks=pred_mask, alpha=0.5, colors='red')
image_masks = torchvision.utils.draw_segmentation_masks(image=image_masks, masks=gt_mask, alpha=0.5, colors='green')
plot_tensor_image(image_masks)
compare_detection(model, *val_dset[1])