According to arXiv:2103.14030 [cs.CV]
Original implementation in https://github.com/microsoft/Swin-Transformer
Imports
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
import ignite.metrics
import ignite.contrib.handlers
Configuration
DATA_DIR='./data'
IMAGE_SIZE = 32
NUM_CLASSES = 10
NUM_WORKERS = 8
BATCH_SIZE = 32
EPOCHS = 100
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-1
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)
device: cuda
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float)
])
train_dset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=train_transform)
test_dset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transforms.ToTensor())
Files already downloaded and verified Files already downloaded and verified
def dataset_show_image(dset, idx):
X, Y = dset[idx]
title = "Ground truth: {}".format(dset.classes[Y])
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_axis_off()
ax.imshow(np.moveaxis(X.numpy(), 0, -1))
ax.set_title(title)
plt.show()
dataset_show_image(test_dset, 1)
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_WORKERS, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, pin_memory=True)
Utilities
class Residual(nn.Module):
def __init__(self, *layers):
super().__init__()
self.residual = nn.Sequential(*layers)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
return x + self.gamma * self.residual(x)
class GlobalAvgPool(nn.Module):
def forward(self, x):
return x.mean(dim=-2)
Transformer
Shifted Windows Self Attention
class ShiftedWindowAttention(nn.Module):
def __init__(self, dim, head_dim, shape, window_size, shift_size=0):
super().__init__()
self.heads = dim // head_dim
self.head_dim = head_dim
self.scale = head_dim**-0.5
self.shape = shape
self.window_size = window_size
self.shift_size = shift_size
self.to_qkv = nn.Linear(dim, dim * 3)
self.unifyheads = nn.Linear(dim, dim)
self.pos_enc = nn.Parameter(torch.Tensor(self.heads, (2 * window_size - 1)**2))
self.register_buffer("relative_indices", self.get_indices(window_size))
if shift_size > 0:
self.register_buffer("mask", self.generate_mask(shape, window_size, shift_size))
def forward(self, x):
shift_size, window_size = self.shift_size, self.window_size
x = self.to_windows(x, self.shape, window_size, shift_size) # partition into windows
# self attention
qkv = self.to_qkv(x).unflatten(-1, (3, self.heads, self.head_dim)).transpose(-2, 1)
queries, keys, values = qkv.unbind(dim=2)
att = queries @ keys.transpose(-2, -1)
att = att * self.scale + self.get_rel_pos_enc(window_size) # add relative positon encoding
# masking
if shift_size > 0:
att = self.mask_attention(att)
att = F.softmax(att, dim=-1)
x = att @ values
x = x.transpose(1, 2).contiguous().flatten(-2, -1) # move head back
x = self.unifyheads(x)
x = self.from_windows(x, self.shape, window_size, shift_size) # undo partitioning into windows
return x
def to_windows(self, x, shape, window_size, shift_size):
x = x.unflatten(1, shape)
if shift_size > 0:
x = x.roll((-shift_size, -shift_size), dims=(1, 2))
x = self.split_windows(x, window_size)
return x
def from_windows(self, x, shape, window_size, shift_size):
x = self.merge_windows(x, shape, window_size)
if shift_size > 0:
x = x.roll((shift_size, shift_size), dims=(1, 2))
x = x.flatten(1, 2)
return x
def mask_attention(self, att):
num_win = self.mask.size(1)
att = att.unflatten(0, (att.size(0) // num_win, num_win))
att = att.masked_fill(self.mask, float('-inf'))
att = att.flatten(0, 1)
return att
def get_rel_pos_enc(self, window_size):
indices = self.relative_indices.expand(self.heads, -1)
rel_pos_enc = self.pos_enc.gather(-1, indices)
rel_pos_enc = rel_pos_enc.unflatten(-1, (window_size**2, window_size**2))
return rel_pos_enc
# For explanation of mask regions see Figure 4 in the article
@staticmethod
def generate_mask(shape, window_size, shift_size):
region_mask = torch.zeros(1, *shape, 1)
slices = [slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None)]
region_num = 0
for i in slices:
for j in slices:
region_mask[:, i, j, :] = region_num
region_num += 1
mask_windows = ShiftedWindowAttention.split_windows(region_mask, window_size).squeeze(-1)
diff_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
mask = diff_mask != 0
mask = mask.unsqueeze(1).unsqueeze(0) # add heads and batch dimension
return mask
@staticmethod
def split_windows(x, window_size):
n_h, n_w = x.size(1) // window_size, x.size(2) // window_size
x = x.unflatten(1, (n_h, window_size)).unflatten(-2, (n_w, window_size)) # split into windows
x = x.transpose(2, 3).flatten(0, 2) # merge batch and window numbers
x = x.flatten(-3, -2)
return x
@staticmethod
def merge_windows(x, shape, window_size):
n_h, n_w = shape[0] // window_size, shape[1] // window_size
b = x.size(0) // (n_h * n_w)
x = x.unflatten(1, (window_size, window_size))
x = x.unflatten(0, (b, n_h, n_w)).transpose(2, 3) # separate batch and window numbers
x = x.flatten(1, 2).flatten(-3, -2) # merge windows
return x
@staticmethod
def get_indices(window_size):
x = torch.arange(window_size, dtype=torch.long)
y1, x1, y2, x2 = torch.meshgrid(x, x, x, x, indexing='ij')
indices = (y1 - y2 + window_size - 1) * (2 * window_size - 1) + x1 - x2 + window_size - 1
indices = indices.flatten()
return indices
class FeedForward(nn.Sequential):
def __init__(self, dim, mult=4):
hidden_dim = dim * mult
super().__init__(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim)
)
class TransformerBlock(nn.Sequential):
def __init__(self, dim, head_dim, shape, window_size, shift_size=0, p_drop=0.):
super().__init__(
Residual(
nn.LayerNorm(dim),
ShiftedWindowAttention(dim, head_dim, shape, window_size, shift_size),
nn.Dropout(p_drop)
),
Residual(
nn.LayerNorm(dim),
FeedForward(dim),
nn.Dropout(p_drop)
)
)
Stage
class PatchMerging(nn.Module):
def __init__(self, in_dim, out_dim, shape):
super().__init__()
self.shape = shape
self.norm = nn.LayerNorm(4 * in_dim)
self.reduction = nn.Linear(4 * in_dim, out_dim, bias=False)
def forward(self, x):
x = x.unflatten(1, self.shape).movedim(-1, 1)
x = F.unfold(x, kernel_size=2, stride=2).movedim(1, -1)
x = self.norm(x)
x = self.reduction(x)
return x
class Stage(nn.Sequential):
def __init__(self, num_blocks, in_dim, out_dim, head_dim, shape, window_size, p_drop=0.):
if out_dim != in_dim:
layers = [PatchMerging(in_dim, out_dim, shape)]
shape = (shape[0] // 2, shape[1] // 2)
else:
layers = []
shift_size = window_size // 2
layers += [TransformerBlock(out_dim, head_dim, shape, window_size, 0 if (num % 2 == 0) else shift_size,
p_drop) for num in range(num_blocks)]
super().__init__(*layers)
class StageStack(nn.Sequential):
def __init__(self, num_blocks_list, dims, head_dim, shape, window_size, p_drop=0.):
layers = []
in_dim = dims[0]
for num, out_dim in zip(num_blocks_list, dims[1:]):
layers.append(Stage(num, in_dim, out_dim, head_dim, shape, window_size, p_drop))
if in_dim != out_dim:
shape = (shape[0] // 2, shape[1] // 2)
in_dim = out_dim
super().__init__(*layers)
Embedding of patches
class ToPatches(nn.Module):
def __init__(self, in_channels, dim, patch_size):
super().__init__()
self.patch_size = patch_size
patch_dim = in_channels * patch_size**2
self.proj = nn.Linear(patch_dim, dim)
self.norm = nn.LayerNorm(dim)
def forward(self, x):
x = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).movedim(1, -1)
x = self.proj(x)
x = self.norm(x)
return x
class AddPositionEmbedding(nn.Module):
def __init__(self, dim, num_patches):
super().__init__()
self.pos_embedding = nn.Parameter(torch.Tensor(num_patches, dim))
def forward(self, x):
return x + self.pos_embedding
class ToEmbedding(nn.Sequential):
def __init__(self, in_channels, dim, patch_size, num_patches, p_drop=0.):
super().__init__(
ToPatches(in_channels, dim, patch_size),
AddPositionEmbedding(dim, num_patches),
nn.Dropout(p_drop)
)
Main model
class Head(nn.Sequential):
def __init__(self, dim, classes, p_drop=0.):
super().__init__(
nn.LayerNorm(dim),
nn.GELU(),
GlobalAvgPool(),
nn.Dropout(p_drop),
nn.Linear(dim, classes)
)
class SwinTransformer(nn.Sequential):
def __init__(self, classes, image_size, num_blocks_list, dims, head_dim, patch_size, window_size,
in_channels=3, emb_p_drop=0., trans_p_drop=0., head_p_drop=0.):
reduced_size = image_size // patch_size
shape = (reduced_size, reduced_size)
num_patches = shape[0] * shape[1]
super().__init__(
ToEmbedding(in_channels, dims[0], patch_size, num_patches, emb_p_drop),
StageStack(num_blocks_list, dims, head_dim, shape, window_size, trans_p_drop),
Head(dims[-1], classes, head_p_drop)
)
self.reset_parameters()
def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.weight, 1.)
nn.init.zeros_(m.bias)
elif isinstance(m, AddPositionEmbedding):
nn.init.normal_(m.pos_embedding, mean=0.0, std=0.02)
elif isinstance(m, ShiftedWindowAttention):
nn.init.normal_(m.pos_enc, mean=0.0, std=0.02)
elif isinstance(m, Residual):
nn.init.zeros_(m.gamma)
def separate_parameters(self):
parameters_decay = set()
parameters_no_decay = set()
modules_weight_decay = (nn.Linear, )
modules_no_weight_decay = (nn.LayerNorm,)
for m_name, m in self.named_modules():
for param_name, param in m.named_parameters():
full_param_name = f"{m_name}.{param_name}" if m_name else param_name
if isinstance(m, modules_no_weight_decay):
parameters_no_decay.add(full_param_name)
elif param_name.endswith("bias"):
parameters_no_decay.add(full_param_name)
elif isinstance(m, Residual) and param_name.endswith("gamma"):
parameters_no_decay.add(full_param_name)
elif isinstance(m, AddPositionEmbedding) and param_name.endswith("pos_embedding"):
parameters_no_decay.add(full_param_name)
elif isinstance(m, ShiftedWindowAttention) and param_name.endswith("pos_enc"):
parameters_no_decay.add(full_param_name)
elif isinstance(m, modules_weight_decay):
parameters_decay.add(full_param_name)
# sanity check
assert len(parameters_decay & parameters_no_decay) == 0
assert len(parameters_decay) + len(parameters_no_decay) == len(list(model.parameters()))
return parameters_decay, parameters_no_decay
model = SwinTransformer(NUM_CLASSES, IMAGE_SIZE,
num_blocks_list=[4, 4], dims=[128, 128, 256],
head_dim=32, patch_size=2, window_size=4,
emb_p_drop=0., trans_p_drop=0., head_p_drop=0.3)
model.to(DEVICE);
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
Number of parameters: 4,124,362
def get_optimizer(model, learning_rate, weight_decay):
param_dict = {pn: p for pn, p in model.named_parameters()}
parameters_decay, parameters_no_decay = model.separate_parameters()
optim_groups = [
{"params": [param_dict[pn] for pn in parameters_decay], "weight_decay": weight_decay},
{"params": [param_dict[pn] for pn in parameters_no_decay], "weight_decay": 0.0},
]
optimizer = optim.AdamW(optim_groups, lr=learning_rate)
return optimizer
loss = nn.CrossEntropyLoss()
optimizer = get_optimizer(model, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
trainer = create_supervised_trainer(model, optimizer, loss, device=DEVICE)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
steps_per_epoch=len(train_loader), epochs=EPOCHS)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step());
ignite.metrics.RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
val_metrics = {"accuracy": ignite.metrics.Accuracy(), "loss": ignite.metrics.Loss(loss)}
evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=DEVICE)
history = defaultdict(list)
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
train_state = engine.state
epoch = train_state.epoch
max_epochs = train_state.max_epochs
train_loss = train_state.metrics["loss"]
history['train loss'].append(train_loss)
evaluator.run(test_loader)
val_metrics = evaluator.state.metrics
val_loss = val_metrics["loss"]
val_acc = val_metrics["accuracy"]
history['val loss'].append(val_loss)
history['val acc'].append(val_acc)
print("{}/{} - train: loss {:.3f}; val: loss {:.3f} accuracy {:.3f}".format(
epoch, max_epochs, train_loss, val_loss, val_acc))
trainer.run(train_loader, max_epochs=EPOCHS);
1/100 - train: loss 1.866; val: loss 1.835 accuracy 0.301 2/100 - train: loss 1.676; val: loss 1.603 accuracy 0.412 3/100 - train: loss 1.510; val: loss 1.443 accuracy 0.472 4/100 - train: loss 1.372; val: loss 1.334 accuracy 0.515 5/100 - train: loss 1.293; val: loss 1.304 accuracy 0.529 6/100 - train: loss 1.256; val: loss 1.241 accuracy 0.548 7/100 - train: loss 1.197; val: loss 1.220 accuracy 0.555 8/100 - train: loss 1.160; val: loss 1.149 accuracy 0.581 9/100 - train: loss 1.066; val: loss 1.117 accuracy 0.605 10/100 - train: loss 1.063; val: loss 0.997 accuracy 0.643 11/100 - train: loss 1.038; val: loss 1.165 accuracy 0.596 12/100 - train: loss 1.019; val: loss 1.040 accuracy 0.628 13/100 - train: loss 1.005; val: loss 0.988 accuracy 0.645 14/100 - train: loss 0.898; val: loss 0.937 accuracy 0.664 15/100 - train: loss 0.893; val: loss 0.896 accuracy 0.691 16/100 - train: loss 0.863; val: loss 0.887 accuracy 0.695 17/100 - train: loss 0.832; val: loss 0.847 accuracy 0.704 18/100 - train: loss 0.898; val: loss 0.872 accuracy 0.701 19/100 - train: loss 0.816; val: loss 0.798 accuracy 0.720 20/100 - train: loss 0.820; val: loss 0.808 accuracy 0.718 21/100 - train: loss 0.779; val: loss 0.773 accuracy 0.730 22/100 - train: loss 0.827; val: loss 0.784 accuracy 0.720 23/100 - train: loss 0.804; val: loss 0.854 accuracy 0.711 24/100 - train: loss 0.801; val: loss 0.813 accuracy 0.714 25/100 - train: loss 0.791; val: loss 0.786 accuracy 0.726 26/100 - train: loss 0.790; val: loss 0.729 accuracy 0.750 27/100 - train: loss 0.760; val: loss 0.801 accuracy 0.729 28/100 - train: loss 0.803; val: loss 0.822 accuracy 0.712 29/100 - train: loss 0.745; val: loss 0.820 accuracy 0.717 30/100 - train: loss 0.765; val: loss 0.770 accuracy 0.734 31/100 - train: loss 0.747; val: loss 0.794 accuracy 0.731 32/100 - train: loss 0.720; val: loss 0.780 accuracy 0.729 33/100 - train: loss 0.722; val: loss 0.724 accuracy 0.751 34/100 - train: loss 0.753; val: loss 0.737 accuracy 0.741 35/100 - train: loss 0.736; val: loss 0.695 accuracy 0.756 36/100 - train: loss 0.744; val: loss 0.675 accuracy 0.763 37/100 - train: loss 0.696; val: loss 0.627 accuracy 0.780 38/100 - train: loss 0.665; val: loss 0.741 accuracy 0.749 39/100 - train: loss 0.684; val: loss 0.613 accuracy 0.787 40/100 - train: loss 0.667; val: loss 0.655 accuracy 0.780 41/100 - train: loss 0.645; val: loss 0.669 accuracy 0.766 42/100 - train: loss 0.652; val: loss 0.687 accuracy 0.768 43/100 - train: loss 0.602; val: loss 0.627 accuracy 0.785 44/100 - train: loss 0.609; val: loss 0.609 accuracy 0.785 45/100 - train: loss 0.582; val: loss 0.578 accuracy 0.798 46/100 - train: loss 0.572; val: loss 0.601 accuracy 0.795 47/100 - train: loss 0.582; val: loss 0.600 accuracy 0.797 48/100 - train: loss 0.571; val: loss 0.578 accuracy 0.808 49/100 - train: loss 0.577; val: loss 0.537 accuracy 0.810 50/100 - train: loss 0.544; val: loss 0.565 accuracy 0.813 51/100 - train: loss 0.497; val: loss 0.577 accuracy 0.807 52/100 - train: loss 0.526; val: loss 0.533 accuracy 0.820 53/100 - train: loss 0.558; val: loss 0.523 accuracy 0.818 54/100 - train: loss 0.505; val: loss 0.491 accuracy 0.831 55/100 - train: loss 0.483; val: loss 0.496 accuracy 0.828 56/100 - train: loss 0.478; val: loss 0.521 accuracy 0.823 57/100 - train: loss 0.467; val: loss 0.468 accuracy 0.837 58/100 - train: loss 0.449; val: loss 0.460 accuracy 0.843 59/100 - train: loss 0.431; val: loss 0.463 accuracy 0.842 60/100 - train: loss 0.424; val: loss 0.454 accuracy 0.846 61/100 - train: loss 0.385; val: loss 0.433 accuracy 0.852 62/100 - train: loss 0.396; val: loss 0.461 accuracy 0.846 63/100 - train: loss 0.401; val: loss 0.450 accuracy 0.845 64/100 - train: loss 0.400; val: loss 0.433 accuracy 0.850 65/100 - train: loss 0.365; val: loss 0.411 accuracy 0.865 66/100 - train: loss 0.347; val: loss 0.386 accuracy 0.868 67/100 - train: loss 0.348; val: loss 0.375 accuracy 0.874 68/100 - train: loss 0.323; val: loss 0.398 accuracy 0.865 69/100 - train: loss 0.311; val: loss 0.393 accuracy 0.867 70/100 - train: loss 0.310; val: loss 0.397 accuracy 0.867 71/100 - train: loss 0.287; val: loss 0.392 accuracy 0.872 72/100 - train: loss 0.280; val: loss 0.361 accuracy 0.879 73/100 - train: loss 0.249; val: loss 0.371 accuracy 0.880 74/100 - train: loss 0.213; val: loss 0.341 accuracy 0.888 75/100 - train: loss 0.234; val: loss 0.354 accuracy 0.881 76/100 - train: loss 0.204; val: loss 0.393 accuracy 0.873 77/100 - train: loss 0.207; val: loss 0.347 accuracy 0.887 78/100 - train: loss 0.192; val: loss 0.355 accuracy 0.886 79/100 - train: loss 0.183; val: loss 0.358 accuracy 0.887 80/100 - train: loss 0.156; val: loss 0.333 accuracy 0.896 81/100 - train: loss 0.143; val: loss 0.355 accuracy 0.891 82/100 - train: loss 0.136; val: loss 0.331 accuracy 0.900 83/100 - train: loss 0.115; val: loss 0.334 accuracy 0.898 84/100 - train: loss 0.107; val: loss 0.339 accuracy 0.900 85/100 - train: loss 0.090; val: loss 0.330 accuracy 0.900 86/100 - train: loss 0.082; val: loss 0.348 accuracy 0.899 87/100 - train: loss 0.071; val: loss 0.346 accuracy 0.900 88/100 - train: loss 0.070; val: loss 0.364 accuracy 0.900 89/100 - train: loss 0.054; val: loss 0.347 accuracy 0.905 90/100 - train: loss 0.052; val: loss 0.343 accuracy 0.904 91/100 - train: loss 0.038; val: loss 0.346 accuracy 0.909 92/100 - train: loss 0.037; val: loss 0.350 accuracy 0.906 93/100 - train: loss 0.043; val: loss 0.349 accuracy 0.909 94/100 - train: loss 0.029; val: loss 0.349 accuracy 0.910 95/100 - train: loss 0.022; val: loss 0.353 accuracy 0.909 96/100 - train: loss 0.022; val: loss 0.348 accuracy 0.910 97/100 - train: loss 0.023; val: loss 0.349 accuracy 0.911 98/100 - train: loss 0.017; val: loss 0.351 accuracy 0.911 99/100 - train: loss 0.019; val: loss 0.351 accuracy 0.912 100/100 - train: loss 0.025; val: loss 0.351 accuracy 0.912
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['train loss']) + 1)
ax.plot(xs, history['train loss'], '.-', label='train')
ax.plot(xs, history['val loss'], '.-', label='val')
ax.set_xlabel('epoch')
ax.set_ylabel('loss')
ax.legend()
ax.grid()
plt.show()
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['val acc']) + 1)
ax.plot(xs, history['val acc'], '-')
ax.set_xlabel('epoch')
ax.set_ylabel('val acc')
ax.grid()
plt.show()