Imports
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
import ignite.metrics
import ignite.contrib.handlers
Configuration
DATA_DIR='./data'
NUM_CLASSES = 10
NUM_WORKERS = 8
BATCH_SIZE = 32
EPOCHS = 100
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)
device: cuda
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor()
])
train_dset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=train_transform)
test_dset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transforms.ToTensor())
Files already downloaded and verified Files already downloaded and verified
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True,
                                           num_workers=NUM_WORKERS, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=BATCH_SIZE, shuffle=False,
                                          num_workers=NUM_WORKERS, pin_memory=True)
def dataset_show_image(dset, idx):
    X, Y = dset[idx]
    title = "Ground truth: {}".format(dset.classes[Y])
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.set_axis_off()
    ax.imshow(np.moveaxis(X.numpy(), 0, -1))
    ax.set_title(title)
    plt.show()
dataset_show_image(test_dset, 1)
Utilities
def init_linear(m):
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.kaiming_normal_(m.weight)
        if m.bias is not None: nn.init.zeros_(m.bias)
class Residual(nn.Module):
    def __init__(self, residual, shortcut=None):
        super().__init__()
        self.shortcut = nn.Identity() if shortcut is None else shortcut
        self.residual = residual
        self.gamma = nn.Parameter(torch.zeros(1))
    
    def forward(self, x):
        return self.shortcut(x) + self.gamma * self.residual(x)
class NormAct(nn.Sequential):
    def __init__(self, channels):
        super().__init__(
            nn.BatchNorm2d(channels),
            nn.SiLU(inplace=True)
        )
class ConvBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, groups=1):
        padding = (kernel_size - 1) // 2
        super().__init__(
            NormAct(in_channels),
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, groups=groups),
        )
class SqueezeExciteBlock(nn.Module):
    def __init__(self, channels, reduced_channels):
        super().__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, reduced_channels, kernel_size=1),
            nn.SiLU(inplace=True),
            nn.Conv2d(reduced_channels, channels, kernel_size=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        return x * self.se(x)
EfficientNet
Inverted residual block
class MBConvResidual(nn.Sequential):
    def __init__(self, in_channels, out_channels, expansion, kernel_size=3, stride=1):
        mid_channels = in_channels * expansion
        squeeze_channels = in_channels // 4
        super().__init__(
            ConvBlock(in_channels, mid_channels, 1), # Pointwise
            ConvBlock(mid_channels, mid_channels, kernel_size, stride=stride, groups=mid_channels), # Depthwise
            NormAct(mid_channels),
            SqueezeExciteBlock(mid_channels, squeeze_channels),
            nn.Conv2d(mid_channels, out_channels, kernel_size=1) # Pointwise
        )
class MBConvBlock(Residual):
    def __init__(self, in_channels, out_channels, expansion, kernel_size=3, stride=1):
        residual = MBConvResidual(in_channels, out_channels, expansion, kernel_size, stride)
        shortcut = self.get_shortcut(in_channels, out_channels, stride)
        super().__init__(residual, shortcut)
    
    def get_shortcut(self, in_channels, out_channels, stride):
        if in_channels != out_channels:
            shortcut = nn.Conv2d(in_channels, out_channels, 1)
            if stride > 1:
                shortcut = nn.Sequential(nn.AvgPool2d(stride), shortcut)
        elif stride > 1:
            shortcut = nn.AvgPool2d(stride)
        else:
            shortcut = nn.Identity()
        return shortcut
class BlockStack(nn.Sequential):
    def __init__(self, num_layers, channel_list, strides, expansion=4, kernel_size=3):
        layers = []
        for num, in_channels, out_channels, stride in zip(num_layers, channel_list, channel_list[1:], strides):
            for _ in range(num):
                layers.append(MBConvBlock(in_channels, out_channels, expansion, kernel_size, stride))
                in_channels = out_channels
                stride = 1
        super().__init__(*layers)
class Head(nn.Sequential):
    def __init__(self, in_channels, classes, mult=4, p_drop=0.):
        mid_channels = in_channels * mult
        super().__init__(
            ConvBlock(in_channels, mid_channels, 1),
            NormAct(mid_channels),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(p_drop),
            nn.Linear(mid_channels, classes)
        )
class Stem(nn.Sequential):
    def __init__(self, in_channels, mid_channels, out_channels, stride):
        squeeze_channels = mid_channels // 4
        super().__init__(
            nn.Conv2d(in_channels, mid_channels, 3, stride=stride, padding=1),
            ConvBlock(mid_channels, mid_channels, 3, groups=mid_channels), # Depthwise
            NormAct(mid_channels),
            SqueezeExciteBlock(mid_channels, squeeze_channels),
            nn.Conv2d(mid_channels, out_channels, kernel_size=1) # Pointwise
        )
class EfficientNet(nn.Sequential):
    def __init__(self, classes,  num_layers, channel_list, strides, expansion=4,
                 in_channels=3, head_p_drop=0.):
        super().__init__(
            Stem(in_channels, *channel_list[:2], stride=strides[0]),
            BlockStack(num_layers, channel_list[1:], strides[1:], expansion),
            Head(channel_list[-1], classes, p_drop=head_p_drop)
        )
model = EfficientNet(NUM_CLASSES,
                     num_layers =            [4,  4,   4,   4],
                     channel_list = [32, 16, 32, 64, 128, 256],
                     strides =           [1,  1,  2,   2,   2],
                     expansion = 4,
                     head_p_drop = 0.3)
model.apply(init_linear);
model.to(DEVICE);
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
Number of parameters: 3,351,478
class History:
    def __init__(self):
        self.values = defaultdict(list)
    def append(self, key, value):
        self.values[key].append(value)
    def reset(self):
        for k in self.values.keys():
            self.values[k] = []
    def _begin_plot(self):
        self.fig = plt.figure()
        self.ax = self.fig.add_subplot(111)
    def _end_plot(self, ylabel):
        self.ax.set_xlabel('epoch')
        self.ax.set_ylabel(ylabel)
        plt.show()
    def _plot(self, key, line_type='-', label=None):
        if label is None: label=key
        xs = np.arange(1, len(self.values[key])+1)
        self.ax.plot(xs, self.values[key], line_type, label=label)
    def plot(self, key):
        self._begin_plot()
        self._plot(key, '-')
        self._end_plot(key)
    def plot_train_val(self, key):
        self._begin_plot()
        self._plot('train ' + key, '.-', 'train')
        self._plot('val ' + key, '.-', 'val')
        self.ax.legend()
        self._end_plot(key)
loss = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-6, weight_decay=1e-2)
trainer = create_supervised_trainer(model, optimizer, loss, device=DEVICE)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-2,
                                             steps_per_epoch=len(train_loader), epochs=EPOCHS)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step());
ignite.metrics.RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
val_metrics = {"accuracy": ignite.metrics.Accuracy(), "loss": ignite.metrics.Loss(loss)}
evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=DEVICE)
history = History()
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
    train_state = engine.state
    epoch = train_state.epoch
    max_epochs = train_state.max_epochs
    train_loss = train_state.metrics["loss"]
    history.append('train loss', train_loss)
    
    evaluator.run(test_loader)
    val_metrics = evaluator.state.metrics
    val_loss = val_metrics["loss"]
    val_acc = val_metrics["accuracy"]
    history.append('val loss', val_loss)
    history.append('val acc', val_acc)
    
    print("{}/{} - train: loss {:.3f}; val: loss {:.3f} accuracy {:.3f}".format(
        epoch, max_epochs, train_loss, val_loss, val_acc))
trainer.run(train_loader, max_epochs=EPOCHS);
1/100 - train: loss 1.085; val: loss 0.977 accuracy 0.653 2/100 - train: loss 0.779; val: loss 0.679 accuracy 0.762 3/100 - train: loss 0.648; val: loss 0.593 accuracy 0.792 4/100 - train: loss 0.569; val: loss 0.535 accuracy 0.815 5/100 - train: loss 0.560; val: loss 0.573 accuracy 0.808 6/100 - train: loss 0.518; val: loss 0.514 accuracy 0.822 7/100 - train: loss 0.500; val: loss 0.557 accuracy 0.817 8/100 - train: loss 0.473; val: loss 0.649 accuracy 0.785 9/100 - train: loss 0.454; val: loss 0.504 accuracy 0.831 10/100 - train: loss 0.462; val: loss 0.519 accuracy 0.824 11/100 - train: loss 0.449; val: loss 0.504 accuracy 0.833 12/100 - train: loss 0.447; val: loss 0.486 accuracy 0.836 13/100 - train: loss 0.458; val: loss 0.512 accuracy 0.833 14/100 - train: loss 0.445; val: loss 0.454 accuracy 0.846 15/100 - train: loss 0.412; val: loss 0.519 accuracy 0.833 16/100 - train: loss 0.423; val: loss 0.415 accuracy 0.864 17/100 - train: loss 0.456; val: loss 0.550 accuracy 0.824 18/100 - train: loss 0.454; val: loss 0.515 accuracy 0.828 19/100 - train: loss 0.454; val: loss 0.486 accuracy 0.835 20/100 - train: loss 0.454; val: loss 0.475 accuracy 0.841 21/100 - train: loss 0.474; val: loss 0.448 accuracy 0.849 22/100 - train: loss 0.462; val: loss 0.520 accuracy 0.835 23/100 - train: loss 0.468; val: loss 0.544 accuracy 0.820 24/100 - train: loss 0.442; val: loss 0.550 accuracy 0.828 25/100 - train: loss 0.428; val: loss 0.541 accuracy 0.823 26/100 - train: loss 0.455; val: loss 0.579 accuracy 0.816 27/100 - train: loss 0.451; val: loss 0.463 accuracy 0.842 28/100 - train: loss 0.447; val: loss 0.929 accuracy 0.735 29/100 - train: loss 0.439; val: loss 0.487 accuracy 0.840 30/100 - train: loss 0.417; val: loss 0.423 accuracy 0.857 31/100 - train: loss 0.434; val: loss 0.598 accuracy 0.808 32/100 - train: loss 0.453; val: loss 0.544 accuracy 0.816 33/100 - train: loss 0.439; val: loss 0.453 accuracy 0.851 34/100 - train: loss 0.422; val: loss 0.505 accuracy 0.831 35/100 - train: loss 0.430; val: loss 0.575 accuracy 0.812 36/100 - train: loss 0.432; val: loss 0.469 accuracy 0.848 37/100 - train: loss 0.439; val: loss 0.524 accuracy 0.830 38/100 - train: loss 0.437; val: loss 0.548 accuracy 0.822 39/100 - train: loss 0.434; val: loss 0.639 accuracy 0.808 40/100 - train: loss 0.401; val: loss 0.472 accuracy 0.843 41/100 - train: loss 0.394; val: loss 0.515 accuracy 0.842 42/100 - train: loss 0.430; val: loss 0.626 accuracy 0.805 43/100 - train: loss 0.386; val: loss 0.472 accuracy 0.848 44/100 - train: loss 0.418; val: loss 0.673 accuracy 0.789 45/100 - train: loss 0.405; val: loss 0.409 accuracy 0.868 46/100 - train: loss 0.358; val: loss 0.437 accuracy 0.856 47/100 - train: loss 0.375; val: loss 0.407 accuracy 0.863 48/100 - train: loss 0.383; val: loss 0.469 accuracy 0.851 49/100 - train: loss 0.363; val: loss 0.404 accuracy 0.862 50/100 - train: loss 0.366; val: loss 0.400 accuracy 0.869 51/100 - train: loss 0.350; val: loss 0.523 accuracy 0.837 52/100 - train: loss 0.318; val: loss 0.392 accuracy 0.867 53/100 - train: loss 0.358; val: loss 0.433 accuracy 0.858 54/100 - train: loss 0.357; val: loss 0.379 accuracy 0.876 55/100 - train: loss 0.328; val: loss 0.427 accuracy 0.860 56/100 - train: loss 0.306; val: loss 0.545 accuracy 0.828 57/100 - train: loss 0.256; val: loss 0.386 accuracy 0.876 58/100 - train: loss 0.329; val: loss 0.384 accuracy 0.875 59/100 - train: loss 0.302; val: loss 0.378 accuracy 0.875 60/100 - train: loss 0.290; val: loss 0.389 accuracy 0.873 61/100 - train: loss 0.327; val: loss 0.424 accuracy 0.863 62/100 - train: loss 0.263; val: loss 0.331 accuracy 0.895 63/100 - train: loss 0.245; val: loss 0.394 accuracy 0.876 64/100 - train: loss 0.276; val: loss 0.418 accuracy 0.864 65/100 - train: loss 0.264; val: loss 0.347 accuracy 0.887 66/100 - train: loss 0.263; val: loss 0.343 accuracy 0.887 67/100 - train: loss 0.219; val: loss 0.331 accuracy 0.892 68/100 - train: loss 0.229; val: loss 0.288 accuracy 0.905 69/100 - train: loss 0.205; val: loss 0.330 accuracy 0.898 70/100 - train: loss 0.194; val: loss 0.272 accuracy 0.914 71/100 - train: loss 0.176; val: loss 0.295 accuracy 0.905 72/100 - train: loss 0.164; val: loss 0.261 accuracy 0.917 73/100 - train: loss 0.197; val: loss 0.276 accuracy 0.915 74/100 - train: loss 0.165; val: loss 0.274 accuracy 0.914 75/100 - train: loss 0.168; val: loss 0.264 accuracy 0.918 76/100 - train: loss 0.123; val: loss 0.257 accuracy 0.920 77/100 - train: loss 0.140; val: loss 0.267 accuracy 0.917 78/100 - train: loss 0.130; val: loss 0.302 accuracy 0.911 79/100 - train: loss 0.100; val: loss 0.234 accuracy 0.928 80/100 - train: loss 0.106; val: loss 0.249 accuracy 0.926 81/100 - train: loss 0.106; val: loss 0.247 accuracy 0.928 82/100 - train: loss 0.088; val: loss 0.233 accuracy 0.934 83/100 - train: loss 0.066; val: loss 0.230 accuracy 0.934 84/100 - train: loss 0.084; val: loss 0.239 accuracy 0.933 85/100 - train: loss 0.053; val: loss 0.227 accuracy 0.936 86/100 - train: loss 0.053; val: loss 0.222 accuracy 0.938 87/100 - train: loss 0.042; val: loss 0.233 accuracy 0.939 88/100 - train: loss 0.039; val: loss 0.220 accuracy 0.941 89/100 - train: loss 0.031; val: loss 0.217 accuracy 0.946 90/100 - train: loss 0.026; val: loss 0.226 accuracy 0.942 91/100 - train: loss 0.022; val: loss 0.223 accuracy 0.946 92/100 - train: loss 0.025; val: loss 0.226 accuracy 0.945 93/100 - train: loss 0.022; val: loss 0.224 accuracy 0.945 94/100 - train: loss 0.018; val: loss 0.216 accuracy 0.949 95/100 - train: loss 0.016; val: loss 0.211 accuracy 0.949 96/100 - train: loss 0.009; val: loss 0.217 accuracy 0.949 97/100 - train: loss 0.009; val: loss 0.213 accuracy 0.950 98/100 - train: loss 0.010; val: loss 0.213 accuracy 0.951 99/100 - train: loss 0.009; val: loss 0.211 accuracy 0.950 100/100 - train: loss 0.011; val: loss 0.214 accuracy 0.950
history.plot_train_val('loss')
history.plot('val acc')