According to arXiv:2106.04803 [cs.CV]
In terms of generalization capability $$ \text{C-C-C-C} \approx \text{C-C-C-T} \geq \text{C-C-T-T} > \text{C-T-T-T} \gg \text{VIT}_{\mathrm{REL}} $$
For model capacity $$ \text{C-C-T-T} \approx \text{C-T-T-T} > \text{VIT}_{\mathrm{REL}} > \text{C-C-C-T} > \text{C-C-C-C} $$
For transferability $$ \text{C-C-T-T} > \text{C-T-T-T} $$
Imports
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
import ignite.metrics
import ignite.contrib.handlers
Configuration
DATA_DIR='./data'
IMAGE_SIZE = 32
NUM_CLASSES = 10
NUM_WORKERS = 20
BATCH_SIZE = 32
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-1
EPOCHS = 100
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)
device: cuda
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(IMAGE_SIZE, padding=4),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor()
])
train_dset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=train_transform)
test_dset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transforms.ToTensor())
Files already downloaded and verified Files already downloaded and verified
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_WORKERS, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, pin_memory=True)
def dataset_show_image(dset, idx):
X, Y = dset[idx]
title = "Ground truth: {}".format(dset.classes[Y])
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_axis_off()
ax.imshow(np.moveaxis(X.numpy(), 0, -1))
ax.set_title(title)
plt.show()
dataset_show_image(test_dset, 1)
Utilities
def init_linear(m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
class Partial:
def __init__(self, module, *args, **kwargs):
self.module = module
self.args = args
self.kwargs = kwargs
def __call__(self, *args_c, **kwargs_c):
return self.module(*args_c, *self.args, **kwargs_c, **self.kwargs)
class LayerNormChannels(nn.Module):
def __init__(self, channels):
super().__init__()
self.norm = nn.LayerNorm(channels)
def forward(self, x):
x = x.transpose(1, -1)
x = self.norm(x)
x = x.transpose(-1, 1)
return x
class Residual(nn.Module):
def __init__(self, *layers, shortcut=None):
super().__init__()
self.shortcut = nn.Identity() if shortcut is None else shortcut
self.residual = nn.Sequential(*layers)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
return self.shortcut(x) + self.gamma * self.residual(x)
class ConvBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super().__init__(
nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
groups=groups, bias=False),
nn.BatchNorm2d(out_channels),
nn.GELU()
)
def get_shortcut(in_channels, out_channels, stride):
if (in_channels == out_channels and stride == 1):
shortcut = nn.Identity()
else:
shortcut = nn.Conv2d(in_channels, out_channels, 1)
if stride > 1:
shortcut = nn.Sequential(nn.MaxPool2d(stride), shortcut)
return shortcut
Squeeze-and-Excitation, arXiv:1709.01507 [cs.CV]
class SqueezeExciteBlock(nn.Module):
def __init__(self, channels, reduction=4):
super().__init__()
self.out_channels = channels
channels_r = channels // reduction
self.se = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels_r, kernel_size=1),
nn.GELU(),
nn.Conv2d(channels_r, channels, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
return x * self.se(x)
MobileNetV2, arXiv:1801.04381 [cs.CV]
MnasNet, arXiv:1807.11626 [cs.CV]
class MBConv(Residual):
def __init__(self, in_channels, out_channels, shape, kernel_size=3, stride=1, expansion_factor=4):
mid_channels = in_channels * expansion_factor
super().__init__(
nn.BatchNorm2d(in_channels),
nn.GELU(),
ConvBlock(in_channels, mid_channels, 1), # Pointwise
ConvBlock(mid_channels, mid_channels, kernel_size, stride=stride, groups=mid_channels), # Depthwise
SqueezeExciteBlock(mid_channels),
nn.Conv2d(mid_channels, out_channels, 1), # Pointwise
shortcut = get_shortcut(in_channels, out_channels, stride)
)
Attention
Calculation of indices for relative position encoding
Flattening of indices: $i,j \rightarrow W i +j$
We want $P[W i + j, W i' + j'] = w[i - i', j - j']$
Since we index the elements of $w$ from $0$, then $P[W i + j, W i' + j'] = w[i - i' + H - 1 , j - j' + W - 1]$
Flattening: $$ P[(W i + j) H W + W i' + j'] = w[(i - i' + H - 1) (2 W - 1) + j - j' + W - 1] $$
class SelfAttention2d(nn.Module):
def __init__(self, in_channels, out_channels, head_channels, shape, p_drop=0.):
super().__init__()
self.heads = out_channels // head_channels
self.head_channels = head_channels
self.scale = head_channels**-0.5
self.to_keys = nn.Conv2d(in_channels, out_channels, 1)
self.to_queries = nn.Conv2d(in_channels, out_channels, 1)
self.to_values = nn.Conv2d(in_channels, out_channels, 1)
self.unifyheads = nn.Conv2d(out_channels, out_channels, 1)
height, width = shape
self.pos_enc = nn.Parameter(torch.randn(self.heads, (2 * height - 1) * (2 * width - 1)))
self.register_buffer("relative_indices", self.get_indices(height, width))
self.drop = nn.Dropout(p_drop)
def forward(self, x):
b, _, h, w = x.shape
keys = self.to_keys(x).view(b, self.heads, self.head_channels, -1)
values = self.to_values(x).view(b, self.heads, self.head_channels, -1)
queries = self.to_queries(x).view(b, self.heads, self.head_channels, -1)
att = keys.transpose(-2, -1) @ queries
indices = self.relative_indices.expand(self.heads, -1)
rel_pos_enc = self.pos_enc.gather(-1, indices)
rel_pos_enc = rel_pos_enc.unflatten(-1, (h * w, h * w))
att = att * self.scale + rel_pos_enc
att = F.softmax(att, dim=-2)
out = values @ att
out = out.view(b, -1, h, w)
out = self.unifyheads(out)
out = self.drop(out)
return out
@staticmethod
def get_indices(h, w):
y = torch.arange(h, dtype=torch.long)
x = torch.arange(w, dtype=torch.long)
y1, x1, y2, x2 = torch.meshgrid(y, x, y, x)
indices = (y1 - y2 + h - 1) * (2 * w - 1) + x1 - x2 + w - 1
indices = indices.flatten()
return indices
Transformer
class FeedForward(nn.Sequential):
def __init__(self, in_channels, out_channels, mult=4, p_drop=0.):
hidden_channels = in_channels * mult
super().__init__(
nn.Conv2d(in_channels, hidden_channels, 1),
nn.GELU(),
nn.Conv2d(hidden_channels, out_channels, 1),
nn.Dropout(p_drop)
)
class TransformerBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, head_channels, shape, stride=1, p_drop=0.):
shape = (shape[0] // stride, shape[1] // stride)
super().__init__(
Residual(
LayerNormChannels(in_channels),
nn.MaxPool2d(stride) if stride > 1 else nn.Identity(),
SelfAttention2d(in_channels, out_channels, head_channels, shape, p_drop=p_drop),
shortcut = get_shortcut(in_channels, out_channels, stride)
),
Residual(
LayerNormChannels(out_channels),
FeedForward(out_channels, out_channels, p_drop=p_drop)
)
)
Full model
class Stem(nn.Sequential):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__(
ConvBlock(in_channels, out_channels, 3, stride=stride),
nn.Conv2d(out_channels, out_channels, 3, padding=1)
)
class Head(nn.Sequential):
def __init__(self, channels, classes, p_drop=0.):
super().__init__(
LayerNormChannels(channels),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Dropout(p_drop),
nn.Linear(channels, classes)
)
class BlockStack(nn.Sequential):
def __init__(self, num_blocks, shape, in_channels, out_channels, stride, block):
layers = []
for _ in range(num_blocks):
layers.append(block(in_channels, out_channels, shape=shape, stride=stride))
shape = (shape[0] // stride, shape[1] // stride)
in_channels = out_channels
stride=1
super().__init__(*layers)
class CoAtNet(nn.Sequential):
def __init__(self, classes, image_size, head_channels, channel_list, num_blocks, strides=None,
in_channels=3, trans_p_drop=0., head_p_drop=0.):
if strides is None: strides = [2] * len(num_blocks)
block_list = [MBConv, # S1
MBConv, # S2
Partial(TransformerBlock, head_channels, p_drop=trans_p_drop), # S3
Partial(TransformerBlock, head_channels, p_drop=trans_p_drop)] # S4
layers = [Stem(in_channels, channel_list[0], strides[0])] # S0
in_channels = channel_list[0]
shape = (image_size, image_size)
for num, out_channels, stride, block in zip(num_blocks, channel_list[1:], strides[1:], block_list):
layers.append(BlockStack(num, shape, in_channels, out_channels, stride, block))
shape = (shape[0] // stride, shape[1] // stride)
in_channels = out_channels
layers.append(Head(in_channels, classes, p_drop=head_p_drop))
super().__init__(*layers)
model = CoAtNet(NUM_CLASSES, IMAGE_SIZE, head_channels=32, channel_list=[64, 64, 128, 256, 512],
num_blocks=[2, 2, 2, 2, 2], strides=[1, 1, 2, 2, 2],
trans_p_drop=0.3, head_p_drop=0.3)
model.apply(init_linear);
model.to(DEVICE);
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
Number of parameters: 8,109,446
def separate_parameters(model):
parameters_decay = set()
parameters_no_decay = set()
modules_weight_decay = (nn.Linear, nn.Conv2d)
modules_no_weight_decay = (nn.LayerNorm, nn.BatchNorm2d)
for m_name, m in model.named_modules():
for param_name, param in m.named_parameters():
full_param_name = f"{m_name}.{param_name}" if m_name else param_name
if isinstance(m, modules_no_weight_decay):
parameters_no_decay.add(full_param_name)
elif param_name.endswith("bias"):
parameters_no_decay.add(full_param_name)
elif isinstance(m, Residual) and param_name.endswith("gamma"):
parameters_no_decay.add(full_param_name)
elif isinstance(m, SelfAttention2d) and param_name.endswith("pos_enc"):
parameters_no_decay.add(full_param_name)
elif isinstance(m, modules_weight_decay):
parameters_decay.add(full_param_name)
# sanity check
assert len(parameters_decay & parameters_no_decay) == 0
assert len(parameters_decay) + len(parameters_no_decay) == len(list(model.parameters()))
return parameters_decay, parameters_no_decay
def get_optimizer(model, learning_rate, weight_decay):
param_dict = {pn: p for pn, p in model.named_parameters()}
parameters_decay, parameters_no_decay = separate_parameters(model)
optim_groups = [
{"params": [param_dict[pn] for pn in parameters_decay], "weight_decay": weight_decay},
{"params": [param_dict[pn] for pn in parameters_no_decay], "weight_decay": 0.0},
]
optimizer = optim.AdamW(optim_groups, lr=learning_rate)
return optimizer
loss = nn.CrossEntropyLoss()
optimizer = get_optimizer(model, learning_rate=1e-6, weight_decay=WEIGHT_DECAY)
trainer = create_supervised_trainer(model, optimizer, loss, device=DEVICE)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
steps_per_epoch=len(train_loader), epochs=EPOCHS)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step());
ignite.metrics.RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
val_metrics = {"accuracy": ignite.metrics.Accuracy(), "loss": ignite.metrics.Loss(loss)}
evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=DEVICE)
history = defaultdict(list)
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
train_state = engine.state
epoch = train_state.epoch
max_epochs = train_state.max_epochs
train_loss = train_state.metrics["loss"]
history['train loss'].append(train_loss)
evaluator.run(test_loader)
val_metrics = evaluator.state.metrics
val_loss = val_metrics["loss"]
val_acc = val_metrics["accuracy"]
history['val loss'].append(val_loss)
history['val acc'].append(val_acc)
print("{}/{} - train: loss {:.3f}; val: loss {:.3f} accuracy {:.3f}".format(
epoch, max_epochs, train_loss, val_loss, val_acc))
trainer.run(train_loader, max_epochs=EPOCHS);
1/100 - train: loss 1.673; val: loss 1.506 accuracy 0.466 2/100 - train: loss 1.366; val: loss 1.320 accuracy 0.534 3/100 - train: loss 1.238; val: loss 1.153 accuracy 0.596 4/100 - train: loss 1.152; val: loss 1.099 accuracy 0.616 5/100 - train: loss 1.098; val: loss 1.013 accuracy 0.646 6/100 - train: loss 1.024; val: loss 0.991 accuracy 0.659 7/100 - train: loss 0.941; val: loss 0.911 accuracy 0.681 8/100 - train: loss 0.903; val: loss 0.863 accuracy 0.696 9/100 - train: loss 0.847; val: loss 0.821 accuracy 0.716 10/100 - train: loss 0.833; val: loss 0.806 accuracy 0.718 11/100 - train: loss 0.773; val: loss 0.775 accuracy 0.735 12/100 - train: loss 0.761; val: loss 0.771 accuracy 0.735 13/100 - train: loss 0.718; val: loss 0.721 accuracy 0.749 14/100 - train: loss 0.690; val: loss 0.665 accuracy 0.773 15/100 - train: loss 0.671; val: loss 0.660 accuracy 0.781 16/100 - train: loss 0.611; val: loss 0.601 accuracy 0.793 17/100 - train: loss 0.575; val: loss 0.572 accuracy 0.806 18/100 - train: loss 0.633; val: loss 0.581 accuracy 0.795 19/100 - train: loss 0.588; val: loss 0.557 accuracy 0.811 20/100 - train: loss 0.584; val: loss 0.683 accuracy 0.770 21/100 - train: loss 0.589; val: loss 0.572 accuracy 0.812 22/100 - train: loss 0.576; val: loss 0.612 accuracy 0.787 23/100 - train: loss 0.532; val: loss 0.559 accuracy 0.815 24/100 - train: loss 0.521; val: loss 0.485 accuracy 0.841 25/100 - train: loss 0.506; val: loss 0.508 accuracy 0.834 26/100 - train: loss 0.547; val: loss 0.568 accuracy 0.807 27/100 - train: loss 0.517; val: loss 0.497 accuracy 0.830 28/100 - train: loss 0.486; val: loss 0.528 accuracy 0.825 29/100 - train: loss 0.495; val: loss 0.553 accuracy 0.818 30/100 - train: loss 0.476; val: loss 0.570 accuracy 0.812 31/100 - train: loss 0.482; val: loss 0.528 accuracy 0.829 32/100 - train: loss 0.462; val: loss 0.496 accuracy 0.833 33/100 - train: loss 0.434; val: loss 0.473 accuracy 0.845 34/100 - train: loss 0.469; val: loss 0.408 accuracy 0.861 35/100 - train: loss 0.450; val: loss 0.395 accuracy 0.866 36/100 - train: loss 0.391; val: loss 0.506 accuracy 0.837 37/100 - train: loss 0.417; val: loss 0.471 accuracy 0.849 38/100 - train: loss 0.409; val: loss 0.473 accuracy 0.846 39/100 - train: loss 0.370; val: loss 0.404 accuracy 0.869 40/100 - train: loss 0.361; val: loss 0.426 accuracy 0.859 41/100 - train: loss 0.378; val: loss 0.435 accuracy 0.858 42/100 - train: loss 0.413; val: loss 0.381 accuracy 0.875 43/100 - train: loss 0.377; val: loss 0.453 accuracy 0.855 44/100 - train: loss 0.351; val: loss 0.420 accuracy 0.861 45/100 - train: loss 0.361; val: loss 0.422 accuracy 0.863 46/100 - train: loss 0.332; val: loss 0.380 accuracy 0.878 47/100 - train: loss 0.312; val: loss 0.389 accuracy 0.875 48/100 - train: loss 0.313; val: loss 0.342 accuracy 0.888 49/100 - train: loss 0.312; val: loss 0.359 accuracy 0.883 50/100 - train: loss 0.315; val: loss 0.334 accuracy 0.890 51/100 - train: loss 0.296; val: loss 0.367 accuracy 0.881 52/100 - train: loss 0.327; val: loss 0.391 accuracy 0.874 53/100 - train: loss 0.287; val: loss 0.373 accuracy 0.881 54/100 - train: loss 0.289; val: loss 0.315 accuracy 0.896 55/100 - train: loss 0.269; val: loss 0.308 accuracy 0.901 56/100 - train: loss 0.272; val: loss 0.326 accuracy 0.897 57/100 - train: loss 0.252; val: loss 0.326 accuracy 0.898 58/100 - train: loss 0.245; val: loss 0.374 accuracy 0.882 59/100 - train: loss 0.261; val: loss 0.299 accuracy 0.905 60/100 - train: loss 0.245; val: loss 0.314 accuracy 0.900 61/100 - train: loss 0.231; val: loss 0.287 accuracy 0.906 62/100 - train: loss 0.230; val: loss 0.355 accuracy 0.892 63/100 - train: loss 0.210; val: loss 0.304 accuracy 0.904 64/100 - train: loss 0.206; val: loss 0.308 accuracy 0.906 65/100 - train: loss 0.198; val: loss 0.314 accuracy 0.904 66/100 - train: loss 0.204; val: loss 0.299 accuracy 0.909 67/100 - train: loss 0.173; val: loss 0.310 accuracy 0.906 68/100 - train: loss 0.172; val: loss 0.278 accuracy 0.915 69/100 - train: loss 0.174; val: loss 0.302 accuracy 0.909 70/100 - train: loss 0.158; val: loss 0.296 accuracy 0.913 71/100 - train: loss 0.160; val: loss 0.298 accuracy 0.909 72/100 - train: loss 0.141; val: loss 0.283 accuracy 0.920 73/100 - train: loss 0.120; val: loss 0.290 accuracy 0.919 74/100 - train: loss 0.149; val: loss 0.276 accuracy 0.921 75/100 - train: loss 0.115; val: loss 0.268 accuracy 0.926 76/100 - train: loss 0.089; val: loss 0.292 accuracy 0.920 77/100 - train: loss 0.110; val: loss 0.264 accuracy 0.927 78/100 - train: loss 0.079; val: loss 0.291 accuracy 0.924 79/100 - train: loss 0.093; val: loss 0.292 accuracy 0.925 80/100 - train: loss 0.071; val: loss 0.268 accuracy 0.931 81/100 - train: loss 0.061; val: loss 0.282 accuracy 0.931 82/100 - train: loss 0.055; val: loss 0.298 accuracy 0.926 83/100 - train: loss 0.044; val: loss 0.274 accuracy 0.933 84/100 - train: loss 0.045; val: loss 0.299 accuracy 0.930 85/100 - train: loss 0.053; val: loss 0.292 accuracy 0.932 86/100 - train: loss 0.037; val: loss 0.292 accuracy 0.932 87/100 - train: loss 0.027; val: loss 0.298 accuracy 0.935 88/100 - train: loss 0.026; val: loss 0.306 accuracy 0.935 89/100 - train: loss 0.020; val: loss 0.299 accuracy 0.938 90/100 - train: loss 0.017; val: loss 0.300 accuracy 0.937 91/100 - train: loss 0.013; val: loss 0.305 accuracy 0.941 92/100 - train: loss 0.011; val: loss 0.314 accuracy 0.940 93/100 - train: loss 0.011; val: loss 0.300 accuracy 0.940 94/100 - train: loss 0.010; val: loss 0.313 accuracy 0.940 95/100 - train: loss 0.009; val: loss 0.309 accuracy 0.942 96/100 - train: loss 0.007; val: loss 0.310 accuracy 0.942 97/100 - train: loss 0.005; val: loss 0.308 accuracy 0.943 98/100 - train: loss 0.005; val: loss 0.310 accuracy 0.942 99/100 - train: loss 0.004; val: loss 0.306 accuracy 0.943 100/100 - train: loss 0.004; val: loss 0.310 accuracy 0.942
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['train loss']) + 1)
ax.plot(xs, history['train loss'], '.-', label='train')
ax.plot(xs, history['val loss'], '.-', label='val')
ax.legend()
ax.set_xlabel('epoch')
ax.set_ylabel('loss')
plt.show()
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['val acc']) + 1)
ax.plot(xs, history['val acc'], '-')
ax.set_xlabel('epoch')
ax.set_ylabel('val acc')
plt.show()