Architecture based on general scheme of transformers.
Ideas from:
Imports
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
import ignite.metrics
import ignite.contrib.handlers
Configuration
DATA_DIR='./data'
IMAGE_SIZE = 32
NUM_CLASSES = 10
NUM_WORKERS = 8
BATCH_SIZE = 512
EPOCHS = 200
LEARNING_RATE = 1e-2
WEIGHT_DECAY = 1e-1
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)
device: cuda
train_transform = transforms.Compose([
transforms.TrivialAugmentWide(interpolation=transforms.InterpolationMode.BILINEAR),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.RandomErasing(p=0.1)
])
train_dset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=train_transform)
val_dset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transforms.ToTensor())
Files already downloaded and verified Files already downloaded and verified
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_WORKERS, pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_dset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, pin_memory=True)
class Residual(nn.Module):
def __init__(self, residual, shortcut=None):
super().__init__()
self.residual = residual
self.shortcut = shortcut if shortcut is not None else nn.Identity()
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
return self.shortcut(x) + self.gamma * self.residual(x)
Block
class SpatialMixer(nn.Sequential):
def __init__(self, channels, kernel_size, stride=1):
padding = (kernel_size - 1) // 2
super().__init__(
nn.BatchNorm2d(channels),
nn.Conv2d(channels, channels, kernel_size, padding=padding, stride=stride, groups=channels)
)
class ChannelMixer(nn.Sequential):
def __init__(self, in_channels, out_channels, mult=4):
mid_channels = in_channels * mult
super().__init__(
nn.BatchNorm2d(in_channels),
nn.Conv2d(in_channels, mid_channels, 1),
nn.GELU(),
nn.Conv2d(mid_channels, out_channels, 1)
)
class Block(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
super().__init__(
Residual(SpatialMixer(in_channels, kernel_size, stride),
shortcut = nn.AvgPool2d(stride) if stride > 1 else None),
Residual(ChannelMixer(in_channels, out_channels),
shortcut = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else None)
)
Stage
class Stage(nn.Sequential):
def __init__(self, in_channels, out_channels, num_blocks, kernel_size, stride=1):
super().__init__(
Block(in_channels, out_channels, kernel_size, stride),
*[Block(out_channels, out_channels, kernel_size) for _ in range(num_blocks - 1)]
)
class StageStack(nn.Sequential):
def __init__(self, in_channels, channels_list, num_blocks_list, strides, kernel_size):
layers = []
for num, out_channels, stride in zip(num_blocks_list, channels_list, strides):
layers.append(Stage(in_channels, out_channels, num, kernel_size, stride))
in_channels = out_channels
super().__init__(*layers)
Main model
def Stem(in_channels, out_channels, kernel_size=3, stride=1):
padding = (kernel_size - 1) // 2
return nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, stride=stride)
class Head(nn.Sequential):
def __init__(self, channels, classes, p_drop=0.):
super().__init__(
nn.BatchNorm2d(channels),
nn.GELU(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Dropout(p_drop),
nn.Linear(channels, classes)
)
class Net(nn.Sequential):
def __init__(self, classes, num_blocks_list, channels_list, strides, kernel_size, in_channels=3, head_p_drop=0.):
super().__init__(
Stem(in_channels, channels_list[0], stride=strides[0]),
StageStack(channels_list[0], channels_list, num_blocks_list, strides[1:], kernel_size),
Head(channels_list[-1], classes, head_p_drop)
)
def init_linear(m):
if isinstance(m, (nn.Conv2d, nn.Conv1d, nn.Linear)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
model = Net(NUM_CLASSES,
num_blocks_list = [4, 4, 2, 2],
channels_list = [32, 64, 128, 256],
strides = [1, 1, 2, 2, 2],
kernel_size = 5,
head_p_drop = 0.3)
model.apply(init_linear);
model.to(DEVICE);
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
Number of parameters: 1,124,642
loss = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
steps_per_epoch=len(train_loader), epochs=EPOCHS)
Trainer
trainer = create_supervised_trainer(model, optimizer, loss, device=DEVICE)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step());
ignite.metrics.RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
Evaluator
val_metrics = {"accuracy": ignite.metrics.Accuracy(), "loss": ignite.metrics.Loss(loss)}
evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=DEVICE)
history = defaultdict(list)
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
train_state = engine.state
epoch = train_state.epoch
max_epochs = train_state.max_epochs
train_loss = train_state.metrics["loss"]
history['train loss'].append(train_loss)
evaluator.run(val_loader)
val_metrics = evaluator.state.metrics
val_loss = val_metrics["loss"]
val_acc = val_metrics["accuracy"]
history['val loss'].append(val_loss)
history['val acc'].append(val_acc)
print("{}/{} - train: loss {:.3f}; val: loss {:.3f} accuracy {:.3f}".format(
epoch, max_epochs, train_loss, val_loss, val_acc))
trainer.run(train_loader, max_epochs=EPOCHS);
1/200 - train: loss 2.228; val: loss 1.970 accuracy 0.302 2/200 - train: loss 2.049; val: loss 1.775 accuracy 0.411 3/200 - train: loss 1.907; val: loss 1.769 accuracy 0.419 4/200 - train: loss 1.795; val: loss 1.558 accuracy 0.524 5/200 - train: loss 1.682; val: loss 1.414 accuracy 0.594 6/200 - train: loss 1.607; val: loss 1.357 accuracy 0.626 7/200 - train: loss 1.541; val: loss 1.294 accuracy 0.644 8/200 - train: loss 1.480; val: loss 1.237 accuracy 0.686 9/200 - train: loss 1.434; val: loss 1.202 accuracy 0.697 10/200 - train: loss 1.398; val: loss 1.156 accuracy 0.723 11/200 - train: loss 1.346; val: loss 1.116 accuracy 0.739 12/200 - train: loss 1.322; val: loss 1.056 accuracy 0.771 13/200 - train: loss 1.286; val: loss 1.060 accuracy 0.773 14/200 - train: loss 1.254; val: loss 1.026 accuracy 0.782 15/200 - train: loss 1.229; val: loss 0.989 accuracy 0.796 16/200 - train: loss 1.207; val: loss 0.995 accuracy 0.794 17/200 - train: loss 1.198; val: loss 0.946 accuracy 0.815 18/200 - train: loss 1.165; val: loss 0.937 accuracy 0.826 19/200 - train: loss 1.150; val: loss 0.986 accuracy 0.799 20/200 - train: loss 1.138; val: loss 0.927 accuracy 0.824 21/200 - train: loss 1.123; val: loss 0.913 accuracy 0.835 22/200 - train: loss 1.103; val: loss 0.950 accuracy 0.815 23/200 - train: loss 1.098; val: loss 0.878 accuracy 0.850 24/200 - train: loss 1.094; val: loss 0.907 accuracy 0.833 25/200 - train: loss 1.079; val: loss 0.892 accuracy 0.836 26/200 - train: loss 1.069; val: loss 0.866 accuracy 0.855 27/200 - train: loss 1.062; val: loss 0.876 accuracy 0.847 28/200 - train: loss 1.065; val: loss 0.918 accuracy 0.825 29/200 - train: loss 1.049; val: loss 0.836 accuracy 0.865 30/200 - train: loss 1.051; val: loss 0.843 accuracy 0.863 31/200 - train: loss 1.030; val: loss 0.893 accuracy 0.832 32/200 - train: loss 1.051; val: loss 0.836 accuracy 0.863 33/200 - train: loss 1.025; val: loss 0.851 accuracy 0.854 34/200 - train: loss 1.021; val: loss 0.844 accuracy 0.864 35/200 - train: loss 1.019; val: loss 0.897 accuracy 0.833 36/200 - train: loss 1.021; val: loss 0.851 accuracy 0.850 37/200 - train: loss 1.021; val: loss 0.839 accuracy 0.863 38/200 - train: loss 1.013; val: loss 0.926 accuracy 0.826 39/200 - train: loss 1.018; val: loss 0.848 accuracy 0.858 40/200 - train: loss 1.012; val: loss 0.839 accuracy 0.859 41/200 - train: loss 1.018; val: loss 0.884 accuracy 0.841 42/200 - train: loss 1.017; val: loss 0.959 accuracy 0.805 43/200 - train: loss 1.016; val: loss 0.890 accuracy 0.841 44/200 - train: loss 1.017; val: loss 0.884 accuracy 0.836 45/200 - train: loss 1.022; val: loss 0.863 accuracy 0.853 46/200 - train: loss 1.011; val: loss 0.963 accuracy 0.803 47/200 - train: loss 0.994; val: loss 0.881 accuracy 0.844 48/200 - train: loss 1.000; val: loss 0.863 accuracy 0.845 49/200 - train: loss 1.011; val: loss 0.940 accuracy 0.816 50/200 - train: loss 0.991; val: loss 1.022 accuracy 0.787 51/200 - train: loss 0.994; val: loss 0.941 accuracy 0.822 52/200 - train: loss 0.992; val: loss 1.092 accuracy 0.749 53/200 - train: loss 1.003; val: loss 1.059 accuracy 0.764 54/200 - train: loss 1.005; val: loss 0.962 accuracy 0.795 55/200 - train: loss 0.986; val: loss 0.873 accuracy 0.842 56/200 - train: loss 0.997; val: loss 0.954 accuracy 0.818 57/200 - train: loss 1.001; val: loss 1.419 accuracy 0.607 58/200 - train: loss 0.987; val: loss 1.045 accuracy 0.757 59/200 - train: loss 0.979; val: loss 1.174 accuracy 0.775 60/200 - train: loss 0.987; val: loss 1.037 accuracy 0.774 61/200 - train: loss 0.993; val: loss 1.052 accuracy 0.768 62/200 - train: loss 0.980; val: loss 1.090 accuracy 0.747 63/200 - train: loss 0.975; val: loss 0.977 accuracy 0.802 64/200 - train: loss 0.980; val: loss 0.911 accuracy 0.832 65/200 - train: loss 0.985; val: loss 0.861 accuracy 0.856 66/200 - train: loss 0.971; val: loss 0.909 accuracy 0.825 67/200 - train: loss 0.971; val: loss 1.026 accuracy 0.778 68/200 - train: loss 0.978; val: loss 1.008 accuracy 0.782 69/200 - train: loss 0.976; val: loss 0.994 accuracy 0.804 70/200 - train: loss 0.991; val: loss 1.017 accuracy 0.783 71/200 - train: loss 0.971; val: loss 1.070 accuracy 0.755 72/200 - train: loss 0.974; val: loss 1.003 accuracy 0.785 73/200 - train: loss 0.966; val: loss 0.856 accuracy 0.854 74/200 - train: loss 0.961; val: loss 0.988 accuracy 0.794 75/200 - train: loss 0.963; val: loss 1.819 accuracy 0.520 76/200 - train: loss 0.969; val: loss 1.111 accuracy 0.773 77/200 - train: loss 0.971; val: loss 1.103 accuracy 0.748 78/200 - train: loss 0.950; val: loss 1.076 accuracy 0.766 79/200 - train: loss 0.950; val: loss 0.925 accuracy 0.817 80/200 - train: loss 0.952; val: loss 0.911 accuracy 0.840 81/200 - train: loss 0.949; val: loss 1.011 accuracy 0.788 82/200 - train: loss 0.951; val: loss 1.048 accuracy 0.776 83/200 - train: loss 0.956; val: loss 1.278 accuracy 0.668 84/200 - train: loss 0.947; val: loss 0.885 accuracy 0.840 85/200 - train: loss 0.947; val: loss 0.942 accuracy 0.809 86/200 - train: loss 0.945; val: loss 0.871 accuracy 0.847 87/200 - train: loss 0.941; val: loss 0.839 accuracy 0.861 88/200 - train: loss 0.933; val: loss 1.012 accuracy 0.779 89/200 - train: loss 0.938; val: loss 0.869 accuracy 0.859 90/200 - train: loss 0.935; val: loss 1.051 accuracy 0.767 91/200 - train: loss 0.943; val: loss 1.090 accuracy 0.752 92/200 - train: loss 0.932; val: loss 0.874 accuracy 0.843 93/200 - train: loss 0.939; val: loss 0.985 accuracy 0.796 94/200 - train: loss 0.937; val: loss 0.852 accuracy 0.859 95/200 - train: loss 0.932; val: loss 0.865 accuracy 0.854 96/200 - train: loss 0.917; val: loss 0.804 accuracy 0.868 97/200 - train: loss 0.920; val: loss 0.863 accuracy 0.845 98/200 - train: loss 0.913; val: loss 0.910 accuracy 0.827 99/200 - train: loss 0.923; val: loss 1.037 accuracy 0.762 100/200 - train: loss 0.911; val: loss 0.858 accuracy 0.851 101/200 - train: loss 0.926; val: loss 1.070 accuracy 0.767 102/200 - train: loss 0.905; val: loss 0.881 accuracy 0.845 103/200 - train: loss 0.907; val: loss 1.018 accuracy 0.782 104/200 - train: loss 0.913; val: loss 0.976 accuracy 0.788 105/200 - train: loss 0.912; val: loss 0.860 accuracy 0.850 106/200 - train: loss 0.902; val: loss 0.890 accuracy 0.833 107/200 - train: loss 0.900; val: loss 0.880 accuracy 0.846 108/200 - train: loss 0.896; val: loss 0.904 accuracy 0.839 109/200 - train: loss 0.888; val: loss 0.824 accuracy 0.864 110/200 - train: loss 0.886; val: loss 0.848 accuracy 0.857 111/200 - train: loss 0.884; val: loss 0.796 accuracy 0.879 112/200 - train: loss 0.874; val: loss 0.844 accuracy 0.855 113/200 - train: loss 0.879; val: loss 1.251 accuracy 0.649 114/200 - train: loss 0.887; val: loss 0.869 accuracy 0.856 115/200 - train: loss 0.878; val: loss 0.861 accuracy 0.855 116/200 - train: loss 0.870; val: loss 1.118 accuracy 0.738 117/200 - train: loss 0.874; val: loss 0.822 accuracy 0.868 118/200 - train: loss 0.870; val: loss 0.991 accuracy 0.795 119/200 - train: loss 0.875; val: loss 0.791 accuracy 0.882 120/200 - train: loss 0.859; val: loss 0.967 accuracy 0.811 121/200 - train: loss 0.863; val: loss 0.849 accuracy 0.859 122/200 - train: loss 0.864; val: loss 0.828 accuracy 0.869 123/200 - train: loss 0.846; val: loss 0.789 accuracy 0.882 124/200 - train: loss 0.855; val: loss 0.781 accuracy 0.884 125/200 - train: loss 0.842; val: loss 0.929 accuracy 0.823 126/200 - train: loss 0.848; val: loss 0.788 accuracy 0.878 127/200 - train: loss 0.835; val: loss 0.780 accuracy 0.883 128/200 - train: loss 0.839; val: loss 0.772 accuracy 0.888 129/200 - train: loss 0.830; val: loss 0.828 accuracy 0.862 130/200 - train: loss 0.834; val: loss 0.795 accuracy 0.872 131/200 - train: loss 0.827; val: loss 0.741 accuracy 0.898 132/200 - train: loss 0.821; val: loss 0.814 accuracy 0.875 133/200 - train: loss 0.818; val: loss 0.756 accuracy 0.895 134/200 - train: loss 0.821; val: loss 0.761 accuracy 0.891 135/200 - train: loss 0.823; val: loss 0.899 accuracy 0.836 136/200 - train: loss 0.809; val: loss 0.772 accuracy 0.890 137/200 - train: loss 0.809; val: loss 0.798 accuracy 0.875 138/200 - train: loss 0.801; val: loss 0.789 accuracy 0.875 139/200 - train: loss 0.801; val: loss 0.789 accuracy 0.882 140/200 - train: loss 0.795; val: loss 0.816 accuracy 0.870 141/200 - train: loss 0.785; val: loss 0.851 accuracy 0.856 142/200 - train: loss 0.790; val: loss 0.852 accuracy 0.857 143/200 - train: loss 0.791; val: loss 0.830 accuracy 0.866 144/200 - train: loss 0.783; val: loss 0.696 accuracy 0.922 145/200 - train: loss 0.789; val: loss 0.768 accuracy 0.888 146/200 - train: loss 0.781; val: loss 0.810 accuracy 0.881 147/200 - train: loss 0.778; val: loss 0.755 accuracy 0.895 148/200 - train: loss 0.766; val: loss 0.714 accuracy 0.916 149/200 - train: loss 0.771; val: loss 0.720 accuracy 0.914 150/200 - train: loss 0.761; val: loss 0.806 accuracy 0.878 151/200 - train: loss 0.760; val: loss 0.771 accuracy 0.887 152/200 - train: loss 0.756; val: loss 0.849 accuracy 0.886 153/200 - train: loss 0.758; val: loss 0.695 accuracy 0.921 154/200 - train: loss 0.753; val: loss 0.662 accuracy 0.934 155/200 - train: loss 0.745; val: loss 0.703 accuracy 0.917 156/200 - train: loss 0.741; val: loss 0.724 accuracy 0.912 157/200 - train: loss 0.738; val: loss 0.704 accuracy 0.917 158/200 - train: loss 0.732; val: loss 0.825 accuracy 0.873 159/200 - train: loss 0.732; val: loss 0.831 accuracy 0.867 160/200 - train: loss 0.725; val: loss 0.748 accuracy 0.896 161/200 - train: loss 0.726; val: loss 0.701 accuracy 0.920 162/200 - train: loss 0.721; val: loss 0.722 accuracy 0.914 163/200 - train: loss 0.718; val: loss 0.682 accuracy 0.926 164/200 - train: loss 0.716; val: loss 0.677 accuracy 0.931 165/200 - train: loss 0.713; val: loss 0.651 accuracy 0.941 166/200 - train: loss 0.704; val: loss 0.643 accuracy 0.945 167/200 - train: loss 0.695; val: loss 0.667 accuracy 0.932 168/200 - train: loss 0.699; val: loss 0.653 accuracy 0.939 169/200 - train: loss 0.698; val: loss 0.660 accuracy 0.934 170/200 - train: loss 0.698; val: loss 0.668 accuracy 0.936 171/200 - train: loss 0.695; val: loss 0.635 accuracy 0.946 172/200 - train: loss 0.684; val: loss 0.636 accuracy 0.946 173/200 - train: loss 0.685; val: loss 0.641 accuracy 0.945 174/200 - train: loss 0.691; val: loss 0.632 accuracy 0.949 175/200 - train: loss 0.684; val: loss 0.639 accuracy 0.944 176/200 - train: loss 0.685; val: loss 0.627 accuracy 0.949 177/200 - train: loss 0.680; val: loss 0.629 accuracy 0.948 178/200 - train: loss 0.671; val: loss 0.627 accuracy 0.949 179/200 - train: loss 0.675; val: loss 0.620 accuracy 0.951 180/200 - train: loss 0.673; val: loss 0.629 accuracy 0.948 181/200 - train: loss 0.664; val: loss 0.614 accuracy 0.951 182/200 - train: loss 0.666; val: loss 0.616 accuracy 0.953 183/200 - train: loss 0.662; val: loss 0.614 accuracy 0.953 184/200 - train: loss 0.658; val: loss 0.613 accuracy 0.953 185/200 - train: loss 0.657; val: loss 0.612 accuracy 0.954 186/200 - train: loss 0.653; val: loss 0.613 accuracy 0.954 187/200 - train: loss 0.658; val: loss 0.607 accuracy 0.956 188/200 - train: loss 0.650; val: loss 0.609 accuracy 0.955 189/200 - train: loss 0.650; val: loss 0.606 accuracy 0.956 190/200 - train: loss 0.649; val: loss 0.606 accuracy 0.957 191/200 - train: loss 0.649; val: loss 0.606 accuracy 0.957 192/200 - train: loss 0.648; val: loss 0.606 accuracy 0.957 193/200 - train: loss 0.642; val: loss 0.606 accuracy 0.957 194/200 - train: loss 0.641; val: loss 0.604 accuracy 0.958 195/200 - train: loss 0.645; val: loss 0.605 accuracy 0.958 196/200 - train: loss 0.643; val: loss 0.604 accuracy 0.958 197/200 - train: loss 0.647; val: loss 0.604 accuracy 0.959 198/200 - train: loss 0.645; val: loss 0.604 accuracy 0.959 199/200 - train: loss 0.639; val: loss 0.604 accuracy 0.959 200/200 - train: loss 0.646; val: loss 0.604 accuracy 0.958
def history_plot_train_val(history, key):
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['train ' + key]) + 1)
ax.plot(xs, history['train ' + key], '.-', label='train')
ax.plot(xs, history['val ' + key], '.-', label='val')
ax.set_xlabel('epoch')
ax.set_ylabel(key)
ax.legend()
ax.grid()
plt.show()
def history_plot(history, key):
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history[key]) + 1)
ax.plot(xs, history[key], '.-')
ax.set_xlabel('epoch')
ax.set_ylabel(key)
ax.grid()
plt.show()
history_plot_train_val(history, 'loss')
history_plot(history, 'val acc')