Imports
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
import ignite.metrics
import ignite.contrib.handlers
Configuration
DATA_DIR='./data'
NUM_CLASSES = 10
NUM_WORKERS = 8
BATCH_SIZE = 32
EPOCHS = 100
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)
device: cuda
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor()
])
train_dset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=train_transform)
test_dset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transforms.ToTensor())
Files already downloaded and verified Files already downloaded and verified
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_WORKERS, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, pin_memory=True)
def dataset_show_image(dset, idx):
X, Y = dset[idx]
title = "Ground truth: {}".format(dset.classes[Y])
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_axis_off()
ax.imshow(np.moveaxis(X.numpy(), 0, -1))
ax.set_title(title)
plt.show()
dataset_show_image(test_dset, 1)
Utilities
def init_linear(m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
class Residual(nn.Module):
def __init__(self, residual, shortcut=None):
super().__init__()
self.shortcut = nn.Identity() if shortcut is None else shortcut
self.residual = residual
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
return self.shortcut(x) + self.gamma * self.residual(x)
class NormAct(nn.Sequential):
def __init__(self, channels):
super().__init__(
nn.BatchNorm2d(channels),
nn.SiLU(inplace=True)
)
class ConvBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super().__init__(
NormAct(in_channels),
nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, groups=groups),
)
class SqueezeExciteBlock(nn.Module):
def __init__(self, channels, reduced_channels):
super().__init__()
self.se = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, reduced_channels, kernel_size=1),
nn.SiLU(inplace=True),
nn.Conv2d(reduced_channels, channels, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
return x * self.se(x)
EfficientNet
Inverted residual block
class MBConvResidual(nn.Sequential):
def __init__(self, in_channels, out_channels, expansion, kernel_size=3, stride=1):
mid_channels = in_channels * expansion
squeeze_channels = in_channels // 4
super().__init__(
ConvBlock(in_channels, mid_channels, 1), # Pointwise
ConvBlock(mid_channels, mid_channels, kernel_size, stride=stride, groups=mid_channels), # Depthwise
NormAct(mid_channels),
SqueezeExciteBlock(mid_channels, squeeze_channels),
nn.Conv2d(mid_channels, out_channels, kernel_size=1) # Pointwise
)
class MBConvBlock(Residual):
def __init__(self, in_channels, out_channels, expansion, kernel_size=3, stride=1):
residual = MBConvResidual(in_channels, out_channels, expansion, kernel_size, stride)
shortcut = self.get_shortcut(in_channels, out_channels, stride)
super().__init__(residual, shortcut)
def get_shortcut(self, in_channels, out_channels, stride):
if in_channels != out_channels:
shortcut = nn.Conv2d(in_channels, out_channels, 1)
if stride > 1:
shortcut = nn.Sequential(nn.AvgPool2d(stride), shortcut)
elif stride > 1:
shortcut = nn.AvgPool2d(stride)
else:
shortcut = nn.Identity()
return shortcut
class BlockStack(nn.Sequential):
def __init__(self, num_layers, channel_list, strides, expansion=4, kernel_size=3):
layers = []
for num, in_channels, out_channels, stride in zip(num_layers, channel_list, channel_list[1:], strides):
for _ in range(num):
layers.append(MBConvBlock(in_channels, out_channels, expansion, kernel_size, stride))
in_channels = out_channels
stride = 1
super().__init__(*layers)
class Head(nn.Sequential):
def __init__(self, in_channels, classes, mult=4, p_drop=0.):
mid_channels = in_channels * mult
super().__init__(
ConvBlock(in_channels, mid_channels, 1),
NormAct(mid_channels),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Dropout(p_drop),
nn.Linear(mid_channels, classes)
)
class Stem(nn.Sequential):
def __init__(self, in_channels, mid_channels, out_channels, stride):
squeeze_channels = mid_channels // 4
super().__init__(
nn.Conv2d(in_channels, mid_channels, 3, stride=stride, padding=1),
ConvBlock(mid_channels, mid_channels, 3, groups=mid_channels), # Depthwise
NormAct(mid_channels),
SqueezeExciteBlock(mid_channels, squeeze_channels),
nn.Conv2d(mid_channels, out_channels, kernel_size=1) # Pointwise
)
class EfficientNet(nn.Sequential):
def __init__(self, classes, num_layers, channel_list, strides, expansion=4,
in_channels=3, head_p_drop=0.):
super().__init__(
Stem(in_channels, *channel_list[:2], stride=strides[0]),
BlockStack(num_layers, channel_list[1:], strides[1:], expansion),
Head(channel_list[-1], classes, p_drop=head_p_drop)
)
model = EfficientNet(NUM_CLASSES,
num_layers = [4, 4, 4, 4],
channel_list = [32, 16, 32, 64, 128, 256],
strides = [1, 1, 2, 2, 2],
expansion = 4,
head_p_drop = 0.3)
model.apply(init_linear);
model.to(DEVICE);
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
Number of parameters: 3,351,478
class History:
def __init__(self):
self.values = defaultdict(list)
def append(self, key, value):
self.values[key].append(value)
def reset(self):
for k in self.values.keys():
self.values[k] = []
def _begin_plot(self):
self.fig = plt.figure()
self.ax = self.fig.add_subplot(111)
def _end_plot(self, ylabel):
self.ax.set_xlabel('epoch')
self.ax.set_ylabel(ylabel)
plt.show()
def _plot(self, key, line_type='-', label=None):
if label is None: label=key
xs = np.arange(1, len(self.values[key])+1)
self.ax.plot(xs, self.values[key], line_type, label=label)
def plot(self, key):
self._begin_plot()
self._plot(key, '-')
self._end_plot(key)
def plot_train_val(self, key):
self._begin_plot()
self._plot('train ' + key, '.-', 'train')
self._plot('val ' + key, '.-', 'val')
self.ax.legend()
self._end_plot(key)
loss = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-6, weight_decay=1e-2)
trainer = create_supervised_trainer(model, optimizer, loss, device=DEVICE)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-2,
steps_per_epoch=len(train_loader), epochs=EPOCHS)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step());
ignite.metrics.RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
val_metrics = {"accuracy": ignite.metrics.Accuracy(), "loss": ignite.metrics.Loss(loss)}
evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=DEVICE)
history = History()
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
train_state = engine.state
epoch = train_state.epoch
max_epochs = train_state.max_epochs
train_loss = train_state.metrics["loss"]
history.append('train loss', train_loss)
evaluator.run(test_loader)
val_metrics = evaluator.state.metrics
val_loss = val_metrics["loss"]
val_acc = val_metrics["accuracy"]
history.append('val loss', val_loss)
history.append('val acc', val_acc)
print("{}/{} - train: loss {:.3f}; val: loss {:.3f} accuracy {:.3f}".format(
epoch, max_epochs, train_loss, val_loss, val_acc))
trainer.run(train_loader, max_epochs=EPOCHS);
1/100 - train: loss 1.085; val: loss 0.977 accuracy 0.653 2/100 - train: loss 0.779; val: loss 0.679 accuracy 0.762 3/100 - train: loss 0.648; val: loss 0.593 accuracy 0.792 4/100 - train: loss 0.569; val: loss 0.535 accuracy 0.815 5/100 - train: loss 0.560; val: loss 0.573 accuracy 0.808 6/100 - train: loss 0.518; val: loss 0.514 accuracy 0.822 7/100 - train: loss 0.500; val: loss 0.557 accuracy 0.817 8/100 - train: loss 0.473; val: loss 0.649 accuracy 0.785 9/100 - train: loss 0.454; val: loss 0.504 accuracy 0.831 10/100 - train: loss 0.462; val: loss 0.519 accuracy 0.824 11/100 - train: loss 0.449; val: loss 0.504 accuracy 0.833 12/100 - train: loss 0.447; val: loss 0.486 accuracy 0.836 13/100 - train: loss 0.458; val: loss 0.512 accuracy 0.833 14/100 - train: loss 0.445; val: loss 0.454 accuracy 0.846 15/100 - train: loss 0.412; val: loss 0.519 accuracy 0.833 16/100 - train: loss 0.423; val: loss 0.415 accuracy 0.864 17/100 - train: loss 0.456; val: loss 0.550 accuracy 0.824 18/100 - train: loss 0.454; val: loss 0.515 accuracy 0.828 19/100 - train: loss 0.454; val: loss 0.486 accuracy 0.835 20/100 - train: loss 0.454; val: loss 0.475 accuracy 0.841 21/100 - train: loss 0.474; val: loss 0.448 accuracy 0.849 22/100 - train: loss 0.462; val: loss 0.520 accuracy 0.835 23/100 - train: loss 0.468; val: loss 0.544 accuracy 0.820 24/100 - train: loss 0.442; val: loss 0.550 accuracy 0.828 25/100 - train: loss 0.428; val: loss 0.541 accuracy 0.823 26/100 - train: loss 0.455; val: loss 0.579 accuracy 0.816 27/100 - train: loss 0.451; val: loss 0.463 accuracy 0.842 28/100 - train: loss 0.447; val: loss 0.929 accuracy 0.735 29/100 - train: loss 0.439; val: loss 0.487 accuracy 0.840 30/100 - train: loss 0.417; val: loss 0.423 accuracy 0.857 31/100 - train: loss 0.434; val: loss 0.598 accuracy 0.808 32/100 - train: loss 0.453; val: loss 0.544 accuracy 0.816 33/100 - train: loss 0.439; val: loss 0.453 accuracy 0.851 34/100 - train: loss 0.422; val: loss 0.505 accuracy 0.831 35/100 - train: loss 0.430; val: loss 0.575 accuracy 0.812 36/100 - train: loss 0.432; val: loss 0.469 accuracy 0.848 37/100 - train: loss 0.439; val: loss 0.524 accuracy 0.830 38/100 - train: loss 0.437; val: loss 0.548 accuracy 0.822 39/100 - train: loss 0.434; val: loss 0.639 accuracy 0.808 40/100 - train: loss 0.401; val: loss 0.472 accuracy 0.843 41/100 - train: loss 0.394; val: loss 0.515 accuracy 0.842 42/100 - train: loss 0.430; val: loss 0.626 accuracy 0.805 43/100 - train: loss 0.386; val: loss 0.472 accuracy 0.848 44/100 - train: loss 0.418; val: loss 0.673 accuracy 0.789 45/100 - train: loss 0.405; val: loss 0.409 accuracy 0.868 46/100 - train: loss 0.358; val: loss 0.437 accuracy 0.856 47/100 - train: loss 0.375; val: loss 0.407 accuracy 0.863 48/100 - train: loss 0.383; val: loss 0.469 accuracy 0.851 49/100 - train: loss 0.363; val: loss 0.404 accuracy 0.862 50/100 - train: loss 0.366; val: loss 0.400 accuracy 0.869 51/100 - train: loss 0.350; val: loss 0.523 accuracy 0.837 52/100 - train: loss 0.318; val: loss 0.392 accuracy 0.867 53/100 - train: loss 0.358; val: loss 0.433 accuracy 0.858 54/100 - train: loss 0.357; val: loss 0.379 accuracy 0.876 55/100 - train: loss 0.328; val: loss 0.427 accuracy 0.860 56/100 - train: loss 0.306; val: loss 0.545 accuracy 0.828 57/100 - train: loss 0.256; val: loss 0.386 accuracy 0.876 58/100 - train: loss 0.329; val: loss 0.384 accuracy 0.875 59/100 - train: loss 0.302; val: loss 0.378 accuracy 0.875 60/100 - train: loss 0.290; val: loss 0.389 accuracy 0.873 61/100 - train: loss 0.327; val: loss 0.424 accuracy 0.863 62/100 - train: loss 0.263; val: loss 0.331 accuracy 0.895 63/100 - train: loss 0.245; val: loss 0.394 accuracy 0.876 64/100 - train: loss 0.276; val: loss 0.418 accuracy 0.864 65/100 - train: loss 0.264; val: loss 0.347 accuracy 0.887 66/100 - train: loss 0.263; val: loss 0.343 accuracy 0.887 67/100 - train: loss 0.219; val: loss 0.331 accuracy 0.892 68/100 - train: loss 0.229; val: loss 0.288 accuracy 0.905 69/100 - train: loss 0.205; val: loss 0.330 accuracy 0.898 70/100 - train: loss 0.194; val: loss 0.272 accuracy 0.914 71/100 - train: loss 0.176; val: loss 0.295 accuracy 0.905 72/100 - train: loss 0.164; val: loss 0.261 accuracy 0.917 73/100 - train: loss 0.197; val: loss 0.276 accuracy 0.915 74/100 - train: loss 0.165; val: loss 0.274 accuracy 0.914 75/100 - train: loss 0.168; val: loss 0.264 accuracy 0.918 76/100 - train: loss 0.123; val: loss 0.257 accuracy 0.920 77/100 - train: loss 0.140; val: loss 0.267 accuracy 0.917 78/100 - train: loss 0.130; val: loss 0.302 accuracy 0.911 79/100 - train: loss 0.100; val: loss 0.234 accuracy 0.928 80/100 - train: loss 0.106; val: loss 0.249 accuracy 0.926 81/100 - train: loss 0.106; val: loss 0.247 accuracy 0.928 82/100 - train: loss 0.088; val: loss 0.233 accuracy 0.934 83/100 - train: loss 0.066; val: loss 0.230 accuracy 0.934 84/100 - train: loss 0.084; val: loss 0.239 accuracy 0.933 85/100 - train: loss 0.053; val: loss 0.227 accuracy 0.936 86/100 - train: loss 0.053; val: loss 0.222 accuracy 0.938 87/100 - train: loss 0.042; val: loss 0.233 accuracy 0.939 88/100 - train: loss 0.039; val: loss 0.220 accuracy 0.941 89/100 - train: loss 0.031; val: loss 0.217 accuracy 0.946 90/100 - train: loss 0.026; val: loss 0.226 accuracy 0.942 91/100 - train: loss 0.022; val: loss 0.223 accuracy 0.946 92/100 - train: loss 0.025; val: loss 0.226 accuracy 0.945 93/100 - train: loss 0.022; val: loss 0.224 accuracy 0.945 94/100 - train: loss 0.018; val: loss 0.216 accuracy 0.949 95/100 - train: loss 0.016; val: loss 0.211 accuracy 0.949 96/100 - train: loss 0.009; val: loss 0.217 accuracy 0.949 97/100 - train: loss 0.009; val: loss 0.213 accuracy 0.950 98/100 - train: loss 0.010; val: loss 0.213 accuracy 0.951 99/100 - train: loss 0.009; val: loss 0.211 accuracy 0.950 100/100 - train: loss 0.011; val: loss 0.214 accuracy 0.950
history.plot_train_val('loss')
history.plot('val acc')