ResNet on CIFAR10 using PyTorch Lightning

Configuration

Imports

In [ ]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

import pytorch_lightning as pl
In [ ]:
%load_ext tensorboard

Configuration

In [ ]:
DATA_DIR='./data'

NUM_CLASSES = 10
NUM_WORKERS = 24
BATCH_SIZE = 32
EPOCHS = 50

Data

In [ ]:
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor()
])
In [ ]:
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, train_transform, data_dir='./', batch_size=32, num_workers=8):
        super().__init__()
        self.train_transform = train_transform
        self.val_transform = transforms.ToTensor()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
    
    def prepare_data(self):
        datasets.CIFAR10(root=self.data_dir, train=True, download=True)
        datasets.CIFAR10(root=self.data_dir, train=False, download=True)
    
    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.train_dset = datasets.CIFAR10(root=self.data_dir, train=True,
                                               transform=self.train_transform)
            self.val_dset = datasets.CIFAR10(root=self.data_dir, train=False,
                                             transform=self.val_transform)
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dset, batch_size=self.batch_size,
                                           num_workers=self.num_workers, pin_memory=True)


    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dset, batch_size=self.batch_size,
                                           num_workers=self.num_workers, pin_memory=True)
In [ ]:
dm = CIFAR10DataModule(train_transform, data_dir=DATA_DIR, batch_size=BATCH_SIZE,
                       num_workers=NUM_WORKERS)
In [ ]:
dm.prepare_data()
In [ ]:
dm.setup()

Model

In [ ]:
@torch.no_grad()
def init_linear(m):
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.kaiming_normal_(m.weight)
        if m.bias is not None: nn.init.zeros_(m.bias)
In [ ]:
def conv_bn(in_channels, out_channels, kernel_size=3, stride=1):
    padding = (kernel_size - 1) // 2
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False),
        nn.BatchNorm2d(out_channels)
    )
In [ ]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, res_channels, stride=1):
        super().__init__()
        self.shortcut = self.get_shortcut(in_channels, res_channels, stride)
        
        self.residual = nn.Sequential(
            conv_bn(in_channels, res_channels, stride=stride),
            nn.ReLU(inplace=True),
            conv_bn(res_channels, res_channels)
        )
        self.act = nn.ReLU(inplace=True)
        
        self.gamma = nn.Parameter(torch.zeros(1))
    
    def forward(self, x):
        out = self.shortcut(x) + self.gamma * self.residual(x)
        return self.act(out)
    
    def get_shortcut(self, in_channels, res_channels, stride):
        layers = []
        if stride > 1: layers.append(nn.AvgPool2d(stride))
        if in_channels != res_channels: layers.append(conv_bn(in_channels, res_channels, 1))
        return nn.Sequential(*layers)
In [ ]:
def residual_body(in_channels, repetitions, strides):
    layers = []
    res_channels = in_channels
    for rep, stride in zip(repetitions, strides):
        for _ in range(rep):
            layers.append(ResidualBlock(in_channels, res_channels, stride))
            in_channels = res_channels
            stride = 1
        res_channels = res_channels * 2
    return nn.Sequential(*layers)
In [ ]:
def stem(channel_list, stride):
    layers = []
    for in_channels, out_channels in zip(channel_list, channel_list[1:]):
        layers += [conv_bn(in_channels, out_channels, stride=stride), nn.ReLU(inplace=True)]
        stride = 1
    return nn.Sequential(*layers)
In [ ]:
def head(in_channels, classes, p_drop=0.):
    return nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(p_drop),
            nn.Linear(in_channels, classes)
        )
In [ ]:
def resnet(repetitions, classes, strides=None, p_drop=0.):
    if not strides: strides = [2] * (len(repetitions) + 1)
    return nn.Sequential(
        stem([3, 32, 32, 64], strides[0]),
        residual_body(64, repetitions, strides[1:]),
        head(64 * 2**(len(repetitions) - 1), classes, p_drop)
    )
In [ ]:
model = resnet([2, 2, 2, 2], NUM_CLASSES, strides=[1, 1, 2, 2, 2], p_drop=0.3)
In [ ]:
model.apply(init_linear);

Training

In [ ]:
class ClassificationTask(pl.LightningModule):
    def __init__(self, model, max_lr, epochs, steps_per_epoch):
        super().__init__()
        self.save_hyperparameters('max_lr', 'epochs')
        self.steps_per_epoch = steps_per_epoch
        self.model = model
        self.loss = nn.CrossEntropyLoss()
        self.train_acc = pl.metrics.Accuracy(compute_on_step=False)
        self.val_acc = pl.metrics.Accuracy(compute_on_step=False)
    
    def forward(self, x):
        return self.model(x)
    
    def _shared_step(self, batch, metric, prefix):
        x, y = batch
        logits = self.model(x)
        loss = self.loss(logits, y)
        metric(logits, y)
        self.log(f'{prefix}_loss', loss)
        return loss
    
    def training_step(self, batch, batch_idx):
        return self._shared_step(batch, self.train_acc, 'train')
    
    def training_epoch_end(self, outs):
        self.log('train_acc', self.train_acc.compute())
    
    def validation_step(self, batch, batch_idx):
        return self._shared_step(batch, self.val_acc, 'val')
    
    def validation_epoch_end(self, val_outs):
        self.log('val_acc', self.val_acc.compute())
    
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.model.parameters(), weight_decay=1e-2)
        lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.hparams.max_lr,
                                                     steps_per_epoch=self.steps_per_epoch,
                                                     epochs=self.hparams.epochs)
        lr_dict = {'scheduler': lr_scheduler, 'interval': 'step'}
        return [optimizer], [lr_dict]
In [ ]:
classifier = ClassificationTask(model, max_lr=1e-2, epochs=EPOCHS,
                                steps_per_epoch=len(dm.train_dataloader()))

Start Training

In [ ]:
trainer = pl.Trainer(gpus=1, max_epochs=EPOCHS+1)
In [ ]:
lr_finder = trainer.tuner.lr_find(classifier, datamodule=dm, min_lr=1e-6, max_lr=1e-1)
In [ ]:
fig = lr_finder.plot(suggest=True)
fig.show()
In [ ]:
%tensorboard --logdir "lightning_logs/"
In [ ]:
trainer.fit(classifier, datamodule=dm)
In [ ]:
trainer.save_checkpoint('model.ckpt')