According to "Bottleneck Transformers for Visual Recognition", arXiv:2101.11605 [cs.CV]
Parts of code from https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
BoT50 warrants longer training in order to show significant improvement over R50.
BoTNet (self-attention) benefits more from extra augmentations such as multi-scale jitter compared to ResNet (pure convolutions).
Self-attention replacement is more efficient than stacking convolutions.
Replacing 3 spatial convolutions with all2all
attention gives more improvement in the metrics compared to stacking 50 more layers of convolutions (R101), and is competitive with stacking 100 more layers (R152).
BoT50 does not provide significant gains over R50 on ImageNet though it does provide the benefit of reducing the parameters while maintaining comparable computation.
ResNets and SENets perform really well in the lower accuracy regime, outperforming both EfficientNets and BoTNets.
EfficientNets may be better in terms of M.Adds, but do not map as well as BoTNets, onto the latest hardware accelerators such as TPUs.
ResNets and SENets achieve strong performance in the improved EfficientNet training setting. They are strong enough that they can outperform all the EfficientNets.
Pure convolutional models such as ResNets and SENets are still the best performing models until an accuracy regime of 83% top-1 accuracy.
We recommend using absolute position encodings for image classification.
Imports
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
import ignite.metrics
import ignite.contrib.handlers
Configuration
DATA_DIR='./data'
NUM_CLASSES = 10
NUM_WORKERS = 8
BATCH_SIZE = 32
EPOCHS = 100
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)
device: cuda
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor()
])
train_dset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=train_transform)
test_dset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transforms.ToTensor())
Files already downloaded and verified Files already downloaded and verified
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_WORKERS, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, pin_memory=True)
def dataset_show_image(dset, idx):
X, Y = dset[idx]
title = "Ground truth: {}".format(dset.classes[Y])
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_axis_off()
ax.imshow(np.moveaxis(X.numpy(), 0, -1))
ax.set_title(title)
plt.show()
dataset_show_image(test_dset, 1)
Attention: $$ O=V\mathrm{softmax}\left[\frac{1}{\sqrt{c}}(K^{\intercal}Q + P^{\intercal}Q)\right]\,. $$ Here $P$ represents position encoding. $V,K,P,Q\in\mathbb{R}^{c\times n}$, $c$ is the number of channels, $n$ is the number of elements. Thus $K^{\intercal}Q,P^{\intercal}Q\in\mathbb{R}^{n\times n}$.
class RelativePosEnc(nn.Module):
def forward(self, q):
if not hasattr(self, 'pos_h'):
c, h, w = q.shape[-3:]
self.pos_h = nn.Parameter(torch.zeros(2 * h - 1, c, device=q.device))
self.pos_w = nn.Parameter(torch.zeros(2 * w - 1, c, device=q.device))
rel_h = self.pos_h @ q.movedim(4, 2)
rel_w = self.pos_w @ q.movedim(3, 2)
rel_h = self.rel_to_abs(rel_h).movedim(2, 4)
rel_w = self.rel_to_abs(rel_w).movedim(2, 3)
pos_enc = rel_h[:, :, :, None] + rel_w[:, :, None, :]
pos_enc = pos_enc.flatten(-2).flatten(2, 3)
return pos_enc
@staticmethod
def rel_to_abs(x):
"""
Converts relative indexing to absolute.
Input shape: [..., 2 * length - 1, length]
Output shape: [..., length, length]
"""
shape = x.shape
length = shape[-1]
# reshape [..., 2*length-1, length] -> [..., 2*length-1, length+1] :
x = F.pad(x, (0, 0, 0, 1)) # [..., 2*length, length]
x = x.flatten(-2) # [..., 2*length**2]
x = F.pad(x, (0, length-1)) # [..., 2*length**2+length-1]
x = x.view(*shape[:-1], length+1) # [..., 2*length-1, length+1]
# take the right elements
x = x[..., length-1:, :length]
return x
class AbsolutePosEnc(nn.Module):
def forward(self, q):
if not hasattr(self, 'pos_h'):
c, h, w = q.shape[-3:]
self.pos_h = nn.Parameter(torch.zeros(h, c, device=q.device))
self.pos_w = nn.Parameter(torch.zeros(w, c, device=q.device))
pos_enc = self.pos_h[:, None] + self.pos_w[None, :]
pos_enc = pos_enc.flatten(0, 1) @ q.flatten(-2)
return pos_enc
class SelfAttention2d(nn.Module):
def __init__(self, in_channels, out_channels, q_channels, v_channels, heads, pos_enc, p_drop=0.):
super().__init__()
self.heads= heads
self.q_channels = q_channels
self.scale = q_channels**-0.5
self.to_pos_enc = pos_enc()
self.to_keys = nn.Conv2d(in_channels, q_channels * heads, 1)
self.to_queries = nn.Conv2d(in_channels, q_channels * heads, 1)
self.to_values = nn.Conv2d(in_channels, v_channels * heads, 1)
self.unifyheads = nn.Conv2d(v_channels * heads, out_channels, 1)
self.attn_drop = nn.Dropout(p_drop)
self.resid_drop = nn.Dropout(p_drop)
def forward(self, x):
b, _, h, w = x.shape
keys = self.to_keys(x).view(b, self.heads, self.q_channels, h * w)
queries = self.to_queries(x).view(b, self.heads, self.q_channels, h, w)
values = self.to_values(x).view(b, self.heads, -1, h * w)
pos_enc = self.to_pos_enc(queries)
queries = queries.flatten(-2)
att = keys.transpose(-2, -1) @ queries + pos_enc
att = F.softmax(att * self.scale, dim=-2)
att = self.attn_drop(att)
out = values @ att
out = out.view(b, -1, h, w)
out = self.unifyheads(out)
out = self.resid_drop(out)
return out
class AttentionBlock(nn.Sequential):
def __init__(self, channels, heads=4, p_drop=0.):
q_channels = channels // heads
super().__init__(
SelfAttention2d(channels, channels, q_channels, q_channels, heads, RelativePosEnc, p_drop),
nn.BatchNorm2d(channels),
nn.ReLU(inplace=True)
)
class ConvBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, act=True):
padding = (kernel_size - 1) // 2
layers = [
nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm2d(out_channels)
]
if act: layers.append(nn.ReLU(inplace=True))
super().__init__(*layers)
class BoTResidual(nn.Sequential):
def __init__(self, in_channels, out_channels, expansion=4, heads=4, p_drop=0.):
bottl_channels = out_channels // expansion
super().__init__(
ConvBlock(in_channels, bottl_channels, 1),
AttentionBlock(bottl_channels, heads, p_drop),
ConvBlock(bottl_channels, out_channels, 1, act=False)
)
class BottleneckResidual(nn.Sequential):
def __init__(self, in_channels, out_channels, expansion=4):
res_channels = out_channels // expansion
super().__init__(
ConvBlock(in_channels, res_channels, 1),
ConvBlock(res_channels, res_channels),
ConvBlock(res_channels, out_channels, 1, act=False)
)
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, residual):
super().__init__()
self.shortcut = self.get_shortcut(in_channels, out_channels)
self.residual = residual(in_channels, out_channels)
self.act = nn.ReLU(inplace=True)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
out = self.shortcut(x) + self.gamma * self.residual(x)
return self.act(out)
def get_shortcut(self, in_channels, out_channels):
if in_channels != out_channels:
shortcut = ConvBlock(in_channels, out_channels, 1, act=False)
else:
shortcut = nn.Identity()
return shortcut
class ResidualStack(nn.Sequential):
def __init__(self, in_channels, out_channels, repetitions, strides, residual):
layers = []
for rep, stride in zip(repetitions, strides):
if stride > 1:
layers.append(nn.MaxPool2d(stride))
for _ in range(rep):
layers.append(ResidualBlock(in_channels, out_channels, residual))
in_channels = out_channels
out_channels = out_channels * 2
super().__init__(*layers)
class Stem(nn.Sequential):
def __init__(self, in_channels=3, channel_list=[32, 32, 64], stride=2):
layers = [ConvBlock(in_channels, channel_list[0], stride=stride)]
for in_channels, out_channels in zip(channel_list, channel_list[1:]):
layers.append(ConvBlock(in_channels, out_channels))
super().__init__(*layers)
class Head(nn.Sequential):
def __init__(self, in_channels, classes, p_drop=0.):
super().__init__(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Dropout(p_drop),
nn.Linear(in_channels, classes)
)
class BoTNet(nn.Sequential):
def __init__(self, repetitions_conv, repetitions_trans, classes, strides, p_drop=0.):
num_conv, num_trans = len(repetitions_conv), len(repetitions_trans)
strides_conv = strides[1:1+num_conv]
strides_trans = strides[1+num_conv:1+num_conv+num_trans]
out_ch0 = 64
out_ch1 = out_ch0 * 4
out_ch2 = out_ch1 * 2**(num_conv - 1)
out_ch3 = out_ch2 * 2**num_trans
super().__init__(
Stem(stride=strides[0]),
ResidualStack(out_ch0, out_ch1, repetitions_conv, strides_conv, BottleneckResidual),
ResidualStack(out_ch2, out_ch2 * 2, repetitions_trans, strides_trans, BoTResidual),
Head(out_ch3, classes, p_drop)
)
model = BoTNet([2, 2], [2, 2], NUM_CLASSES, strides=[1, 1, 2, 2, 2], p_drop=0.3)
@torch.no_grad()
def init_linear(m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
model.apply(init_linear);
model.to(DEVICE);
Initialize position encoding:
model(torch.rand(2, 3, 32, 32).to(DEVICE));
print("Number of parameters:", sum(p.numel() for p in model.parameters()))
Number of parameters: 10722674
class History:
def __init__(self):
self.values = defaultdict(list)
def append(self, key, value):
self.values[key].append(value)
def reset(self):
for k in self.values.keys():
self.values[k] = []
def _begin_plot(self):
self.fig = plt.figure()
self.ax = self.fig.add_subplot(111)
def _end_plot(self, ylabel):
self.ax.set_xlabel('epoch')
self.ax.set_ylabel(ylabel)
plt.show()
def _plot(self, key, line_type='-', label=None):
if label is None: label=key
xs = np.arange(1, len(self.values[key])+1)
self.ax.plot(xs, self.values[key], line_type, label)
def plot(self, key):
self._begin_plot()
self._plot(key, '-')
self._end_plot(key)
def plot_train_val(self, key):
self._begin_plot()
self._plot('train ' + key, '.-', 'train')
self._plot('val ' + key, '.-', 'val')
self.ax.legend()
self._end_plot(key)
def separate_parameters(model):
# biases, and batchnorm weights will not be decayed for regularization
parameters_decay = set()
parameters_no_decay = set()
modules_weight_decay = (nn.Linear, nn.Conv2d)
modules_no_weight_decay = (nn.BatchNorm2d,)
for m_name, m in model.named_modules():
for param_name, param in m.named_parameters():
full_param_name = f"{m_name}.{param_name}" if m_name else param_name
if isinstance(m, modules_no_weight_decay):
parameters_no_decay.add(full_param_name)
elif param_name.endswith("bias"):
parameters_no_decay.add(full_param_name)
elif isinstance(m, ResidualBlock) and param_name.endswith("gamma"):
parameters_no_decay.add(full_param_name)
elif isinstance(m, (RelativePosEnc, AbsolutePosEnc)) and (
param_name.endswith("pos_h") or param_name.endswith("pos_w")):
parameters_no_decay.add(full_param_name)
elif isinstance(m, modules_weight_decay):
parameters_decay.add(full_param_name)
return parameters_decay, parameters_no_decay
def get_optimizer(model, learning_rate, weight_decay):
param_dict = {pn: p for pn, p in model.named_parameters()}
parameters_decay, parameters_no_decay = separate_parameters(model)
optim_groups = [
{"params": [param_dict[pn] for pn in parameters_decay], "weight_decay": weight_decay},
{"params": [param_dict[pn] for pn in parameters_no_decay], "weight_decay": 0.0},
]
optimizer = optim.AdamW(optim_groups, lr=learning_rate)
return optimizer
optimizer = get_optimizer(model, learning_rate=1e-6, weight_decay=1e-2)
loss = nn.CrossEntropyLoss()
trainer = create_supervised_trainer(model, optimizer, loss, device=DEVICE)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3,
steps_per_epoch=len(train_loader), epochs=EPOCHS)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step());
ignite.metrics.RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
val_metrics = {"accuracy": ignite.metrics.Accuracy(), "loss": ignite.metrics.Loss(loss)}
evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=DEVICE)
history = History()
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
train_state = engine.state
epoch = train_state.epoch
max_epochs = train_state.max_epochs
train_loss = train_state.metrics["loss"]
history.append('train loss', train_loss)
evaluator.run(test_loader)
val_metrics = evaluator.state.metrics
val_loss = val_metrics["loss"]
val_acc = val_metrics["accuracy"]
history.append('val loss', val_loss)
history.append('val acc', val_acc)
print("{}/{} - train: loss {:.3f}; val: loss {:.3f} accuracy {:.3f}".format(
epoch, max_epochs, train_loss, val_loss, val_acc))
trainer.run(train_loader, max_epochs=EPOCHS);
1/100 - train: loss 1.389; val: loss 1.256 accuracy 0.551 2/100 - train: loss 1.215; val: loss 1.102 accuracy 0.612 3/100 - train: loss 1.112; val: loss 1.015 accuracy 0.638 4/100 - train: loss 1.014; val: loss 0.924 accuracy 0.677 5/100 - train: loss 0.979; val: loss 0.858 accuracy 0.696 6/100 - train: loss 0.902; val: loss 0.853 accuracy 0.701 7/100 - train: loss 0.849; val: loss 0.784 accuracy 0.728 8/100 - train: loss 0.774; val: loss 0.741 accuracy 0.741 9/100 - train: loss 0.710; val: loss 0.708 accuracy 0.758 10/100 - train: loss 0.706; val: loss 0.632 accuracy 0.783 11/100 - train: loss 0.621; val: loss 0.635 accuracy 0.784 12/100 - train: loss 0.635; val: loss 0.613 accuracy 0.790 13/100 - train: loss 0.582; val: loss 0.654 accuracy 0.788 14/100 - train: loss 0.562; val: loss 0.516 accuracy 0.821 15/100 - train: loss 0.517; val: loss 0.542 accuracy 0.823 16/100 - train: loss 0.552; val: loss 0.551 accuracy 0.816 17/100 - train: loss 0.486; val: loss 0.653 accuracy 0.790 18/100 - train: loss 0.463; val: loss 0.472 accuracy 0.834 19/100 - train: loss 0.449; val: loss 0.520 accuracy 0.828 20/100 - train: loss 0.432; val: loss 0.463 accuracy 0.846 21/100 - train: loss 0.454; val: loss 0.504 accuracy 0.838 22/100 - train: loss 0.394; val: loss 0.495 accuracy 0.838 23/100 - train: loss 0.392; val: loss 0.468 accuracy 0.851 24/100 - train: loss 0.370; val: loss 0.422 accuracy 0.866 25/100 - train: loss 0.387; val: loss 0.436 accuracy 0.855 26/100 - train: loss 0.322; val: loss 0.417 accuracy 0.862 27/100 - train: loss 0.322; val: loss 0.445 accuracy 0.855 28/100 - train: loss 0.354; val: loss 0.387 accuracy 0.872 29/100 - train: loss 0.316; val: loss 0.363 accuracy 0.880 30/100 - train: loss 0.324; val: loss 0.369 accuracy 0.877 31/100 - train: loss 0.291; val: loss 0.355 accuracy 0.886 32/100 - train: loss 0.262; val: loss 0.348 accuracy 0.892 33/100 - train: loss 0.284; val: loss 0.330 accuracy 0.894 34/100 - train: loss 0.228; val: loss 0.310 accuracy 0.902 35/100 - train: loss 0.244; val: loss 0.367 accuracy 0.890 36/100 - train: loss 0.232; val: loss 0.320 accuracy 0.898 37/100 - train: loss 0.210; val: loss 0.303 accuracy 0.906 38/100 - train: loss 0.223; val: loss 0.382 accuracy 0.889 39/100 - train: loss 0.189; val: loss 0.326 accuracy 0.904 40/100 - train: loss 0.176; val: loss 0.285 accuracy 0.910 41/100 - train: loss 0.174; val: loss 0.324 accuracy 0.901 42/100 - train: loss 0.185; val: loss 0.321 accuracy 0.903 43/100 - train: loss 0.168; val: loss 0.325 accuracy 0.900 44/100 - train: loss 0.191; val: loss 0.324 accuracy 0.905 45/100 - train: loss 0.144; val: loss 0.330 accuracy 0.907 46/100 - train: loss 0.136; val: loss 0.309 accuracy 0.913 47/100 - train: loss 0.148; val: loss 0.317 accuracy 0.911 48/100 - train: loss 0.134; val: loss 0.432 accuracy 0.878 49/100 - train: loss 0.123; val: loss 0.316 accuracy 0.911 50/100 - train: loss 0.146; val: loss 0.331 accuracy 0.913 51/100 - train: loss 0.114; val: loss 0.325 accuracy 0.913 52/100 - train: loss 0.117; val: loss 0.318 accuracy 0.915 53/100 - train: loss 0.105; val: loss 0.334 accuracy 0.913 54/100 - train: loss 0.082; val: loss 0.300 accuracy 0.923 55/100 - train: loss 0.087; val: loss 0.323 accuracy 0.917 56/100 - train: loss 0.082; val: loss 0.310 accuracy 0.919 57/100 - train: loss 0.074; val: loss 0.311 accuracy 0.917 58/100 - train: loss 0.084; val: loss 0.320 accuracy 0.920 59/100 - train: loss 0.064; val: loss 0.297 accuracy 0.926 60/100 - train: loss 0.062; val: loss 0.318 accuracy 0.920 61/100 - train: loss 0.065; val: loss 0.351 accuracy 0.912 62/100 - train: loss 0.062; val: loss 0.322 accuracy 0.919 63/100 - train: loss 0.072; val: loss 0.331 accuracy 0.920 64/100 - train: loss 0.055; val: loss 0.319 accuracy 0.925 65/100 - train: loss 0.065; val: loss 0.331 accuracy 0.925 66/100 - train: loss 0.057; val: loss 0.329 accuracy 0.921 67/100 - train: loss 0.052; val: loss 0.350 accuracy 0.918 68/100 - train: loss 0.044; val: loss 0.332 accuracy 0.924 69/100 - train: loss 0.040; val: loss 0.323 accuracy 0.926 70/100 - train: loss 0.040; val: loss 0.321 accuracy 0.927 71/100 - train: loss 0.031; val: loss 0.327 accuracy 0.928 72/100 - train: loss 0.029; val: loss 0.327 accuracy 0.933 73/100 - train: loss 0.024; val: loss 0.321 accuracy 0.927 74/100 - train: loss 0.026; val: loss 0.311 accuracy 0.933 75/100 - train: loss 0.028; val: loss 0.338 accuracy 0.927 76/100 - train: loss 0.021; val: loss 0.347 accuracy 0.931 77/100 - train: loss 0.018; val: loss 0.333 accuracy 0.934 78/100 - train: loss 0.018; val: loss 0.349 accuracy 0.933 79/100 - train: loss 0.017; val: loss 0.333 accuracy 0.932 80/100 - train: loss 0.012; val: loss 0.349 accuracy 0.930 81/100 - train: loss 0.015; val: loss 0.357 accuracy 0.931 82/100 - train: loss 0.013; val: loss 0.356 accuracy 0.935 83/100 - train: loss 0.013; val: loss 0.364 accuracy 0.932 84/100 - train: loss 0.010; val: loss 0.371 accuracy 0.933 85/100 - train: loss 0.006; val: loss 0.350 accuracy 0.937 86/100 - train: loss 0.009; val: loss 0.346 accuracy 0.938 87/100 - train: loss 0.004; val: loss 0.345 accuracy 0.935 88/100 - train: loss 0.005; val: loss 0.349 accuracy 0.937 89/100 - train: loss 0.005; val: loss 0.352 accuracy 0.938 90/100 - train: loss 0.004; val: loss 0.363 accuracy 0.938 91/100 - train: loss 0.006; val: loss 0.349 accuracy 0.939 92/100 - train: loss 0.005; val: loss 0.357 accuracy 0.937 93/100 - train: loss 0.005; val: loss 0.347 accuracy 0.938 94/100 - train: loss 0.002; val: loss 0.345 accuracy 0.939 95/100 - train: loss 0.003; val: loss 0.347 accuracy 0.938 96/100 - train: loss 0.002; val: loss 0.347 accuracy 0.938 97/100 - train: loss 0.003; val: loss 0.337 accuracy 0.940 98/100 - train: loss 0.003; val: loss 0.347 accuracy 0.940 99/100 - train: loss 0.001; val: loss 0.342 accuracy 0.939 100/100 - train: loss 0.002; val: loss 0.342 accuracy 0.940
history.plot_train_val('loss')
history.plot('val acc')