Tiny model for CIFAR10¶

Configuration¶

Imports

InĀ [1]:
import math
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt

from tqdm import tqdm, trange

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets
import torchvision.transforms.v2 as transforms

Configuration

InĀ [2]:
DATA_DIR='./data'

NUM_CLASSES = 10
NUM_WORKERS = 8
BATCH_SIZE = 32
EPOCHS = 2000
LEARNING_RATE = 1e-2
WEIGHT_DECAY = 1e-3
InĀ [3]:
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)
device: cuda

Data¶

InĀ [4]:
train_transform = transforms.Compose([
    transforms.ToImage(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToDtype(torch.float, scale=True),
    transforms.RandomErasing(p=1.0, value=0.)
])
InĀ [5]:
val_transform = transforms.Compose([
    transforms.ToImage(),
    transforms.ToDtype(torch.float, scale=True),
])
InĀ [6]:
train_dset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=train_transform)
test_dset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=val_transform)
InĀ [7]:
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
InĀ [8]:
def dataset_show_image(dset, idx):
    X, Y = dset[idx]
    title = "Ground truth: {}".format(dset.classes[Y])
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.set_axis_off()
    ax.imshow(np.moveaxis(X.numpy(), 0, -1))
    ax.set_title(title)
    plt.show()
InĀ [9]:
dataset_show_image(test_dset, 1)
No description has been provided for this image

Model¶

Utilities¶

From arXiv:1904.11486 [cs.CV]

InĀ [10]:
class BlurPool(nn.Module):
    def __init__(self, stride=2, filter_size=4):
        super().__init__()
        self.stride = stride
        self.padding = (filter_size - stride) // 2
        self.register_buffer("filt", self.get_filter(filter_size))

    def forward(self, x):
        channels = x.size(1)
        filt = self.filt.expand(channels, 1, -1, -1)
        x = F.conv2d(x, filt, stride=self.stride, padding=self.padding, groups=channels)
        return x

    def get_filter(self, size):
        filt = torch.tensor(self.binomial_coefficients(size - 1)).float()
        filt = filt[:, None] * filt[None, :]
        filt = filt / filt.sum()  # normalize
        filt = filt[None, None, :, :]
        return filt

    @staticmethod
    def binomial_coefficients(n):
        coef = 1
        coefs = [coef]
        for d in range(1, n + 1):
            coef = coef * (n + 1 - d) // d
            coefs.append(coef)
        return coefs

ECA channel attention, arXiv:1910.03151 [cs.CV]

InĀ [11]:
class ECA(nn.Module):
    def __init__(self, channels):
        super().__init__()
        k_size = self.get_k_size(channels)
        padding = (k_size - 1) // 2

        self.pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=padding)

    def forward(self, x):
        s = self.pool(x)
        c = s.size(1)
        s = s.view(-1, 1, c)
        s = self.conv(s)
        s = s.view(-1, c, 1, 1)
        s = torch.sigmoid(s)

        return x * s

    @staticmethod
    def get_k_size(channels, gamma=2, b=1):
        t = int(abs((math.log(channels, 2) + b) / gamma))
        k = t if t % 2 else t + 1
        return k

Mixed Depthwise Convolution, from arXiv:1907.09595

InĀ [12]:
class MixConv(nn.Module):
    def __init__(self, channels, kernel_sizes):
        super().__init__()
        num_groups = len(kernel_sizes)
        self.splits = self.split_channels(channels, num_groups)

        self.convs = nn.ModuleList()
        for ch, ks in zip(self.splits, kernel_sizes):
            padding = (ks - 1) // 2
            self.convs.append(nn.Conv2d(ch, ch, ks, padding=padding, groups=ch, bias=False))

    def forward(self, x):
        x_split = torch.split(x, self.splits, dim=1)
        x_out = [conv(x_s) for conv, x_s in zip(self.convs, x_split)]
        x = torch.cat(x_out, dim=1)
        return x

    @staticmethod
    def split_channels(channels, num_groups):
        splits = [channels//num_groups] * num_groups
        splits[0] += channels % num_groups
        return splits
InĀ [13]:
class NormAct(nn.Sequential):
    def __init__(self, channels):
        super().__init__(
            nn.BatchNorm2d(channels),
            nn.GELU()
        )

Block¶

InĀ [14]:
class SpatialMixer(nn.Sequential):
    def __init__(self, channels, stride=1):
        super().__init__(
            MixConv(channels, kernel_sizes=[3, 5]),
            NormAct(channels)
        )
        if stride > 1:
            self.insert(0, BlurPool(stride, filter_size=6))
InĀ [15]:
class ChannelMixer(nn.Sequential):
    def __init__(self, channels):
        mid_channels = channels // 2
        super().__init__(
            nn.Conv2d(channels, mid_channels, 1, bias=False),
            NormAct(mid_channels),
            nn.Conv2d(mid_channels, channels, 1, bias=False)
        )
InĀ [16]:
class ResidualBlock(nn.Module):
    def __init__(self, channels, stride=1, p_drop=0.):
        super().__init__()
        self.shortcut = nn.Dropout(p_drop)
        if stride > 1:
            self.shortcut = nn.Sequential(
                nn.AvgPool2d(stride),
                self.shortcut
            )
        self.residual = nn.Sequential(
            NormAct(channels),
            SpatialMixer(channels, stride),
            ECA(channels),
            ChannelMixer(channels)
        )
        self.γ = nn.Parameter(torch.tensor(0.))

    def forward(self, x):
        out = self.shortcut(x) + self.γ * self.residual(x)
        return out

Main model¶

InĀ [17]:
class Head(nn.Sequential):
    def __init__(self, channels, classes, num_features=20, p_drop=0.):
        super().__init__(
            NormAct(channels),
            nn.Dropout(p_drop),
            nn.Conv2d(channels, num_features, 1, bias=False),
            NormAct(num_features),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(num_features, classes)
        )
InĀ [18]:
class Stem(nn.Sequential):
    def __init__(self, in_channels, out_channels, mid_channels=32):
        super().__init__(
            nn.Conv2d(in_channels, mid_channels, 3, padding=1, bias=False),
            NormAct(mid_channels),
            nn.Conv2d(mid_channels, out_channels, 1, bias=False)
        )
InĀ [19]:
class Net(nn.Sequential):
    def __init__(self, classes, width=32, in_channels=3, res_p_drop=0., head_p_drop=0.):
        strides = [1] + [2, 1] * 3
        super().__init__(
            Stem(in_channels, width),
            *[ResidualBlock(width, stride=stride, p_drop=res_p_drop) for stride in strides],
            Head(width, classes, p_drop=head_p_drop)
        )
InĀ [20]:
def reset_parameters(model):
    for m in model.modules():
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            nn.init.xavier_normal_(m.weight)
            if m.bias is not None: nn.init.zeros_(m.bias)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1.)
            nn.init.zeros_(m.bias)
        elif isinstance(m, ResidualBlock):
            nn.init.zeros_(m.γ)
InĀ [21]:
model = Net(NUM_CLASSES, width=92, res_p_drop=0.1, head_p_drop=0.1).to(DEVICE)
InĀ [22]:
reset_parameters(model)
InĀ [23]:
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
Number of parameters: 79,597

Training¶

Training functions¶

InĀ [24]:
def iterate(step_fn, loader):
    num_samples = 0
    total_loss = 0.
    num_correct = 0
    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        loss, out = step_fn(x, y)
        pred = out.argmax(axis=-1)
        correct = (pred == y)
        loss, correct = loss.cpu().numpy(), correct.cpu().numpy()
        num_samples += correct.shape[0]
        total_loss += loss
        num_correct += np.sum(correct)
    
    avg_loss = total_loss / num_samples
    acc = num_correct / num_samples
    metrics = {"loss": avg_loss, "acc": acc}
    return metrics
InĀ [25]:
def train(model, loss_fn, optimizer, loader, batch_scheduler):
    def train_step(x, y):
        out = model(x)
        loss = loss_fn(out, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        batch_scheduler.step()
        return loss.detach(), out.detach()

    model.train()
    metrics = iterate(train_step, loader)
    return metrics
InĀ [26]:
def evaluate(model, loss_fn, loader):
    def eval_step(x, y):
        out = model(x)
        loss = loss_fn(out, y)
        return loss.detach(), out.detach()

    model.eval()
    with torch.inference_mode():
        metrics = iterate(eval_step, loader)
    return metrics
InĀ [27]:
def update_history(history, metrics, name):
    for key, val in metrics.items():
        history[name + ' ' + key].append(val)
InĀ [28]:
def history_plot_train_val(history, key):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    xs = np.arange(1, len(history['train ' + key]) + 1)
    ax.plot(xs, history['train ' + key], '.-', label='train')
    ax.plot(xs, history['val ' + key], '.-', label='val')
    ax.set_xlabel('epoch')
    ax.set_ylabel(key)
    ax.legend()
    ax.grid()
    plt.show()

Start training¶

InĀ [29]:
loss = nn.CrossEntropyLoss()
InĀ [30]:
optimizer = optim.AdamW([p for p in model.parameters() if p.requires_grad],
                        lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
InĀ [31]:
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
                                             steps_per_epoch=len(train_loader), epochs=EPOCHS)
InĀ [32]:
history = defaultdict(list)
InĀ [33]:
pbar = trange(EPOCHS, ncols=140)
for epoch in pbar:
    train_metrics = train(model, loss, optimizer, train_loader, lr_scheduler)
    update_history(history, train_metrics, "train")
    
    val_metrics = evaluate(model, loss, test_loader)
    update_history(history, val_metrics, "val")
    pbar.set_postfix({"acc": train_metrics['acc'], "val acc": val_metrics['acc']})
100%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 2000/2000 [27:06:35<00:00, 48.80s/it, acc=0.966, val acc=0.947]

InĀ [34]:
history_plot_train_val(history, 'loss')
No description has been provided for this image
InĀ [35]:
history_plot_train_val(history, 'acc')
No description has been provided for this image