Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import pytorch_lightning as pl
%load_ext tensorboard
Configuration
DATA_DIR='./data'
NUM_CLASSES = 10
NUM_WORKERS = 24
BATCH_SIZE = 32
EPOCHS = 50
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor()
])
class CIFAR10DataModule(pl.LightningDataModule):
def __init__(self, train_transform, data_dir='./', batch_size=32, num_workers=8):
super().__init__()
self.train_transform = train_transform
self.val_transform = transforms.ToTensor()
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers
def prepare_data(self):
datasets.CIFAR10(root=self.data_dir, train=True, download=True)
datasets.CIFAR10(root=self.data_dir, train=False, download=True)
def setup(self, stage=None):
if stage == 'fit' or stage is None:
self.train_dset = datasets.CIFAR10(root=self.data_dir, train=True,
transform=self.train_transform)
self.val_dset = datasets.CIFAR10(root=self.data_dir, train=False,
transform=self.val_transform)
def train_dataloader(self):
return torch.utils.data.DataLoader(self.train_dset, batch_size=self.batch_size,
num_workers=self.num_workers, pin_memory=True)
def val_dataloader(self):
return torch.utils.data.DataLoader(self.val_dset, batch_size=self.batch_size,
num_workers=self.num_workers, pin_memory=True)
dm = CIFAR10DataModule(train_transform, data_dir=DATA_DIR, batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS)
dm.prepare_data()
dm.setup()
@torch.no_grad()
def init_linear(m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
def conv_bn(in_channels, out_channels, kernel_size=3, stride=1):
padding = (kernel_size - 1) // 2
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm2d(out_channels)
)
class ResidualBlock(nn.Module):
def __init__(self, in_channels, res_channels, stride=1):
super().__init__()
self.shortcut = self.get_shortcut(in_channels, res_channels, stride)
self.residual = nn.Sequential(
conv_bn(in_channels, res_channels, stride=stride),
nn.ReLU(inplace=True),
conv_bn(res_channels, res_channels)
)
self.act = nn.ReLU(inplace=True)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
out = self.shortcut(x) + self.gamma * self.residual(x)
return self.act(out)
def get_shortcut(self, in_channels, res_channels, stride):
layers = []
if stride > 1: layers.append(nn.AvgPool2d(stride))
if in_channels != res_channels: layers.append(conv_bn(in_channels, res_channels, 1))
return nn.Sequential(*layers)
def residual_body(in_channels, repetitions, strides):
layers = []
res_channels = in_channels
for rep, stride in zip(repetitions, strides):
for _ in range(rep):
layers.append(ResidualBlock(in_channels, res_channels, stride))
in_channels = res_channels
stride = 1
res_channels = res_channels * 2
return nn.Sequential(*layers)
def stem(channel_list, stride):
layers = []
for in_channels, out_channels in zip(channel_list, channel_list[1:]):
layers += [conv_bn(in_channels, out_channels, stride=stride), nn.ReLU(inplace=True)]
stride = 1
return nn.Sequential(*layers)
def head(in_channels, classes, p_drop=0.):
return nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Dropout(p_drop),
nn.Linear(in_channels, classes)
)
def resnet(repetitions, classes, strides=None, p_drop=0.):
if not strides: strides = [2] * (len(repetitions) + 1)
return nn.Sequential(
stem([3, 32, 32, 64], strides[0]),
residual_body(64, repetitions, strides[1:]),
head(64 * 2**(len(repetitions) - 1), classes, p_drop)
)
model = resnet([2, 2, 2, 2], NUM_CLASSES, strides=[1, 1, 2, 2, 2], p_drop=0.3)
model.apply(init_linear);
class ClassificationTask(pl.LightningModule):
def __init__(self, model, max_lr, epochs, steps_per_epoch):
super().__init__()
self.save_hyperparameters('max_lr', 'epochs')
self.steps_per_epoch = steps_per_epoch
self.model = model
self.loss = nn.CrossEntropyLoss()
self.train_acc = pl.metrics.Accuracy(compute_on_step=False)
self.val_acc = pl.metrics.Accuracy(compute_on_step=False)
def forward(self, x):
return self.model(x)
def _shared_step(self, batch, metric, prefix):
x, y = batch
logits = self.model(x)
loss = self.loss(logits, y)
metric(logits, y)
self.log(f'{prefix}_loss', loss)
return loss
def training_step(self, batch, batch_idx):
return self._shared_step(batch, self.train_acc, 'train')
def training_epoch_end(self, outs):
self.log('train_acc', self.train_acc.compute())
def validation_step(self, batch, batch_idx):
return self._shared_step(batch, self.val_acc, 'val')
def validation_epoch_end(self, val_outs):
self.log('val_acc', self.val_acc.compute())
def configure_optimizers(self):
optimizer = optim.AdamW(self.model.parameters(), weight_decay=1e-2)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.hparams.max_lr,
steps_per_epoch=self.steps_per_epoch,
epochs=self.hparams.epochs)
lr_dict = {'scheduler': lr_scheduler, 'interval': 'step'}
return [optimizer], [lr_dict]
classifier = ClassificationTask(model, max_lr=1e-2, epochs=EPOCHS,
steps_per_epoch=len(dm.train_dataloader()))
trainer = pl.Trainer(gpus=1, max_epochs=EPOCHS+1)
lr_finder = trainer.tuner.lr_find(classifier, datamodule=dm, min_lr=1e-6, max_lr=1e-1)
fig = lr_finder.plot(suggest=True)
fig.show()
%tensorboard --logdir "lightning_logs/"
trainer.fit(classifier, datamodule=dm)
trainer.save_checkpoint('model.ckpt')