Imports
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
import ignite.metrics
import ignite.contrib.handlers
Configuration
DATA_DIR='./data'
NUM_CLASSES = 10
NUM_WORKERS = 8
BATCH_SIZE = 32
EPOCHS = 100
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)
device: cuda
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor()
])
train_dset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=train_transform)
test_dset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transforms.ToTensor())
Files already downloaded and verified Files already downloaded and verified
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_WORKERS, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, pin_memory=True)
def dataset_show_image(dset, idx):
X, Y = dset[idx]
title = "Ground truth: {}".format(dset.classes[Y])
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_axis_off()
ax.imshow(np.moveaxis(X.numpy(), 0, -1))
ax.set_title(title)
plt.show()
dataset_show_image(test_dset, 1)
def init_linear(m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
class ConvBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
padding = (kernel_size - 1) // 2
super().__init__(
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
)
class BasicResidual(nn.Sequential):
def __init__(self, in_channels, out_channels, stride=1, p_drop=0.):
super().__init__(
ConvBlock(in_channels, out_channels, stride=stride),
ConvBlock(out_channels, out_channels),
nn.Dropout(p_drop)
)
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, p_drop=0.):
super().__init__()
self.shortcut = self.get_shortcut(in_channels, out_channels, stride)
self.residual = BasicResidual(in_channels, out_channels, stride, p_drop)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
out = self.shortcut(x) + self.gamma * self.residual(x)
return out
def get_shortcut(self, in_channels, out_channels, stride):
if in_channels != out_channels:
shortcut = nn.Conv2d(in_channels, out_channels, 1)
if stride > 1:
shortcut = nn.Sequential(nn.AvgPool2d(stride), shortcut)
elif stride > 1:
shortcut = nn.AvgPool2d(stride)
else:
shortcut = nn.Identity()
return shortcut
class ResidualStack(nn.Sequential):
def __init__(self, in_channels, repetitions, strides, p_drop=0.):
layers = []
out_channels = in_channels
for rep, stride in zip(repetitions, strides):
for _ in range(rep):
layers.append(ResidualBlock(in_channels, out_channels, stride, p_drop))
in_channels = out_channels
stride = 1
out_channels *= 2
super().__init__(*layers)
class Stem(nn.Sequential):
def __init__(self, in_channels, channel_list, stride):
layers = [nn.Conv2d(in_channels, channel_list[0], 3, padding=1, stride=stride)]
for in_channels, out_channels in zip(channel_list, channel_list[1:]):
layers.append(ConvBlock(in_channels, out_channels, 3))
super().__init__(*layers)
class Head(nn.Sequential):
def __init__(self, in_channels, classes, p_drop=0.):
super().__init__(
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Dropout(p_drop),
nn.Linear(in_channels, classes)
)
class ResNet(nn.Sequential):
def __init__(self, classes, repetitions, strides=None, in_channels=3, res_p_drop=0., head_p_drop=0.):
if strides is None: strides = [2] * (len(repetitions) + 1)
super().__init__(
Stem(in_channels, [32, 32, 64], strides[0]),
ResidualStack(64, repetitions, strides[1:], res_p_drop),
Head(64 * 2**(len(repetitions) - 1), classes, head_p_drop)
)
model = ResNet(NUM_CLASSES, [2, 2, 2, 2], strides=[1, 1, 2, 2, 2], res_p_drop=0., head_p_drop=0.3)
model.apply(init_linear);
model.to(DEVICE);
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
Number of parameters: 11,203,954
class History:
def __init__(self):
self.values = defaultdict(list)
def append(self, key, value):
self.values[key].append(value)
def reset(self):
for k in self.values.keys():
self.values[k] = []
def _begin_plot(self):
self.fig = plt.figure()
self.ax = self.fig.add_subplot(111)
def _end_plot(self, ylabel):
self.ax.set_xlabel('epoch')
self.ax.set_ylabel(ylabel)
plt.show()
def _plot(self, key, line_type='-', label=None):
if label is None: label=key
xs = np.arange(1, len(self.values[key])+1)
self.ax.plot(xs, self.values[key], line_type, label=label)
def plot(self, key):
self._begin_plot()
self._plot(key, '-')
self._end_plot(key)
def plot_train_val(self, key):
self._begin_plot()
self._plot('train ' + key, '.-', 'train')
self._plot('val ' + key, '.-', 'val')
self.ax.legend()
self._end_plot(key)
loss = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-6, weight_decay=1e-2)
trainer = create_supervised_trainer(model, optimizer, loss, device=DEVICE)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-2,
steps_per_epoch=len(train_loader), epochs=EPOCHS)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step());
ignite.metrics.RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
val_metrics = {"accuracy": ignite.metrics.Accuracy(), "loss": ignite.metrics.Loss(loss)}
evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=DEVICE)
history = History()
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
train_state = engine.state
epoch = train_state.epoch
max_epochs = train_state.max_epochs
train_loss = train_state.metrics["loss"]
history.append('train loss', train_loss)
evaluator.run(test_loader)
val_metrics = evaluator.state.metrics
val_loss = val_metrics["loss"]
val_acc = val_metrics["accuracy"]
history.append('val loss', val_loss)
history.append('val acc', val_acc)
print("{}/{} - train: loss {:.3f}; val: loss {:.3f} accuracy {:.3f}".format(
epoch, max_epochs, train_loss, val_loss, val_acc))
trainer.run(train_loader, max_epochs=EPOCHS);
1/100 - train: loss 1.173; val: loss 1.053 accuracy 0.622 2/100 - train: loss 0.882; val: loss 0.845 accuracy 0.706 3/100 - train: loss 0.727; val: loss 0.695 accuracy 0.763 4/100 - train: loss 0.652; val: loss 0.651 accuracy 0.784 5/100 - train: loss 0.638; val: loss 0.587 accuracy 0.798 6/100 - train: loss 0.582; val: loss 0.589 accuracy 0.803 7/100 - train: loss 0.554; val: loss 0.586 accuracy 0.813 8/100 - train: loss 0.528; val: loss 0.577 accuracy 0.810 9/100 - train: loss 0.517; val: loss 0.506 accuracy 0.833 10/100 - train: loss 0.458; val: loss 0.590 accuracy 0.806 11/100 - train: loss 0.466; val: loss 0.478 accuracy 0.839 12/100 - train: loss 0.431; val: loss 0.521 accuracy 0.831 13/100 - train: loss 0.450; val: loss 0.443 accuracy 0.845 14/100 - train: loss 0.414; val: loss 0.416 accuracy 0.864 15/100 - train: loss 0.417; val: loss 0.531 accuracy 0.828 16/100 - train: loss 0.411; val: loss 0.499 accuracy 0.838 17/100 - train: loss 0.426; val: loss 0.462 accuracy 0.849 18/100 - train: loss 0.400; val: loss 0.663 accuracy 0.792 19/100 - train: loss 0.414; val: loss 0.533 accuracy 0.819 20/100 - train: loss 0.407; val: loss 0.431 accuracy 0.859 21/100 - train: loss 0.434; val: loss 0.488 accuracy 0.842 22/100 - train: loss 0.380; val: loss 0.471 accuracy 0.845 23/100 - train: loss 0.428; val: loss 0.475 accuracy 0.840 24/100 - train: loss 0.414; val: loss 0.393 accuracy 0.872 25/100 - train: loss 0.406; val: loss 0.480 accuracy 0.843 26/100 - train: loss 0.368; val: loss 0.623 accuracy 0.809 27/100 - train: loss 0.384; val: loss 0.518 accuracy 0.833 28/100 - train: loss 0.383; val: loss 0.500 accuracy 0.841 29/100 - train: loss 0.388; val: loss 0.479 accuracy 0.843 30/100 - train: loss 0.384; val: loss 0.407 accuracy 0.866 31/100 - train: loss 0.360; val: loss 0.439 accuracy 0.859 32/100 - train: loss 0.366; val: loss 0.466 accuracy 0.848 33/100 - train: loss 0.382; val: loss 0.406 accuracy 0.863 34/100 - train: loss 0.363; val: loss 0.690 accuracy 0.793 35/100 - train: loss 0.354; val: loss 0.482 accuracy 0.845 36/100 - train: loss 0.340; val: loss 0.514 accuracy 0.832 37/100 - train: loss 0.351; val: loss 0.498 accuracy 0.840 38/100 - train: loss 0.379; val: loss 0.577 accuracy 0.812 39/100 - train: loss 0.365; val: loss 0.471 accuracy 0.851 40/100 - train: loss 0.345; val: loss 0.484 accuracy 0.836 41/100 - train: loss 0.344; val: loss 0.454 accuracy 0.857 42/100 - train: loss 0.322; val: loss 0.423 accuracy 0.865 43/100 - train: loss 0.295; val: loss 0.480 accuracy 0.853 44/100 - train: loss 0.366; val: loss 0.442 accuracy 0.859 45/100 - train: loss 0.348; val: loss 0.378 accuracy 0.877 46/100 - train: loss 0.307; val: loss 0.419 accuracy 0.861 47/100 - train: loss 0.349; val: loss 0.328 accuracy 0.889 48/100 - train: loss 0.324; val: loss 0.406 accuracy 0.865 49/100 - train: loss 0.330; val: loss 0.555 accuracy 0.828 50/100 - train: loss 0.270; val: loss 0.423 accuracy 0.865 51/100 - train: loss 0.306; val: loss 0.449 accuracy 0.859 52/100 - train: loss 0.262; val: loss 0.297 accuracy 0.897 53/100 - train: loss 0.287; val: loss 0.472 accuracy 0.850 54/100 - train: loss 0.278; val: loss 0.440 accuracy 0.858 55/100 - train: loss 0.279; val: loss 0.324 accuracy 0.893 56/100 - train: loss 0.263; val: loss 0.425 accuracy 0.862 57/100 - train: loss 0.227; val: loss 0.347 accuracy 0.888 58/100 - train: loss 0.236; val: loss 0.392 accuracy 0.876 59/100 - train: loss 0.234; val: loss 0.396 accuracy 0.872 60/100 - train: loss 0.246; val: loss 0.401 accuracy 0.878 61/100 - train: loss 0.237; val: loss 0.286 accuracy 0.902 62/100 - train: loss 0.209; val: loss 0.327 accuracy 0.895 63/100 - train: loss 0.224; val: loss 0.355 accuracy 0.886 64/100 - train: loss 0.178; val: loss 0.359 accuracy 0.889 65/100 - train: loss 0.215; val: loss 0.283 accuracy 0.910 66/100 - train: loss 0.180; val: loss 0.303 accuracy 0.906 67/100 - train: loss 0.174; val: loss 0.300 accuracy 0.906 68/100 - train: loss 0.173; val: loss 0.277 accuracy 0.912 69/100 - train: loss 0.175; val: loss 0.315 accuracy 0.906 70/100 - train: loss 0.141; val: loss 0.270 accuracy 0.914 71/100 - train: loss 0.149; val: loss 0.312 accuracy 0.912 72/100 - train: loss 0.143; val: loss 0.279 accuracy 0.918 73/100 - train: loss 0.125; val: loss 0.320 accuracy 0.907 74/100 - train: loss 0.120; val: loss 0.264 accuracy 0.921 75/100 - train: loss 0.112; val: loss 0.260 accuracy 0.924 76/100 - train: loss 0.102; val: loss 0.248 accuracy 0.927 77/100 - train: loss 0.090; val: loss 0.287 accuracy 0.922 78/100 - train: loss 0.081; val: loss 0.254 accuracy 0.928 79/100 - train: loss 0.059; val: loss 0.256 accuracy 0.931 80/100 - train: loss 0.071; val: loss 0.264 accuracy 0.930 81/100 - train: loss 0.073; val: loss 0.252 accuracy 0.931 82/100 - train: loss 0.048; val: loss 0.262 accuracy 0.937 83/100 - train: loss 0.054; val: loss 0.272 accuracy 0.934 84/100 - train: loss 0.036; val: loss 0.276 accuracy 0.934 85/100 - train: loss 0.036; val: loss 0.256 accuracy 0.936 86/100 - train: loss 0.027; val: loss 0.267 accuracy 0.938 87/100 - train: loss 0.025; val: loss 0.275 accuracy 0.937 88/100 - train: loss 0.018; val: loss 0.256 accuracy 0.942 89/100 - train: loss 0.017; val: loss 0.242 accuracy 0.945 90/100 - train: loss 0.015; val: loss 0.246 accuracy 0.947 91/100 - train: loss 0.008; val: loss 0.245 accuracy 0.947 92/100 - train: loss 0.011; val: loss 0.254 accuracy 0.947 93/100 - train: loss 0.008; val: loss 0.252 accuracy 0.947 94/100 - train: loss 0.009; val: loss 0.255 accuracy 0.949 95/100 - train: loss 0.005; val: loss 0.249 accuracy 0.950 96/100 - train: loss 0.007; val: loss 0.242 accuracy 0.951 97/100 - train: loss 0.004; val: loss 0.251 accuracy 0.950 98/100 - train: loss 0.009; val: loss 0.251 accuracy 0.951 99/100 - train: loss 0.003; val: loss 0.244 accuracy 0.952 100/100 - train: loss 0.003; val: loss 0.250 accuracy 0.951
history.plot_train_val('loss')
history.plot('val acc')