According to "Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet", arXiv:2101.11986 [cs.CV]
ViT achieves inferior performance to CNNs when trained from scratch on a midsize dataset like ImageNet.
Proposals:
Imports
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
import ignite.metrics
import ignite.contrib.handlers
Configuration
DATA_DIR='./data'
IMAGE_SIZE = 32
NUM_CLASSES = 10
NUM_WORKERS = 20
BATCH_SIZE = 32
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-1
EPOCHS = 100
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)
device: cuda
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(IMAGE_SIZE, padding=4),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor()
])
train_dset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=train_transform)
test_dset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transforms.ToTensor())
Files already downloaded and verified Files already downloaded and verified
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_WORKERS, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, pin_memory=True)
def dataset_show_image(dset, idx):
X, Y = dset[idx]
title = "Ground truth: {}".format(dset.classes[Y])
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_axis_off()
ax.imshow(np.moveaxis(X.numpy(), 0, -1))
ax.set_title(title)
plt.show()
dataset_show_image(test_dset, 1)
Utilities
def init_linear(m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
class Residual(nn.Module):
def __init__(self, *layers, shortcut=None):
super().__init__()
self.shortcut = nn.Identity() if shortcut is None else shortcut
self.residual = nn.Sequential(*layers)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
return self.shortcut(x) + self.gamma * self.residual(x)
class TakeFirst(nn.Module):
def forward(self, x):
return x[:, 0]
Attention
class SelfAttention(nn.Module):
def __init__(self, dim, head_dim, heads=8, p_drop=0.):
super().__init__()
inner_dim = head_dim * heads
self.head_shape = (heads, head_dim)
self.scale = head_dim**-0.5
self.to_keys = nn.Linear(dim, inner_dim)
self.to_queries = nn.Linear(dim, inner_dim)
self.to_values = nn.Linear(dim, inner_dim)
self.unifyheads = nn.Linear(inner_dim, dim)
self.drop = nn.Dropout(p_drop)
def forward(self, x):
q_shape = x.shape[:-1] + self.head_shape
keys = self.to_keys(x).view(q_shape).transpose(1, 2) # move head forward to the batch dim
queries = self.to_queries(x).view(q_shape).transpose(1, 2)
values = self.to_values(x).view(q_shape).transpose(1, 2)
att = queries @ keys.transpose(-2, -1)
att = F.softmax(att * self.scale, dim=-1)
out = att @ values
out = out.transpose(1, 2).contiguous().flatten(2) # move head back
out = self.unifyheads(out)
out = self.drop(out)
return out
Transformer
class FeedForward(nn.Sequential):
def __init__(self, dim, mlp_mult=4, p_drop=0.):
hidden_dim = dim * mlp_mult
super().__init__(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
nn.Dropout(p_drop)
)
class TransformerBlock(nn.Sequential):
def __init__(self, dim, head_dim, heads, mlp_mult=4, p_drop=0.):
super().__init__(
Residual(nn.LayerNorm(dim), SelfAttention(dim, head_dim, heads, p_drop)),
Residual(nn.LayerNorm(dim), FeedForward(dim, mlp_mult, p_drop=p_drop))
)
T2T module
class SoftSplit(nn.Module):
def __init__(self, in_channels, dim, kernel_size=3, stride=2):
super().__init__()
padding = (kernel_size - 1) // 2
self.unfold = nn.Unfold(kernel_size, stride=stride, padding=padding)
self.project = nn.Linear(in_channels * kernel_size**2, dim)
def forward(self, x):
out = self.unfold(x).transpose(1, 2)
out = self.project(out)
return out
class Reshape(nn.Module):
def __init__(self, shape):
super().__init__()
self.shape = shape
def forward(self, x):
out = x.transpose(1, 2).unflatten(2, self.shape)
return out
class T2TBlock(nn.Sequential):
def __init__(self, image_size, token_dim, embed_dim, heads=1, mlp_mult=1, stride=2, p_drop=0.):
super().__init__(
TransformerBlock(token_dim, token_dim // heads , heads, mlp_mult, p_drop),
Reshape((image_size, image_size)),
SoftSplit(token_dim, embed_dim, stride=stride)
)
class T2TModule(nn.Sequential):
def __init__(self, in_channels, image_size, strides, token_dim, embed_dim, p_drop=0.):
stride = strides[0]
layers = [SoftSplit(in_channels, token_dim, stride=stride)]
image_size = image_size // stride
for stride in strides[1:-1]:
layers.append(T2TBlock(image_size, token_dim, token_dim, stride=stride, p_drop=p_drop))
image_size = image_size // stride
stride = strides[-1]
layers.append(T2TBlock(image_size, token_dim, embed_dim, stride=stride, p_drop=p_drop))
super().__init__(*layers)
T2T-ViT
class PositionEmbedding(nn.Module):
def __init__(self, image_size, dim):
super().__init__()
self.pos_embedding = nn.Parameter(torch.zeros(1, image_size**2, dim))
self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
def forward(self, x):
# add positional embedding
x = x + self.pos_embedding
# add classification token
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
return x
class TransformerBackbone(nn.Sequential):
def __init__(self, dim, head_dim, heads, depth, mlp_mult=4, p_drop=0.):
layers = [TransformerBlock(dim, head_dim, heads, mlp_mult, p_drop) for _ in range(depth)]
super().__init__(*layers)
class Head(nn.Sequential):
def __init__(self, dim, classes, p_drop=0.):
super().__init__(
nn.LayerNorm(dim),
nn.Dropout(p_drop),
nn.Linear(dim, classes)
)
class T2TViT(nn.Sequential):
def __init__(self, classes, image_size, strides, token_dim, dim, head_dim, heads, backbone_depth, mlp_mult,
in_channels=3, trans_p_drop=0., head_p_drop=0.):
reduced_size = image_size // np.prod(strides)
super().__init__(
T2TModule(in_channels, image_size, strides, token_dim, dim, p_drop=trans_p_drop),
PositionEmbedding(reduced_size, dim),
TransformerBackbone(dim, head_dim, heads, backbone_depth, p_drop=trans_p_drop),
TakeFirst(),
Head(dim, classes, p_drop=head_p_drop)
)
model = T2TViT(NUM_CLASSES, IMAGE_SIZE, strides=[1, 1, 2], token_dim=64, dim=256, head_dim=64, heads=4,
backbone_depth=8, mlp_mult=2, trans_p_drop=0.3, head_p_drop=0.3)
model.apply(init_linear);
model.to(DEVICE);
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
Number of parameters: 6,623,838
class History:
def __init__(self):
self.values = defaultdict(list)
def append(self, key, value):
self.values[key].append(value)
def reset(self):
for k in self.values.keys():
self.values[k] = []
def _begin_plot(self):
self.fig = plt.figure()
self.ax = self.fig.add_subplot(111)
def _end_plot(self, ylabel):
self.ax.set_xlabel('epoch')
self.ax.set_ylabel(ylabel)
plt.show()
def _plot(self, key, line_type='-', label=None):
if label is None: label=key
xs = np.arange(1, len(self.values[key])+1)
self.ax.plot(xs, self.values[key], line_type, label=label)
def plot(self, key):
self._begin_plot()
self._plot(key, '-')
self._end_plot(key)
def plot_train_val(self, key):
self._begin_plot()
self._plot('train ' + key, '.-', 'train')
self._plot('val ' + key, '.-', 'val')
self.ax.legend()
self._end_plot(key)
def separate_parameters(model):
parameters_decay = set()
parameters_no_decay = set()
modules_weight_decay = (nn.Linear, nn.Conv2d)
modules_no_weight_decay = (nn.LayerNorm, PositionEmbedding)
for m_name, m in model.named_modules():
for param_name, param in m.named_parameters():
full_param_name = f"{m_name}.{param_name}" if m_name else param_name
if isinstance(m, modules_no_weight_decay):
parameters_no_decay.add(full_param_name)
elif param_name.endswith("bias"):
parameters_no_decay.add(full_param_name)
elif isinstance(m, Residual) and param_name.endswith("gamma"):
parameters_no_decay.add(full_param_name)
elif isinstance(m, modules_weight_decay):
parameters_decay.add(full_param_name)
# sanity check
assert len(parameters_decay & parameters_no_decay) == 0
assert len(parameters_decay) + len(parameters_no_decay) == len(list(model.parameters()))
return parameters_decay, parameters_no_decay
def get_optimizer(model, learning_rate, weight_decay):
param_dict = {pn: p for pn, p in model.named_parameters()}
parameters_decay, parameters_no_decay = separate_parameters(model)
optim_groups = [
{"params": [param_dict[pn] for pn in parameters_decay], "weight_decay": weight_decay},
{"params": [param_dict[pn] for pn in parameters_no_decay], "weight_decay": 0.0},
]
optimizer = optim.AdamW(optim_groups, lr=learning_rate)
return optimizer
loss = nn.CrossEntropyLoss()
optimizer = get_optimizer(model, learning_rate=1e-6, weight_decay=WEIGHT_DECAY)
trainer = create_supervised_trainer(model, optimizer, loss, device=DEVICE)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
steps_per_epoch=len(train_loader), epochs=EPOCHS)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step());
ignite.metrics.RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
val_metrics = {"accuracy": ignite.metrics.Accuracy(), "loss": ignite.metrics.Loss(loss)}
evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=DEVICE)
history = History()
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
train_state = engine.state
epoch = train_state.epoch
max_epochs = train_state.max_epochs
train_loss = train_state.metrics["loss"]
history.append('train loss', train_loss)
evaluator.run(test_loader)
val_metrics = evaluator.state.metrics
val_loss = val_metrics["loss"]
val_acc = val_metrics["accuracy"]
history.append('val loss', val_loss)
history.append('val acc', val_acc)
print("{}/{} - train: loss {:.3f}; val: loss {:.3f} accuracy {:.3f}".format(
epoch, max_epochs, train_loss, val_loss, val_acc))
trainer.run(train_loader, max_epochs=EPOCHS);
1/100 - train: loss 1.743; val: loss 1.580 accuracy 0.420 2/100 - train: loss 1.565; val: loss 1.538 accuracy 0.426 3/100 - train: loss 1.456; val: loss 1.337 accuracy 0.510 4/100 - train: loss 1.383; val: loss 1.255 accuracy 0.546 5/100 - train: loss 1.260; val: loss 1.223 accuracy 0.568 6/100 - train: loss 1.261; val: loss 1.222 accuracy 0.564 7/100 - train: loss 1.253; val: loss 1.102 accuracy 0.605 8/100 - train: loss 1.181; val: loss 1.157 accuracy 0.591 9/100 - train: loss 1.144; val: loss 1.017 accuracy 0.637 10/100 - train: loss 1.108; val: loss 1.048 accuracy 0.617 11/100 - train: loss 1.102; val: loss 0.967 accuracy 0.650 12/100 - train: loss 1.034; val: loss 0.956 accuracy 0.658 13/100 - train: loss 1.018; val: loss 0.914 accuracy 0.671 14/100 - train: loss 1.029; val: loss 0.951 accuracy 0.662 15/100 - train: loss 0.981; val: loss 0.929 accuracy 0.666 16/100 - train: loss 0.987; val: loss 0.900 accuracy 0.676 17/100 - train: loss 0.909; val: loss 0.818 accuracy 0.716 18/100 - train: loss 0.944; val: loss 0.882 accuracy 0.683 19/100 - train: loss 0.879; val: loss 0.783 accuracy 0.722 20/100 - train: loss 0.862; val: loss 0.840 accuracy 0.706 21/100 - train: loss 0.838; val: loss 0.779 accuracy 0.722 22/100 - train: loss 0.814; val: loss 0.882 accuracy 0.705 23/100 - train: loss 0.795; val: loss 0.710 accuracy 0.754 24/100 - train: loss 0.794; val: loss 0.713 accuracy 0.749 25/100 - train: loss 0.737; val: loss 0.644 accuracy 0.777 26/100 - train: loss 0.747; val: loss 0.661 accuracy 0.772 27/100 - train: loss 0.720; val: loss 0.686 accuracy 0.768 28/100 - train: loss 0.709; val: loss 0.703 accuracy 0.761 29/100 - train: loss 0.716; val: loss 0.592 accuracy 0.796 30/100 - train: loss 0.672; val: loss 0.603 accuracy 0.791 31/100 - train: loss 0.647; val: loss 0.603 accuracy 0.794 32/100 - train: loss 0.655; val: loss 0.659 accuracy 0.783 33/100 - train: loss 0.638; val: loss 0.599 accuracy 0.795 34/100 - train: loss 0.632; val: loss 0.602 accuracy 0.793 35/100 - train: loss 0.591; val: loss 0.534 accuracy 0.815 36/100 - train: loss 0.625; val: loss 0.588 accuracy 0.794 37/100 - train: loss 0.566; val: loss 0.625 accuracy 0.783 38/100 - train: loss 0.542; val: loss 0.578 accuracy 0.806 39/100 - train: loss 0.543; val: loss 0.531 accuracy 0.816 40/100 - train: loss 0.562; val: loss 0.560 accuracy 0.813 41/100 - train: loss 0.549; val: loss 0.497 accuracy 0.833 42/100 - train: loss 0.510; val: loss 0.519 accuracy 0.827 43/100 - train: loss 0.515; val: loss 0.501 accuracy 0.833 44/100 - train: loss 0.509; val: loss 0.489 accuracy 0.832 45/100 - train: loss 0.488; val: loss 0.462 accuracy 0.843 46/100 - train: loss 0.476; val: loss 0.468 accuracy 0.845 47/100 - train: loss 0.520; val: loss 0.469 accuracy 0.838 48/100 - train: loss 0.485; val: loss 0.449 accuracy 0.846 49/100 - train: loss 0.506; val: loss 0.442 accuracy 0.851 50/100 - train: loss 0.457; val: loss 0.475 accuracy 0.843 51/100 - train: loss 0.452; val: loss 0.449 accuracy 0.851 52/100 - train: loss 0.409; val: loss 0.449 accuracy 0.853 53/100 - train: loss 0.447; val: loss 0.419 accuracy 0.859 54/100 - train: loss 0.393; val: loss 0.472 accuracy 0.846 55/100 - train: loss 0.425; val: loss 0.373 accuracy 0.875 56/100 - train: loss 0.402; val: loss 0.413 accuracy 0.863 57/100 - train: loss 0.392; val: loss 0.374 accuracy 0.874 58/100 - train: loss 0.392; val: loss 0.402 accuracy 0.866 59/100 - train: loss 0.384; val: loss 0.377 accuracy 0.877 60/100 - train: loss 0.363; val: loss 0.384 accuracy 0.873 61/100 - train: loss 0.336; val: loss 0.335 accuracy 0.890 62/100 - train: loss 0.315; val: loss 0.378 accuracy 0.880 63/100 - train: loss 0.335; val: loss 0.357 accuracy 0.885 64/100 - train: loss 0.303; val: loss 0.376 accuracy 0.879 65/100 - train: loss 0.293; val: loss 0.349 accuracy 0.884 66/100 - train: loss 0.309; val: loss 0.406 accuracy 0.869 67/100 - train: loss 0.276; val: loss 0.319 accuracy 0.895 68/100 - train: loss 0.295; val: loss 0.375 accuracy 0.881 69/100 - train: loss 0.273; val: loss 0.332 accuracy 0.892 70/100 - train: loss 0.251; val: loss 0.311 accuracy 0.903 71/100 - train: loss 0.228; val: loss 0.328 accuracy 0.900 72/100 - train: loss 0.236; val: loss 0.351 accuracy 0.893 73/100 - train: loss 0.201; val: loss 0.317 accuracy 0.902 74/100 - train: loss 0.199; val: loss 0.324 accuracy 0.905 75/100 - train: loss 0.197; val: loss 0.312 accuracy 0.905 76/100 - train: loss 0.171; val: loss 0.299 accuracy 0.913 77/100 - train: loss 0.188; val: loss 0.332 accuracy 0.902 78/100 - train: loss 0.150; val: loss 0.313 accuracy 0.914 79/100 - train: loss 0.126; val: loss 0.338 accuracy 0.909 80/100 - train: loss 0.125; val: loss 0.327 accuracy 0.909 81/100 - train: loss 0.100; val: loss 0.322 accuracy 0.913 82/100 - train: loss 0.095; val: loss 0.311 accuracy 0.918 83/100 - train: loss 0.086; val: loss 0.332 accuracy 0.916 84/100 - train: loss 0.083; val: loss 0.303 accuracy 0.921 85/100 - train: loss 0.069; val: loss 0.321 accuracy 0.924 86/100 - train: loss 0.059; val: loss 0.338 accuracy 0.921 87/100 - train: loss 0.042; val: loss 0.334 accuracy 0.923 88/100 - train: loss 0.047; val: loss 0.338 accuracy 0.925 89/100 - train: loss 0.030; val: loss 0.361 accuracy 0.924 90/100 - train: loss 0.045; val: loss 0.349 accuracy 0.923 91/100 - train: loss 0.023; val: loss 0.337 accuracy 0.926 92/100 - train: loss 0.022; val: loss 0.345 accuracy 0.928 93/100 - train: loss 0.016; val: loss 0.352 accuracy 0.929 94/100 - train: loss 0.020; val: loss 0.351 accuracy 0.929 95/100 - train: loss 0.014; val: loss 0.356 accuracy 0.931 96/100 - train: loss 0.015; val: loss 0.359 accuracy 0.931 97/100 - train: loss 0.012; val: loss 0.354 accuracy 0.932 98/100 - train: loss 0.010; val: loss 0.350 accuracy 0.931 99/100 - train: loss 0.008; val: loss 0.352 accuracy 0.932 100/100 - train: loss 0.012; val: loss 0.353 accuracy 0.932
history.plot_train_val('loss')
history.plot('val acc')