Convolutions Attention MLPs Patches Are All You Need? 🤷♂️
ICLR 2022 submission
ConvMixer:
def ConvMixr(h,d,k,p,n):
S,C,A=Sequential,Conv2d,lambda x:S(x,GELU(),BatchNorm2d(h))
R=type('',(S,),{'forward':lambda s,x:s[0](x)+x})
return S(A(C(3,h,p,p)),*[S(R(A(C(h,h,k,groups=h,padding=k//2))),A(C(h,h,1))) for i
in range(d)],AdaptiveAvgPool2d((1,1)),Flatten(),Linear(h,n))
Imports
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
import ignite.metrics
import ignite.contrib.handlers
Configuration
DATA_DIR='./data'
NUM_CLASSES = 10
NUM_WORKERS = 8
BATCH_SIZE = 32
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-1
EPOCHS = 100
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)
device: cuda
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor()
])
train_dset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=train_transform)
test_dset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transforms.ToTensor())
Files already downloaded and verified Files already downloaded and verified
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_WORKERS, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, pin_memory=True)
Utilities
def init_linear(m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
class Residual(nn.Module):
def __init__(self, residual):
super().__init__()
self.residual = residual
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
return x + self.gamma * self.residual(x)
class ConvBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super().__init__(
nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, groups=groups),
nn.GELU(),
nn.BatchNorm2d(out_channels)
)
ConvMixer
class ConvMixerLayer(nn.Sequential):
def __init__(self, channels, kernel_size):
super().__init__(
Residual(ConvBlock(channels, channels, kernel_size, groups=channels)), # Depthwise
ConvBlock(channels, channels, 1) # Pointwise
)
class Head(nn.Sequential):
def __init__(self, in_channels, classes, p_drop=0.):
super().__init__(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Dropout(p_drop),
nn.Linear(in_channels, classes)
)
class ConvMixer(nn.Sequential):
def __init__(self, classes, channels, depth, patch_size, kernel_size, in_channels=3, head_p_drop=0.):
super().__init__(
ConvBlock(in_channels, channels, patch_size, stride=patch_size), # patch embedding
*[ ConvMixerLayer(channels, kernel_size) for _ in range(depth) ],
Head(channels, classes, head_p_drop)
)
model = ConvMixer(NUM_CLASSES, channels=256, depth=8, patch_size=1, kernel_size=7, head_p_drop=0.3)
model.apply(init_linear);
model.to(DEVICE);
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
Number of parameters: 641,042
def separate_parameters(model):
parameters_decay = set()
parameters_no_decay = set()
modules_weight_decay = (nn.Linear, nn.Conv2d)
modules_no_weight_decay = (nn.BatchNorm2d,)
for m_name, m in model.named_modules():
for param_name, param in m.named_parameters():
full_param_name = f"{m_name}.{param_name}" if m_name else param_name
if isinstance(m, modules_no_weight_decay):
parameters_no_decay.add(full_param_name)
elif param_name.endswith("bias"):
parameters_no_decay.add(full_param_name)
elif isinstance(m, Residual) and param_name.endswith("gamma"):
parameters_no_decay.add(full_param_name)
elif isinstance(m, modules_weight_decay):
parameters_decay.add(full_param_name)
# sanity check
assert len(parameters_decay & parameters_no_decay) == 0
assert len(parameters_decay) + len(parameters_no_decay) == len(list(model.parameters()))
return parameters_decay, parameters_no_decay
def get_optimizer(model, learning_rate, weight_decay):
param_dict = {pn: p for pn, p in model.named_parameters()}
parameters_decay, parameters_no_decay = separate_parameters(model)
optim_groups = [
{"params": [param_dict[pn] for pn in parameters_decay], "weight_decay": weight_decay},
{"params": [param_dict[pn] for pn in parameters_no_decay], "weight_decay": 0.0},
]
optimizer = optim.AdamW(optim_groups, lr=learning_rate)
return optimizer
loss = nn.CrossEntropyLoss()
optimizer = get_optimizer(model, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
trainer = create_supervised_trainer(model, optimizer, loss, device=DEVICE)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
steps_per_epoch=len(train_loader), epochs=EPOCHS)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step());
ignite.metrics.RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
val_metrics = {"accuracy": ignite.metrics.Accuracy(), "loss": ignite.metrics.Loss(loss)}
evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=DEVICE)
history = defaultdict(list)
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
train_state = engine.state
epoch = train_state.epoch
max_epochs = train_state.max_epochs
train_loss = train_state.metrics["loss"]
history['train loss'].append(train_loss)
evaluator.run(test_loader)
val_metrics = evaluator.state.metrics
val_loss = val_metrics["loss"]
val_acc = val_metrics["accuracy"]
history['val loss'].append(val_loss)
history['val acc'].append(val_acc)
print("{}/{} - train: loss {:.3f}; val: loss {:.3f} accuracy {:.3f}".format(
epoch, max_epochs, train_loss, val_loss, val_acc))
trainer.run(train_loader, max_epochs=EPOCHS);
1/100 - train: loss 1.966; val: loss 1.921 accuracy 0.297 2/100 - train: loss 1.821; val: loss 1.803 accuracy 0.345 3/100 - train: loss 1.663; val: loss 1.623 accuracy 0.401 4/100 - train: loss 1.558; val: loss 1.499 accuracy 0.456 5/100 - train: loss 1.439; val: loss 1.362 accuracy 0.507 6/100 - train: loss 1.315; val: loss 1.259 accuracy 0.544 7/100 - train: loss 1.218; val: loss 1.208 accuracy 0.568 8/100 - train: loss 1.122; val: loss 1.100 accuracy 0.608 9/100 - train: loss 1.029; val: loss 0.988 accuracy 0.651 10/100 - train: loss 0.941; val: loss 0.901 accuracy 0.681 11/100 - train: loss 0.785; val: loss 0.838 accuracy 0.713 12/100 - train: loss 0.743; val: loss 0.732 accuracy 0.750 13/100 - train: loss 0.733; val: loss 0.657 accuracy 0.775 14/100 - train: loss 0.632; val: loss 0.669 accuracy 0.768 15/100 - train: loss 0.585; val: loss 0.607 accuracy 0.793 16/100 - train: loss 0.530; val: loss 0.567 accuracy 0.806 17/100 - train: loss 0.543; val: loss 0.567 accuracy 0.814 18/100 - train: loss 0.454; val: loss 0.514 accuracy 0.832 19/100 - train: loss 0.463; val: loss 0.520 accuracy 0.832 20/100 - train: loss 0.420; val: loss 0.504 accuracy 0.832 21/100 - train: loss 0.449; val: loss 0.427 accuracy 0.855 22/100 - train: loss 0.442; val: loss 0.491 accuracy 0.837 23/100 - train: loss 0.425; val: loss 0.500 accuracy 0.834 24/100 - train: loss 0.388; val: loss 0.412 accuracy 0.864 25/100 - train: loss 0.385; val: loss 0.451 accuracy 0.853 26/100 - train: loss 0.408; val: loss 0.417 accuracy 0.865 27/100 - train: loss 0.381; val: loss 0.420 accuracy 0.866 28/100 - train: loss 0.362; val: loss 0.449 accuracy 0.857 29/100 - train: loss 0.368; val: loss 0.403 accuracy 0.871 30/100 - train: loss 0.345; val: loss 0.487 accuracy 0.845 31/100 - train: loss 0.337; val: loss 0.396 accuracy 0.874 32/100 - train: loss 0.326; val: loss 0.377 accuracy 0.875 33/100 - train: loss 0.330; val: loss 0.484 accuracy 0.844 34/100 - train: loss 0.337; val: loss 0.370 accuracy 0.878 35/100 - train: loss 0.332; val: loss 0.433 accuracy 0.867 36/100 - train: loss 0.319; val: loss 0.354 accuracy 0.884 37/100 - train: loss 0.312; val: loss 0.383 accuracy 0.873 38/100 - train: loss 0.285; val: loss 0.372 accuracy 0.879 39/100 - train: loss 0.279; val: loss 0.341 accuracy 0.892 40/100 - train: loss 0.281; val: loss 0.381 accuracy 0.878 41/100 - train: loss 0.295; val: loss 0.347 accuracy 0.887 42/100 - train: loss 0.268; val: loss 0.368 accuracy 0.881 43/100 - train: loss 0.289; val: loss 0.339 accuracy 0.889 44/100 - train: loss 0.273; val: loss 0.363 accuracy 0.886 45/100 - train: loss 0.240; val: loss 0.352 accuracy 0.888 46/100 - train: loss 0.239; val: loss 0.358 accuracy 0.889 47/100 - train: loss 0.254; val: loss 0.339 accuracy 0.894 48/100 - train: loss 0.242; val: loss 0.333 accuracy 0.892 49/100 - train: loss 0.235; val: loss 0.338 accuracy 0.894 50/100 - train: loss 0.232; val: loss 0.337 accuracy 0.894 51/100 - train: loss 0.213; val: loss 0.323 accuracy 0.895 52/100 - train: loss 0.225; val: loss 0.350 accuracy 0.891 53/100 - train: loss 0.232; val: loss 0.335 accuracy 0.895 54/100 - train: loss 0.204; val: loss 0.374 accuracy 0.885 55/100 - train: loss 0.225; val: loss 0.345 accuracy 0.888 56/100 - train: loss 0.186; val: loss 0.373 accuracy 0.890 57/100 - train: loss 0.181; val: loss 0.303 accuracy 0.906 58/100 - train: loss 0.180; val: loss 0.328 accuracy 0.902 59/100 - train: loss 0.168; val: loss 0.340 accuracy 0.898 60/100 - train: loss 0.170; val: loss 0.310 accuracy 0.907 61/100 - train: loss 0.152; val: loss 0.357 accuracy 0.899 62/100 - train: loss 0.164; val: loss 0.362 accuracy 0.897 63/100 - train: loss 0.116; val: loss 0.311 accuracy 0.910 64/100 - train: loss 0.133; val: loss 0.325 accuracy 0.909 65/100 - train: loss 0.123; val: loss 0.295 accuracy 0.914 66/100 - train: loss 0.135; val: loss 0.303 accuracy 0.912 67/100 - train: loss 0.119; val: loss 0.323 accuracy 0.912 68/100 - train: loss 0.120; val: loss 0.344 accuracy 0.905 69/100 - train: loss 0.097; val: loss 0.296 accuracy 0.920 70/100 - train: loss 0.095; val: loss 0.327 accuracy 0.913 71/100 - train: loss 0.104; val: loss 0.309 accuracy 0.918 72/100 - train: loss 0.070; val: loss 0.315 accuracy 0.916 73/100 - train: loss 0.084; val: loss 0.307 accuracy 0.919 74/100 - train: loss 0.079; val: loss 0.310 accuracy 0.920 75/100 - train: loss 0.059; val: loss 0.317 accuracy 0.921 76/100 - train: loss 0.055; val: loss 0.332 accuracy 0.920 77/100 - train: loss 0.055; val: loss 0.326 accuracy 0.921 78/100 - train: loss 0.045; val: loss 0.336 accuracy 0.919 79/100 - train: loss 0.041; val: loss 0.313 accuracy 0.925 80/100 - train: loss 0.034; val: loss 0.322 accuracy 0.923 81/100 - train: loss 0.026; val: loss 0.322 accuracy 0.927 82/100 - train: loss 0.029; val: loss 0.333 accuracy 0.925 83/100 - train: loss 0.026; val: loss 0.313 accuracy 0.929 84/100 - train: loss 0.019; val: loss 0.309 accuracy 0.928 85/100 - train: loss 0.021; val: loss 0.311 accuracy 0.929 86/100 - train: loss 0.018; val: loss 0.309 accuracy 0.929 87/100 - train: loss 0.011; val: loss 0.316 accuracy 0.930 88/100 - train: loss 0.010; val: loss 0.311 accuracy 0.934 89/100 - train: loss 0.010; val: loss 0.308 accuracy 0.933 90/100 - train: loss 0.006; val: loss 0.315 accuracy 0.933 91/100 - train: loss 0.005; val: loss 0.307 accuracy 0.937 92/100 - train: loss 0.005; val: loss 0.311 accuracy 0.935 93/100 - train: loss 0.004; val: loss 0.307 accuracy 0.935 94/100 - train: loss 0.003; val: loss 0.310 accuracy 0.936 95/100 - train: loss 0.003; val: loss 0.312 accuracy 0.935 96/100 - train: loss 0.002; val: loss 0.309 accuracy 0.935 97/100 - train: loss 0.004; val: loss 0.302 accuracy 0.936 98/100 - train: loss 0.002; val: loss 0.305 accuracy 0.936 99/100 - train: loss 0.003; val: loss 0.305 accuracy 0.937 100/100 - train: loss 0.002; val: loss 0.308 accuracy 0.936
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['train loss']) + 1)
ax.plot(xs, history['train loss'], '.-', label='train')
ax.plot(xs, history['val loss'], '.-', label='val')
ax.set_xlabel('epoch')
ax.set_ylabel('loss')
ax.legend()
ax.grid()
plt.show()
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['val acc']) + 1)
ax.plot(xs, history['val acc'], '-')
ax.set_xlabel('epoch')
ax.set_ylabel('val acc')
ax.grid()
plt.show()
class ConvBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super().__init__(
nn.BatchNorm2d(in_channels),
nn.GELU(),
nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, groups=groups),
)
class ConvMixerLayer(nn.Sequential):
def __init__(self, channels, kernel_size):
num_conv = (kernel_size - 3) // 2 + 1
super().__init__(
Residual( # Depthwise
nn.Sequential(*[ConvBlock(channels, channels, 3, groups=channels) for _ in range(num_conv)])
),
Residual(ConvBlock(channels, channels, 1)) # Pointwise
)
class Head(nn.Sequential):
def __init__(self, in_channels, classes, p_drop=0.):
super().__init__(
nn.BatchNorm2d(in_channels),
nn.GELU(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Dropout(p_drop),
nn.Linear(in_channels, classes)
)
class ConvMixer(nn.Sequential):
def __init__(self, classes, channels, depth, patch_size, kernel_size, in_channels=3, head_p_drop=0.):
super().__init__(
nn.Conv2d(in_channels, channels, patch_size, stride=patch_size), # patch embedding
*[ ConvMixerLayer(channels, kernel_size) for _ in range(depth) ],
Head(channels, classes, head_p_drop)
)
model = ConvMixer(NUM_CLASSES, channels=256, depth=8, patch_size=1, kernel_size=7, head_p_drop=0.3)
model.apply(init_linear);
model.to(DEVICE);
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
Number of parameters: 608,282
Configuration
LEARNING_RATE = 1e-2
WEIGHT_DECAY = 1e-2
Setup trainer
optimizer = get_optimizer(model, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
trainer = create_supervised_trainer(model, optimizer, loss, device=DEVICE)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
steps_per_epoch=len(train_loader), epochs=EPOCHS)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step());
ignite.metrics.RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
val_metrics = {"accuracy": ignite.metrics.Accuracy(), "loss": ignite.metrics.Loss(loss)}
evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=DEVICE)
history = defaultdict(list)
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
train_state = engine.state
epoch = train_state.epoch
max_epochs = train_state.max_epochs
train_loss = train_state.metrics["loss"]
history['train loss'].append(train_loss)
evaluator.run(test_loader)
val_metrics = evaluator.state.metrics
val_loss = val_metrics["loss"]
val_acc = val_metrics["accuracy"]
history['val loss'].append(val_loss)
history['val acc'].append(val_acc)
print("{}/{} - train: loss {:.3f}; val: loss {:.3f} accuracy {:.3f}".format(
epoch, max_epochs, train_loss, val_loss, val_acc))
Start training
trainer.run(train_loader, max_epochs=EPOCHS);
1/100 - train: loss 1.311; val: loss 1.245 accuracy 0.552 2/100 - train: loss 1.096; val: loss 1.008 accuracy 0.638 3/100 - train: loss 0.911; val: loss 0.874 accuracy 0.691 4/100 - train: loss 0.850; val: loss 0.772 accuracy 0.733 5/100 - train: loss 0.728; val: loss 0.682 accuracy 0.764 6/100 - train: loss 0.677; val: loss 0.617 accuracy 0.787 7/100 - train: loss 0.622; val: loss 0.596 accuracy 0.797 8/100 - train: loss 0.602; val: loss 0.548 accuracy 0.815 9/100 - train: loss 0.523; val: loss 0.535 accuracy 0.822 10/100 - train: loss 0.533; val: loss 0.528 accuracy 0.817 11/100 - train: loss 0.466; val: loss 0.484 accuracy 0.840 12/100 - train: loss 0.463; val: loss 0.513 accuracy 0.835 13/100 - train: loss 0.489; val: loss 0.495 accuracy 0.839 14/100 - train: loss 0.432; val: loss 0.468 accuracy 0.838 15/100 - train: loss 0.422; val: loss 0.488 accuracy 0.837 16/100 - train: loss 0.453; val: loss 0.440 accuracy 0.854 17/100 - train: loss 0.449; val: loss 0.544 accuracy 0.826 18/100 - train: loss 0.419; val: loss 0.518 accuracy 0.823 19/100 - train: loss 0.412; val: loss 0.472 accuracy 0.842 20/100 - train: loss 0.397; val: loss 0.487 accuracy 0.839 21/100 - train: loss 0.394; val: loss 0.431 accuracy 0.858 22/100 - train: loss 0.403; val: loss 0.437 accuracy 0.856 23/100 - train: loss 0.419; val: loss 0.540 accuracy 0.829 24/100 - train: loss 0.386; val: loss 0.388 accuracy 0.870 25/100 - train: loss 0.402; val: loss 0.553 accuracy 0.831 26/100 - train: loss 0.369; val: loss 0.553 accuracy 0.831 27/100 - train: loss 0.376; val: loss 0.479 accuracy 0.842 28/100 - train: loss 0.361; val: loss 0.435 accuracy 0.859 29/100 - train: loss 0.349; val: loss 0.419 accuracy 0.862 30/100 - train: loss 0.365; val: loss 0.435 accuracy 0.857 31/100 - train: loss 0.368; val: loss 0.519 accuracy 0.842 32/100 - train: loss 0.365; val: loss 0.459 accuracy 0.848 33/100 - train: loss 0.344; val: loss 0.590 accuracy 0.816 34/100 - train: loss 0.310; val: loss 0.424 accuracy 0.862 35/100 - train: loss 0.360; val: loss 0.469 accuracy 0.854 36/100 - train: loss 0.325; val: loss 0.467 accuracy 0.853 37/100 - train: loss 0.335; val: loss 0.405 accuracy 0.866 38/100 - train: loss 0.314; val: loss 0.328 accuracy 0.889 39/100 - train: loss 0.324; val: loss 0.365 accuracy 0.883 40/100 - train: loss 0.309; val: loss 0.444 accuracy 0.866 41/100 - train: loss 0.314; val: loss 0.358 accuracy 0.886 42/100 - train: loss 0.301; val: loss 0.406 accuracy 0.869 43/100 - train: loss 0.310; val: loss 0.428 accuracy 0.860 44/100 - train: loss 0.286; val: loss 0.337 accuracy 0.890 45/100 - train: loss 0.290; val: loss 0.449 accuracy 0.864 46/100 - train: loss 0.301; val: loss 0.341 accuracy 0.887 47/100 - train: loss 0.293; val: loss 0.363 accuracy 0.886 48/100 - train: loss 0.267; val: loss 0.382 accuracy 0.876 49/100 - train: loss 0.238; val: loss 0.304 accuracy 0.902 50/100 - train: loss 0.276; val: loss 0.314 accuracy 0.899 51/100 - train: loss 0.253; val: loss 0.323 accuracy 0.894 52/100 - train: loss 0.247; val: loss 0.386 accuracy 0.877 53/100 - train: loss 0.267; val: loss 0.439 accuracy 0.864 54/100 - train: loss 0.230; val: loss 0.314 accuracy 0.900 55/100 - train: loss 0.241; val: loss 0.377 accuracy 0.883 56/100 - train: loss 0.229; val: loss 0.311 accuracy 0.899 57/100 - train: loss 0.227; val: loss 0.349 accuracy 0.892 58/100 - train: loss 0.225; val: loss 0.309 accuracy 0.901 59/100 - train: loss 0.184; val: loss 0.308 accuracy 0.905 60/100 - train: loss 0.219; val: loss 0.349 accuracy 0.892 61/100 - train: loss 0.193; val: loss 0.289 accuracy 0.907 62/100 - train: loss 0.218; val: loss 0.335 accuracy 0.894 63/100 - train: loss 0.196; val: loss 0.334 accuracy 0.898 64/100 - train: loss 0.168; val: loss 0.287 accuracy 0.913 65/100 - train: loss 0.161; val: loss 0.317 accuracy 0.908 66/100 - train: loss 0.168; val: loss 0.766 accuracy 0.806 67/100 - train: loss 0.138; val: loss 0.297 accuracy 0.915 68/100 - train: loss 0.146; val: loss 0.334 accuracy 0.901 69/100 - train: loss 0.123; val: loss 0.338 accuracy 0.907 70/100 - train: loss 0.150; val: loss 0.297 accuracy 0.913 71/100 - train: loss 0.116; val: loss 2.391 accuracy 0.704 72/100 - train: loss 0.111; val: loss 0.314 accuracy 0.910 73/100 - train: loss 0.097; val: loss 0.257 accuracy 0.925 74/100 - train: loss 0.103; val: loss 0.279 accuracy 0.921 75/100 - train: loss 0.098; val: loss 0.308 accuracy 0.919 76/100 - train: loss 0.083; val: loss 0.307 accuracy 0.921 77/100 - train: loss 0.094; val: loss 0.288 accuracy 0.922 78/100 - train: loss 0.065; val: loss 0.279 accuracy 0.924 79/100 - train: loss 0.058; val: loss 0.292 accuracy 0.926 80/100 - train: loss 0.066; val: loss 0.289 accuracy 0.927 81/100 - train: loss 0.047; val: loss 0.270 accuracy 0.930 82/100 - train: loss 0.053; val: loss 0.299 accuracy 0.925 83/100 - train: loss 0.038; val: loss 0.296 accuracy 0.931 84/100 - train: loss 0.045; val: loss 0.277 accuracy 0.933 85/100 - train: loss 0.023; val: loss 0.293 accuracy 0.933 86/100 - train: loss 0.028; val: loss 0.294 accuracy 0.936 87/100 - train: loss 0.030; val: loss 0.366 accuracy 0.935 88/100 - train: loss 0.019; val: loss 0.295 accuracy 0.932 89/100 - train: loss 0.012; val: loss 0.282 accuracy 0.938 90/100 - train: loss 0.013; val: loss 0.284 accuracy 0.937 91/100 - train: loss 0.010; val: loss 0.296 accuracy 0.936 92/100 - train: loss 0.009; val: loss 0.278 accuracy 0.939 93/100 - train: loss 0.009; val: loss 0.291 accuracy 0.939 94/100 - train: loss 0.007; val: loss 0.281 accuracy 0.940 95/100 - train: loss 0.006; val: loss 0.283 accuracy 0.941 96/100 - train: loss 0.005; val: loss 0.285 accuracy 0.941 97/100 - train: loss 0.004; val: loss 0.277 accuracy 0.941 98/100 - train: loss 0.006; val: loss 0.277 accuracy 0.941 99/100 - train: loss 0.004; val: loss 0.280 accuracy 0.941 100/100 - train: loss 0.004; val: loss 0.280 accuracy 0.941
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['train loss']) + 1)
ax.plot(xs, history['train loss'], '.-', label='train')
ax.plot(xs, history['val loss'], '.-', label='val')
ax.set_xlabel('epoch')
ax.set_ylabel('loss')
ax.legend()
ax.grid()
plt.show()
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['val acc']) + 1)
ax.plot(xs, history['val acc'], '-')
ax.set_xlabel('epoch')
ax.set_ylabel('val acc')
ax.grid()
plt.show()
class SpatialMixer(nn.Sequential):
def __init__(self, channels, kernel_size):
num_conv = (kernel_size - 3) // 2 + 1
super().__init__(*[ConvBlock(channels, channels, 3, groups=channels) for _ in range(num_conv)])
class ChannelMixer(nn.Sequential):
def __init__(self, channels, mult=4):
mid_channels = channels * mult
super().__init__(
ConvBlock(channels, mid_channels, 1),
ConvBlock(mid_channels, channels, 1),
)
class ConvMixerLayer(nn.Sequential):
def __init__(self, channels, kernel_size):
super().__init__(
Residual(SpatialMixer(channels, kernel_size)),
Residual(ChannelMixer(channels))
)
class ConvMixer(nn.Sequential):
def __init__(self, classes, channels, depth, patch_size, kernel_size, in_channels=3, head_p_drop=0.):
super().__init__(
nn.Conv2d(in_channels, channels, patch_size, stride=patch_size), # patch embedding
*[ ConvMixerLayer(channels, kernel_size) for _ in range(depth) ],
Head(channels, classes, head_p_drop)
)
model = ConvMixer(NUM_CLASSES, channels=256, depth=8, patch_size=1, kernel_size=7, head_p_drop=0.3)
model.apply(init_linear);
model.to(DEVICE);
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
Number of parameters: 4,302,874
Configuration
LEARNING_RATE = 1e-2
WEIGHT_DECAY = 1e-2
Setup trainer
optimizer = get_optimizer(model, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
trainer = create_supervised_trainer(model, optimizer, loss, device=DEVICE)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
steps_per_epoch=len(train_loader), epochs=EPOCHS)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step());
ignite.metrics.RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
val_metrics = {"accuracy": ignite.metrics.Accuracy(), "loss": ignite.metrics.Loss(loss)}
evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=DEVICE)
history = defaultdict(list)
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
train_state = engine.state
epoch = train_state.epoch
max_epochs = train_state.max_epochs
train_loss = train_state.metrics["loss"]
history['train loss'].append(train_loss)
evaluator.run(test_loader)
val_metrics = evaluator.state.metrics
val_loss = val_metrics["loss"]
val_acc = val_metrics["accuracy"]
history['val loss'].append(val_loss)
history['val acc'].append(val_acc)
print("{}/{} - train: loss {:.3f}; val: loss {:.3f} accuracy {:.3f}".format(
epoch, max_epochs, train_loss, val_loss, val_acc))
Start training
trainer.run(train_loader, max_epochs=EPOCHS);
1/100 - train: loss 1.212; val: loss 1.113 accuracy 0.603 2/100 - train: loss 0.906; val: loss 0.832 accuracy 0.709 3/100 - train: loss 0.756; val: loss 0.689 accuracy 0.759 4/100 - train: loss 0.677; val: loss 0.563 accuracy 0.805 5/100 - train: loss 0.574; val: loss 0.545 accuracy 0.816 6/100 - train: loss 0.546; val: loss 0.592 accuracy 0.803 7/100 - train: loss 0.484; val: loss 0.530 accuracy 0.823 8/100 - train: loss 0.488; val: loss 0.481 accuracy 0.832 9/100 - train: loss 0.457; val: loss 0.502 accuracy 0.837 10/100 - train: loss 0.488; val: loss 0.446 accuracy 0.851 11/100 - train: loss 0.456; val: loss 0.523 accuracy 0.835 12/100 - train: loss 0.448; val: loss 0.424 accuracy 0.858 13/100 - train: loss 0.435; val: loss 0.451 accuracy 0.849 14/100 - train: loss 0.403; val: loss 0.418 accuracy 0.860 15/100 - train: loss 0.409; val: loss 0.420 accuracy 0.859 16/100 - train: loss 0.394; val: loss 0.419 accuracy 0.863 17/100 - train: loss 0.369; val: loss 0.482 accuracy 0.846 18/100 - train: loss 0.367; val: loss 0.422 accuracy 0.864 19/100 - train: loss 0.374; val: loss 0.513 accuracy 0.846 20/100 - train: loss 0.364; val: loss 0.379 accuracy 0.868 21/100 - train: loss 0.359; val: loss 0.434 accuracy 0.861 22/100 - train: loss 0.342; val: loss 0.349 accuracy 0.884 23/100 - train: loss 0.332; val: loss 0.401 accuracy 0.870 24/100 - train: loss 0.373; val: loss 0.456 accuracy 0.844 25/100 - train: loss 0.377; val: loss 0.424 accuracy 0.861 26/100 - train: loss 0.351; val: loss 0.395 accuracy 0.874 27/100 - train: loss 0.322; val: loss 0.352 accuracy 0.884 28/100 - train: loss 0.359; val: loss 0.510 accuracy 0.837 29/100 - train: loss 0.345; val: loss 0.418 accuracy 0.862 30/100 - train: loss 0.285; val: loss 0.418 accuracy 0.873 31/100 - train: loss 0.312; val: loss 0.434 accuracy 0.869 32/100 - train: loss 0.296; val: loss 0.455 accuracy 0.854 33/100 - train: loss 0.303; val: loss 0.389 accuracy 0.871 34/100 - train: loss 0.269; val: loss 0.352 accuracy 0.887 35/100 - train: loss 0.253; val: loss 0.385 accuracy 0.883 36/100 - train: loss 0.312; val: loss 0.376 accuracy 0.884 37/100 - train: loss 0.265; val: loss 0.373 accuracy 0.882 38/100 - train: loss 0.247; val: loss 0.304 accuracy 0.899 39/100 - train: loss 0.283; val: loss 0.454 accuracy 0.865 40/100 - train: loss 0.255; val: loss 0.340 accuracy 0.892 41/100 - train: loss 0.260; val: loss 0.336 accuracy 0.896 42/100 - train: loss 0.257; val: loss 0.347 accuracy 0.892 43/100 - train: loss 0.216; val: loss 0.345 accuracy 0.893 44/100 - train: loss 0.250; val: loss 0.359 accuracy 0.890 45/100 - train: loss 0.243; val: loss 0.375 accuracy 0.884 46/100 - train: loss 0.224; val: loss 0.296 accuracy 0.907 47/100 - train: loss 0.240; val: loss 0.349 accuracy 0.891 48/100 - train: loss 0.199; val: loss 0.324 accuracy 0.899 49/100 - train: loss 0.223; val: loss 0.405 accuracy 0.877 50/100 - train: loss 0.243; val: loss 0.329 accuracy 0.898 51/100 - train: loss 0.222; val: loss 0.302 accuracy 0.906 52/100 - train: loss 0.193; val: loss 0.324 accuracy 0.901 53/100 - train: loss 0.168; val: loss 0.316 accuracy 0.906 54/100 - train: loss 0.187; val: loss 0.290 accuracy 0.910 55/100 - train: loss 0.205; val: loss 0.339 accuracy 0.898 56/100 - train: loss 0.166; val: loss 0.332 accuracy 0.905 57/100 - train: loss 0.140; val: loss 0.319 accuracy 0.907 58/100 - train: loss 0.133; val: loss 0.319 accuracy 0.912 59/100 - train: loss 0.140; val: loss 0.300 accuracy 0.909 60/100 - train: loss 0.141; val: loss 0.271 accuracy 0.921 61/100 - train: loss 0.149; val: loss 0.291 accuracy 0.916 62/100 - train: loss 0.133; val: loss 0.267 accuracy 0.920 63/100 - train: loss 0.130; val: loss 0.308 accuracy 0.910 64/100 - train: loss 0.118; val: loss 0.324 accuracy 0.913 65/100 - train: loss 0.106; val: loss 0.314 accuracy 0.915 66/100 - train: loss 0.112; val: loss 0.298 accuracy 0.919 67/100 - train: loss 0.108; val: loss 0.317 accuracy 0.915 68/100 - train: loss 0.076; val: loss 0.311 accuracy 0.920 69/100 - train: loss 0.078; val: loss 0.294 accuracy 0.924 70/100 - train: loss 0.082; val: loss 0.287 accuracy 0.924 71/100 - train: loss 0.064; val: loss 0.316 accuracy 0.927 72/100 - train: loss 0.066; val: loss 0.273 accuracy 0.930 73/100 - train: loss 0.058; val: loss 0.275 accuracy 0.928 74/100 - train: loss 0.050; val: loss 0.306 accuracy 0.927 75/100 - train: loss 0.058; val: loss 0.310 accuracy 0.924 76/100 - train: loss 0.054; val: loss 0.314 accuracy 0.926 77/100 - train: loss 0.036; val: loss 0.288 accuracy 0.935 78/100 - train: loss 0.039; val: loss 0.275 accuracy 0.935 79/100 - train: loss 0.030; val: loss 0.306 accuracy 0.932 80/100 - train: loss 0.022; val: loss 0.271 accuracy 0.937 81/100 - train: loss 0.018; val: loss 0.329 accuracy 0.934 82/100 - train: loss 0.016; val: loss 0.277 accuracy 0.938 83/100 - train: loss 0.012; val: loss 0.300 accuracy 0.937 84/100 - train: loss 0.011; val: loss 0.306 accuracy 0.938 85/100 - train: loss 0.013; val: loss 0.269 accuracy 0.941 86/100 - train: loss 0.009; val: loss 0.297 accuracy 0.939 87/100 - train: loss 0.006; val: loss 0.276 accuracy 0.947 88/100 - train: loss 0.005; val: loss 0.294 accuracy 0.945 89/100 - train: loss 0.004; val: loss 0.285 accuracy 0.945 90/100 - train: loss 0.003; val: loss 0.277 accuracy 0.947 91/100 - train: loss 0.002; val: loss 0.279 accuracy 0.946 92/100 - train: loss 0.001; val: loss 0.272 accuracy 0.948 93/100 - train: loss 0.002; val: loss 0.275 accuracy 0.948 94/100 - train: loss 0.002; val: loss 0.264 accuracy 0.947 95/100 - train: loss 0.001; val: loss 0.266 accuracy 0.949 96/100 - train: loss 0.001; val: loss 0.269 accuracy 0.950 97/100 - train: loss 0.000; val: loss 0.259 accuracy 0.950 98/100 - train: loss 0.001; val: loss 0.262 accuracy 0.950 99/100 - train: loss 0.001; val: loss 0.263 accuracy 0.949 100/100 - train: loss 0.000; val: loss 0.256 accuracy 0.951
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['train loss']) + 1)
ax.plot(xs, history['train loss'], '.-', label='train')
ax.plot(xs, history['val loss'], '.-', label='val')
ax.set_xlabel('epoch')
ax.set_ylabel('loss')
ax.legend()
ax.grid()
plt.show()
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['val acc']) + 1)
ax.plot(xs, history['val acc'], '-')
ax.set_xlabel('epoch')
ax.set_ylabel('val acc')
ax.grid()
plt.show()