Main ideas:
Imports
import math
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
import ignite.metrics
import ignite.contrib.handlers
Configuration
DATA_DIR='./data'
NUM_CLASSES = 10
NUM_WORKERS = 8
BATCH_SIZE = 32
LEARNING_RATE = 1e-2
WEIGHT_DECAY = 1e-2
EPOCHS = 100
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)
device: cuda
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor()
])
train_dset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=train_transform)
test_dset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transforms.ToTensor())
Files already downloaded and verified Files already downloaded and verified
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_WORKERS, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, pin_memory=True)
Utilities
def init_linear(m):
if isinstance(m, (nn.Conv2d, nn.Conv1d, nn.Linear)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
class NormAct(nn.Sequential):
def __init__(self, channels):
super().__init__(nn.BatchNorm2d(channels),
nn.SiLU(inplace=True))
ECA channel attention, arXiv:1910.03151 [cs.CV]
class ECA(nn.Module):
def __init__(self, channels, gamma=2, b=1):
super().__init__()
t = int(abs((math.log(channels, 2) + b) / gamma))
k = t if t % 2 else t + 1
padding = (k - 1) // 2
self.pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=padding)
def forward(self, x):
s = self.pool(x)
c = s.size(1)
s = s.view(-1, 1, c)
s = self.conv(s)
s = s.view(-1, c, 1, 1)
s = torch.sigmoid(s)
return x * s
Residual block
class Block(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.residual = nn.Sequential(
NormAct(in_channels),
nn.Conv2d(in_channels, in_channels, 3, padding=1, stride=stride, groups=in_channels),
NormAct(in_channels),
ECA(in_channels),
nn.Conv2d(in_channels, out_channels, 1)
)
self.shortcut = self.get_shortcut(in_channels, out_channels, stride)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
return self.shortcut(x) + self.gamma * self.residual(x)
@staticmethod
def get_shortcut(in_channels, out_channels, stride):
if in_channels != out_channels:
shortcut = nn.Conv2d(in_channels, out_channels, 1)
if stride > 1:
shortcut = nn.Sequential(nn.AvgPool2d(stride), shortcut)
elif stride > 1:
shortcut = nn.AvgPool2d(stride)
else:
shortcut = nn.Identity()
return shortcut
Main model
class Stage(nn.Sequential):
def __init__(self, in_channels, out_channels, num_blocks, stride=1):
super().__init__(
Block(in_channels, out_channels, stride),
*[Block(out_channels, out_channels) for _ in range(num_blocks - 1)]
)
class Body(nn.Sequential):
def __init__(self, in_channels, channel_list, num_blocks_list, strides):
layers = []
for out_channels, num_blocks, stride in zip(channel_list, num_blocks_list, strides):
layers.append(Stage(in_channels, out_channels, num_blocks, stride))
in_channels = out_channels
super().__init__(*layers)
class Head(nn.Sequential):
def __init__(self, in_channels, classes, p_drop=0.):
super().__init__(
NormAct(in_channels),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Dropout(p_drop),
nn.Linear(in_channels, classes)
)
def Stem(in_channels, out_channels, stride):
return nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)
class Net(nn.Sequential):
def __init__(self, classes, num_blocks_list, channel_list, strides, in_channels=3, head_p_drop=0.):
super().__init__(
Stem(in_channels, channel_list[0], strides[0]),
Body(channel_list[0], channel_list[1:], num_blocks_list, strides[1:]),
Head(channel_list[-1], classes, head_p_drop)
)
model = Net(NUM_CLASSES,
num_blocks_list = [6, 6, 6, 4],
channel_list = [32, 64, 128, 256, 512],
strides = [1, 1, 2, 2, 2],
head_p_drop=0.3)
model.apply(init_linear);
model.to(DEVICE);
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
Number of parameters: 1,637,142
loss = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
trainer = create_supervised_trainer(model, optimizer, loss, device=DEVICE)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
steps_per_epoch=len(train_loader), epochs=EPOCHS)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step());
ignite.metrics.RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
val_metrics = {"accuracy": ignite.metrics.Accuracy(), "loss": ignite.metrics.Loss(loss)}
evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=DEVICE)
history = defaultdict(list)
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
train_state = engine.state
epoch = train_state.epoch
max_epochs = train_state.max_epochs
train_loss = train_state.metrics["loss"]
history['train loss'].append(train_loss)
evaluator.run(test_loader)
val_metrics = evaluator.state.metrics
val_loss = val_metrics["loss"]
val_acc = val_metrics["accuracy"]
history['val loss'].append(val_loss)
history['val acc'].append(val_acc)
print("{}/{} - train: loss {:.3f}; val: loss {:.3f} accuracy {:.3f}".format(
epoch, max_epochs, train_loss, val_loss, val_acc))
trainer.run(train_loader, max_epochs=EPOCHS);
1/100 - train: loss 1.085; val: loss 0.988 accuracy 0.648 2/100 - train: loss 0.792; val: loss 0.804 accuracy 0.720 3/100 - train: loss 0.656; val: loss 0.625 accuracy 0.782 4/100 - train: loss 0.625; val: loss 0.578 accuracy 0.799 5/100 - train: loss 0.558; val: loss 0.531 accuracy 0.818 6/100 - train: loss 0.561; val: loss 0.490 accuracy 0.830 7/100 - train: loss 0.470; val: loss 0.457 accuracy 0.844 8/100 - train: loss 0.476; val: loss 0.467 accuracy 0.839 9/100 - train: loss 0.443; val: loss 0.516 accuracy 0.827 10/100 - train: loss 0.475; val: loss 0.482 accuracy 0.834 11/100 - train: loss 0.430; val: loss 0.469 accuracy 0.842 12/100 - train: loss 0.442; val: loss 0.439 accuracy 0.849 13/100 - train: loss 0.442; val: loss 0.434 accuracy 0.854 14/100 - train: loss 0.442; val: loss 0.472 accuracy 0.835 15/100 - train: loss 0.457; val: loss 0.432 accuracy 0.853 16/100 - train: loss 0.455; val: loss 0.436 accuracy 0.857 17/100 - train: loss 0.464; val: loss 0.442 accuracy 0.851 18/100 - train: loss 0.416; val: loss 0.460 accuracy 0.846 19/100 - train: loss 0.416; val: loss 0.445 accuracy 0.848 20/100 - train: loss 0.430; val: loss 0.502 accuracy 0.832 21/100 - train: loss 0.427; val: loss 0.520 accuracy 0.823 22/100 - train: loss 0.406; val: loss 0.470 accuracy 0.848 23/100 - train: loss 0.427; val: loss 0.487 accuracy 0.835 24/100 - train: loss 0.397; val: loss 0.420 accuracy 0.859 25/100 - train: loss 0.423; val: loss 0.420 accuracy 0.860 26/100 - train: loss 0.437; val: loss 0.405 accuracy 0.864 27/100 - train: loss 0.418; val: loss 0.742 accuracy 0.759 28/100 - train: loss 0.407; val: loss 0.467 accuracy 0.842 29/100 - train: loss 0.411; val: loss 0.498 accuracy 0.835 30/100 - train: loss 0.395; val: loss 0.456 accuracy 0.851 31/100 - train: loss 0.389; val: loss 0.396 accuracy 0.868 32/100 - train: loss 0.411; val: loss 0.526 accuracy 0.832 33/100 - train: loss 0.404; val: loss 0.479 accuracy 0.844 34/100 - train: loss 0.414; val: loss 0.471 accuracy 0.840 35/100 - train: loss 0.358; val: loss 0.456 accuracy 0.853 36/100 - train: loss 0.410; val: loss 0.424 accuracy 0.860 37/100 - train: loss 0.382; val: loss 0.468 accuracy 0.849 38/100 - train: loss 0.363; val: loss 0.552 accuracy 0.829 39/100 - train: loss 0.420; val: loss 0.507 accuracy 0.826 40/100 - train: loss 0.353; val: loss 0.463 accuracy 0.849 41/100 - train: loss 0.343; val: loss 0.366 accuracy 0.876 42/100 - train: loss 0.354; val: loss 0.428 accuracy 0.856 43/100 - train: loss 0.356; val: loss 0.378 accuracy 0.870 44/100 - train: loss 0.344; val: loss 0.372 accuracy 0.877 45/100 - train: loss 0.362; val: loss 0.369 accuracy 0.875 46/100 - train: loss 0.336; val: loss 0.434 accuracy 0.859 47/100 - train: loss 0.327; val: loss 0.326 accuracy 0.888 48/100 - train: loss 0.330; val: loss 0.533 accuracy 0.832 49/100 - train: loss 0.350; val: loss 0.346 accuracy 0.886 50/100 - train: loss 0.325; val: loss 0.427 accuracy 0.857 51/100 - train: loss 0.300; val: loss 0.317 accuracy 0.896 52/100 - train: loss 0.296; val: loss 0.542 accuracy 0.830 53/100 - train: loss 0.321; val: loss 0.382 accuracy 0.868 54/100 - train: loss 0.294; val: loss 0.359 accuracy 0.882 55/100 - train: loss 0.294; val: loss 0.382 accuracy 0.876 56/100 - train: loss 0.276; val: loss 0.358 accuracy 0.884 57/100 - train: loss 0.276; val: loss 0.293 accuracy 0.902 58/100 - train: loss 0.262; val: loss 0.293 accuracy 0.905 59/100 - train: loss 0.255; val: loss 0.338 accuracy 0.892 60/100 - train: loss 0.233; val: loss 0.337 accuracy 0.894 61/100 - train: loss 0.244; val: loss 0.297 accuracy 0.902 62/100 - train: loss 0.259; val: loss 0.364 accuracy 0.886 63/100 - train: loss 0.237; val: loss 0.315 accuracy 0.899 64/100 - train: loss 0.204; val: loss 0.267 accuracy 0.912 65/100 - train: loss 0.196; val: loss 0.305 accuracy 0.903 66/100 - train: loss 0.184; val: loss 0.302 accuracy 0.902 67/100 - train: loss 0.192; val: loss 0.279 accuracy 0.910 68/100 - train: loss 0.176; val: loss 0.302 accuracy 0.907 69/100 - train: loss 0.155; val: loss 0.259 accuracy 0.916 70/100 - train: loss 0.154; val: loss 0.277 accuracy 0.912 71/100 - train: loss 0.157; val: loss 0.260 accuracy 0.919 72/100 - train: loss 0.152; val: loss 0.247 accuracy 0.925 73/100 - train: loss 0.116; val: loss 0.265 accuracy 0.921 74/100 - train: loss 0.132; val: loss 0.241 accuracy 0.926 75/100 - train: loss 0.105; val: loss 0.250 accuracy 0.923 76/100 - train: loss 0.129; val: loss 0.321 accuracy 0.910 77/100 - train: loss 0.087; val: loss 0.252 accuracy 0.928 78/100 - train: loss 0.100; val: loss 0.262 accuracy 0.924 79/100 - train: loss 0.105; val: loss 0.231 accuracy 0.932 80/100 - train: loss 0.084; val: loss 0.254 accuracy 0.930 81/100 - train: loss 0.076; val: loss 0.247 accuracy 0.937 82/100 - train: loss 0.056; val: loss 0.247 accuracy 0.933 83/100 - train: loss 0.059; val: loss 0.255 accuracy 0.934 84/100 - train: loss 0.040; val: loss 0.244 accuracy 0.937 85/100 - train: loss 0.038; val: loss 0.240 accuracy 0.937 86/100 - train: loss 0.038; val: loss 0.262 accuracy 0.936 87/100 - train: loss 0.035; val: loss 0.232 accuracy 0.942 88/100 - train: loss 0.023; val: loss 0.224 accuracy 0.943 89/100 - train: loss 0.017; val: loss 0.241 accuracy 0.941 90/100 - train: loss 0.015; val: loss 0.231 accuracy 0.944 91/100 - train: loss 0.009; val: loss 0.240 accuracy 0.944 92/100 - train: loss 0.020; val: loss 0.240 accuracy 0.945 93/100 - train: loss 0.009; val: loss 0.242 accuracy 0.946 94/100 - train: loss 0.006; val: loss 0.237 accuracy 0.948 95/100 - train: loss 0.007; val: loss 0.240 accuracy 0.949 96/100 - train: loss 0.007; val: loss 0.238 accuracy 0.950 97/100 - train: loss 0.005; val: loss 0.239 accuracy 0.950 98/100 - train: loss 0.005; val: loss 0.238 accuracy 0.950 99/100 - train: loss 0.005; val: loss 0.240 accuracy 0.951 100/100 - train: loss 0.004; val: loss 0.237 accuracy 0.950
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['train loss']) + 1)
ax.plot(xs, history['train loss'], '.-', label='train')
ax.plot(xs, history['val loss'], '.-', label='val')
ax.set_xlabel('epoch')
ax.set_ylabel('loss')
ax.legend()
ax.grid()
plt.show()
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['val acc']) + 1)
ax.plot(xs, history['val acc'], '-')
ax.set_xlabel('epoch')
ax.set_ylabel('val acc')
ax.grid()
plt.show()