Main ideas:
Imports
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.utils import convert_tensor
import ignite.metrics
import ignite.contrib.handlers
Configuration
DATA_DIR='./data'
IMAGE_SIZE = 32
NUM_CLASSES = 10
NUM_WORKERS = 8
BATCH_SIZE = 32
EPOCHS = 200
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-1
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)
device: cuda
train_transform = transforms.Compose([
transforms.TrivialAugmentWide(interpolation=transforms.InterpolationMode.BILINEAR),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.RandomErasing(p=0.1)
])
train_dset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=train_transform)
test_dset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transforms.ToTensor())
Files already downloaded and verified Files already downloaded and verified
def dataset_show_image(dset, idx):
X, Y = dset[idx]
title = "Ground truth: {}".format(dset.classes[Y])
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_axis_off()
ax.imshow(np.moveaxis(X.numpy(), 0, -1))
ax.set_title(title)
plt.show()
dataset_show_image(test_dset, 1)
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_WORKERS, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, pin_memory=True)
Utilities
class Residual(nn.Module):
def __init__(self, *layers):
super().__init__()
self.residual = nn.Sequential(*layers)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
return x + self.gamma * self.residual(x)
class LayerNormChannels(nn.Module):
def __init__(self, channels):
super().__init__()
self.norm = nn.LayerNorm(channels)
def forward(self, x):
x = x.transpose(1, -1)
x = self.norm(x)
x = x.transpose(-1, 1)
return x
Transformer
Attention
class SelfAttention2d(nn.Module):
def __init__(self, in_channels, out_channels, head_channels, shape):
super().__init__()
self.heads = out_channels // head_channels
self.head_channels = head_channels
self.scale = head_channels**-0.5
self.to_keys = nn.Conv2d(in_channels, out_channels, 1)
self.to_queries = nn.Conv2d(in_channels, out_channels, 1)
self.to_values = nn.Conv2d(in_channels, out_channels, 1)
self.unifyheads = nn.Conv2d(out_channels, out_channels, 1)
height, width = shape
self.pos_enc = nn.Parameter(torch.Tensor(self.heads, (2 * height - 1) * (2 * width - 1)))
self.register_buffer("relative_indices", self.get_indices(height, width))
def forward(self, x):
b, _, h, w = x.shape
keys = self.to_keys(x).view(b, self.heads, self.head_channels, -1)
values = self.to_values(x).view(b, self.heads, self.head_channels, -1)
queries = self.to_queries(x).view(b, self.heads, self.head_channels, -1)
att = keys.transpose(-2, -1) @ queries
indices = self.relative_indices.expand(self.heads, -1)
rel_pos_enc = self.pos_enc.gather(-1, indices)
rel_pos_enc = rel_pos_enc.unflatten(-1, (h * w, h * w))
att = att * self.scale + rel_pos_enc
att = F.softmax(att, dim=-2)
out = values @ att
out = out.view(b, -1, h, w)
out = self.unifyheads(out)
return out
@staticmethod
def get_indices(h, w):
y = torch.arange(h, dtype=torch.long)
x = torch.arange(w, dtype=torch.long)
y1, x1, y2, x2 = torch.meshgrid(y, x, y, x, indexing='ij')
indices = (y1 - y2 + h - 1) * (2 * w - 1) + x1 - x2 + w - 1
indices = indices.flatten()
return indices
class FeedForward(nn.Sequential):
def __init__(self, in_channels, out_channels, mult=4):
hidden_channels = in_channels * mult
super().__init__(
nn.Conv2d(in_channels, hidden_channels, 1),
nn.GELU(),
nn.Conv2d(hidden_channels, out_channels, 1)
)
class TransformerBlock(nn.Sequential):
def __init__(self, channels, head_channels, shape, p_drop=0.):
super().__init__(
Residual(
LayerNormChannels(channels),
SelfAttention2d(channels, channels, head_channels, shape),
nn.Dropout(p_drop)
),
Residual(
LayerNormChannels(channels),
FeedForward(channels, channels),
nn.Dropout(p_drop)
)
)
class TransformerStack(nn.Sequential):
def __init__(self, num_blocks, channels, head_channels, shape, p_drop=0.):
layers = [TransformerBlock(channels, head_channels, shape, p_drop) for _ in range(num_blocks)]
super().__init__(*layers)
Embedding of patches
class ToPatches(nn.Sequential):
def __init__(self, in_channels, channels, patch_size, hidden_channels=32):
super().__init__(
nn.Conv2d(in_channels, hidden_channels, 3, padding=1),
nn.GELU(),
nn.Conv2d(hidden_channels, channels, patch_size, stride=patch_size)
)
class AddPositionEmbedding(nn.Module):
def __init__(self, channels, shape):
super().__init__()
self.pos_embedding = nn.Parameter(torch.Tensor(channels, *shape))
def forward(self, x):
return x + self.pos_embedding
class ToEmbedding(nn.Sequential):
def __init__(self, in_channels, channels, patch_size, shape, p_drop=0.):
super().__init__(
ToPatches(in_channels, channels, patch_size),
AddPositionEmbedding(channels, shape),
nn.Dropout(p_drop)
)
Main model
class Head(nn.Sequential):
def __init__(self, in_channels, classes, p_drop=0.):
super().__init__(
LayerNormChannels(in_channels),
nn.GELU(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Dropout(p_drop),
nn.Linear(in_channels, classes)
)
class RelViT(nn.Sequential):
def __init__(self, classes, image_size, channels, head_channels, num_blocks, patch_size,
in_channels=3, emb_p_drop=0., trans_p_drop=0., head_p_drop=0.):
reduced_size = image_size // patch_size
shape = (reduced_size, reduced_size)
super().__init__(
ToEmbedding(in_channels, channels, patch_size, shape, emb_p_drop),
TransformerStack(num_blocks, channels, head_channels, shape, trans_p_drop),
Head(channels, classes, head_p_drop)
)
self.reset_parameters()
def reset_parameters(self):
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.weight, 1.)
nn.init.zeros_(m.bias)
elif isinstance(m, AddPositionEmbedding):
nn.init.normal_(m.pos_embedding, mean=0.0, std=0.02)
elif isinstance(m, SelfAttention2d):
nn.init.normal_(m.pos_enc, mean=0.0, std=0.02)
elif isinstance(m, Residual):
nn.init.zeros_(m.gamma)
def separate_parameters(self):
parameters_decay = set()
parameters_no_decay = set()
modules_weight_decay = (nn.Linear, nn.Conv2d)
modules_no_weight_decay = (nn.LayerNorm,)
for m_name, m in self.named_modules():
for param_name, param in m.named_parameters():
full_param_name = f"{m_name}.{param_name}" if m_name else param_name
if isinstance(m, modules_no_weight_decay):
parameters_no_decay.add(full_param_name)
elif param_name.endswith("bias"):
parameters_no_decay.add(full_param_name)
elif isinstance(m, Residual) and param_name.endswith("gamma"):
parameters_no_decay.add(full_param_name)
elif isinstance(m, AddPositionEmbedding) and param_name.endswith("pos_embedding"):
parameters_no_decay.add(full_param_name)
elif isinstance(m, SelfAttention2d) and param_name.endswith("pos_enc"):
parameters_no_decay.add(full_param_name)
elif isinstance(m, modules_weight_decay):
parameters_decay.add(full_param_name)
# sanity check
assert len(parameters_decay & parameters_no_decay) == 0
assert len(parameters_decay) + len(parameters_no_decay) == len(list(model.parameters()))
return parameters_decay, parameters_no_decay
model = RelViT(NUM_CLASSES, IMAGE_SIZE, channels=256, head_channels=32, num_blocks=8, patch_size=2,
emb_p_drop=0., trans_p_drop=0., head_p_drop=0.3)
model.to(DEVICE);
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
Number of parameters: 6,482,138
def reduce_loss(loss, reduction='mean'):
return loss.mean() if reduction=='mean' else loss.sum() if reduction=='sum' else loss
class CutMix(nn.Module):
def __init__(self, loss, α=1.0):
super().__init__()
self.loss = loss
self.α = α
self.rng = np.random.default_rng()
def prepare_batch(self, batch, device, non_blocking):
x, y = batch
x = convert_tensor(x, device=device, non_blocking=non_blocking)
y = convert_tensor(y, device=device, non_blocking=non_blocking)
batch_size = x.size(0)
self.index = torch.randperm(batch_size).to(device)
self.λ = self.rng.beta(self.α, self.α)
y1, x1, y2, x2 = self.cut_bounding_box(x.shape[-2:], self.λ)
x[:, :, y1:y2, x1:x2] = x[self.index, :, y1:y2, x1:x2]
# adjust lambda to exactly match pixel ratio
area = x.size(2) * x.size(3)
self.λ = 1. - (x2 - x1) * (y2 - y1) / area
return x, y
def cut_bounding_box(self, shape, λ):
cut_size_2 = 0.5 * np.sqrt(1. - λ)
center_yx = self.rng.random(2)
y1x1 = (np.clip(center_yx - cut_size_2, 0., 1.) * shape).astype(int)
y2x2 = (np.clip(center_yx + cut_size_2, 0., 1.) * shape).astype(int)
return np.concatenate((y1x1, y2x2))
def forward(self, pred, target):
orig_reduction = self.loss.reduction
self.loss.reduction = 'none'
batch_loss = self.λ * self.loss(pred, target) + (1. - self.λ) * self.loss(pred, target[self.index])
self.loss.reduction = orig_reduction
return reduce_loss(batch_loss, orig_reduction)
def get_optimizer(model, learning_rate, weight_decay):
param_dict = {pn: p for pn, p in model.named_parameters()}
parameters_decay, parameters_no_decay = model.separate_parameters()
optim_groups = [
{"params": [param_dict[pn] for pn in parameters_decay], "weight_decay": weight_decay},
{"params": [param_dict[pn] for pn in parameters_no_decay], "weight_decay": 0.0},
]
optimizer = optim.AdamW(optim_groups, lr=learning_rate)
return optimizer
loss = nn.CrossEntropyLoss(label_smoothing=0.1)
cutmix = CutMix(loss, α=1.0)
optimizer = get_optimizer(model, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
trainer = create_supervised_trainer(model, optimizer, cutmix, device=DEVICE, prepare_batch=cutmix.prepare_batch)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
steps_per_epoch=len(train_loader), epochs=EPOCHS)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step());
ignite.metrics.RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
val_metrics = {"accuracy": ignite.metrics.Accuracy(), "loss": ignite.metrics.Loss(loss)}
evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=DEVICE)
history = defaultdict(list)
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
train_state = engine.state
epoch = train_state.epoch
max_epochs = train_state.max_epochs
train_loss = train_state.metrics["loss"]
history['train loss'].append(train_loss)
evaluator.run(test_loader)
val_metrics = evaluator.state.metrics
val_loss = val_metrics["loss"]
val_acc = val_metrics["accuracy"]
history['val loss'].append(val_loss)
history['val acc'].append(val_acc)
print("{}/{} - train: loss {:.3f}; val: loss {:.3f} accuracy {:.3f}".format(
epoch, max_epochs, train_loss, val_loss, val_acc))
trainer.run(train_loader, max_epochs=EPOCHS);
1/200 - train: loss 2.203; val: loss 2.018 accuracy 0.299 2/200 - train: loss 2.143; val: loss 1.856 accuracy 0.368 3/200 - train: loss 2.111; val: loss 1.757 accuracy 0.448 4/200 - train: loss 2.042; val: loss 1.635 accuracy 0.484 5/200 - train: loss 2.033; val: loss 1.530 accuracy 0.549 6/200 - train: loss 1.980; val: loss 1.472 accuracy 0.553 7/200 - train: loss 1.945; val: loss 1.435 accuracy 0.576 8/200 - train: loss 1.950; val: loss 1.382 accuracy 0.607 9/200 - train: loss 1.948; val: loss 1.388 accuracy 0.611 10/200 - train: loss 1.896; val: loss 1.351 accuracy 0.620 11/200 - train: loss 1.915; val: loss 1.356 accuracy 0.614 12/200 - train: loss 1.917; val: loss 1.309 accuracy 0.648 13/200 - train: loss 1.875; val: loss 1.258 accuracy 0.662 14/200 - train: loss 1.902; val: loss 1.265 accuracy 0.675 15/200 - train: loss 1.867; val: loss 1.261 accuracy 0.660 16/200 - train: loss 1.830; val: loss 1.230 accuracy 0.678 17/200 - train: loss 1.874; val: loss 1.215 accuracy 0.692 18/200 - train: loss 1.839; val: loss 1.178 accuracy 0.704 19/200 - train: loss 1.853; val: loss 1.204 accuracy 0.693 20/200 - train: loss 1.839; val: loss 1.150 accuracy 0.729 21/200 - train: loss 1.862; val: loss 1.207 accuracy 0.700 22/200 - train: loss 1.860; val: loss 1.179 accuracy 0.699 23/200 - train: loss 1.857; val: loss 1.137 accuracy 0.729 24/200 - train: loss 1.827; val: loss 1.122 accuracy 0.731 25/200 - train: loss 1.822; val: loss 1.215 accuracy 0.678 26/200 - train: loss 1.835; val: loss 1.201 accuracy 0.691 27/200 - train: loss 1.819; val: loss 1.137 accuracy 0.718 28/200 - train: loss 1.841; val: loss 1.129 accuracy 0.726 29/200 - train: loss 1.821; val: loss 1.142 accuracy 0.720 30/200 - train: loss 1.775; val: loss 1.138 accuracy 0.720 31/200 - train: loss 1.786; val: loss 1.132 accuracy 0.727 32/200 - train: loss 1.785; val: loss 1.124 accuracy 0.728 33/200 - train: loss 1.820; val: loss 1.112 accuracy 0.743 34/200 - train: loss 1.783; val: loss 1.100 accuracy 0.752 35/200 - train: loss 1.820; val: loss 1.077 accuracy 0.762 36/200 - train: loss 1.791; val: loss 1.075 accuracy 0.755 37/200 - train: loss 1.796; val: loss 1.072 accuracy 0.765 38/200 - train: loss 1.754; val: loss 1.064 accuracy 0.752 39/200 - train: loss 1.794; val: loss 1.066 accuracy 0.763 40/200 - train: loss 1.817; val: loss 1.073 accuracy 0.763 41/200 - train: loss 1.749; val: loss 1.067 accuracy 0.754 42/200 - train: loss 1.778; val: loss 1.065 accuracy 0.757 43/200 - train: loss 1.760; val: loss 1.053 accuracy 0.759 44/200 - train: loss 1.764; val: loss 1.003 accuracy 0.785 45/200 - train: loss 1.716; val: loss 1.033 accuracy 0.769 46/200 - train: loss 1.733; val: loss 1.020 accuracy 0.772 47/200 - train: loss 1.742; val: loss 1.067 accuracy 0.763 48/200 - train: loss 1.720; val: loss 0.970 accuracy 0.804 49/200 - train: loss 1.748; val: loss 1.012 accuracy 0.786 50/200 - train: loss 1.739; val: loss 0.994 accuracy 0.794 51/200 - train: loss 1.708; val: loss 0.970 accuracy 0.798 52/200 - train: loss 1.693; val: loss 0.963 accuracy 0.805 53/200 - train: loss 1.734; val: loss 1.074 accuracy 0.750 54/200 - train: loss 1.713; val: loss 0.945 accuracy 0.810 55/200 - train: loss 1.757; val: loss 1.020 accuracy 0.785 56/200 - train: loss 1.648; val: loss 0.905 accuracy 0.830 57/200 - train: loss 1.703; val: loss 0.953 accuracy 0.814 58/200 - train: loss 1.662; val: loss 0.952 accuracy 0.807 59/200 - train: loss 1.683; val: loss 0.932 accuracy 0.819 60/200 - train: loss 1.634; val: loss 0.973 accuracy 0.800 61/200 - train: loss 1.683; val: loss 0.904 accuracy 0.834 62/200 - train: loss 1.695; val: loss 0.887 accuracy 0.844 63/200 - train: loss 1.651; val: loss 0.888 accuracy 0.836 64/200 - train: loss 1.643; val: loss 0.871 accuracy 0.845 65/200 - train: loss 1.657; val: loss 0.965 accuracy 0.796 66/200 - train: loss 1.587; val: loss 0.921 accuracy 0.820 67/200 - train: loss 1.626; val: loss 0.876 accuracy 0.844 68/200 - train: loss 1.636; val: loss 0.880 accuracy 0.841 69/200 - train: loss 1.653; val: loss 0.889 accuracy 0.835 70/200 - train: loss 1.641; val: loss 0.893 accuracy 0.838 71/200 - train: loss 1.596; val: loss 0.848 accuracy 0.856 72/200 - train: loss 1.623; val: loss 0.846 accuracy 0.860 73/200 - train: loss 1.614; val: loss 0.825 accuracy 0.864 74/200 - train: loss 1.630; val: loss 0.825 accuracy 0.868 75/200 - train: loss 1.608; val: loss 0.837 accuracy 0.860 76/200 - train: loss 1.581; val: loss 0.865 accuracy 0.850 77/200 - train: loss 1.603; val: loss 0.852 accuracy 0.852 78/200 - train: loss 1.548; val: loss 0.856 accuracy 0.846 79/200 - train: loss 1.561; val: loss 0.831 accuracy 0.865 80/200 - train: loss 1.555; val: loss 0.810 accuracy 0.870 81/200 - train: loss 1.599; val: loss 0.807 accuracy 0.878 82/200 - train: loss 1.572; val: loss 0.799 accuracy 0.872 83/200 - train: loss 1.587; val: loss 0.826 accuracy 0.865 84/200 - train: loss 1.574; val: loss 0.836 accuracy 0.853 85/200 - train: loss 1.629; val: loss 0.820 accuracy 0.877 86/200 - train: loss 1.564; val: loss 0.792 accuracy 0.879 87/200 - train: loss 1.573; val: loss 0.804 accuracy 0.875 88/200 - train: loss 1.563; val: loss 0.796 accuracy 0.876 89/200 - train: loss 1.551; val: loss 0.794 accuracy 0.880 90/200 - train: loss 1.567; val: loss 0.791 accuracy 0.881 91/200 - train: loss 1.564; val: loss 0.829 accuracy 0.862 92/200 - train: loss 1.573; val: loss 0.792 accuracy 0.878 93/200 - train: loss 1.546; val: loss 0.763 accuracy 0.891 94/200 - train: loss 1.566; val: loss 0.791 accuracy 0.878 95/200 - train: loss 1.565; val: loss 0.819 accuracy 0.869 96/200 - train: loss 1.526; val: loss 0.832 accuracy 0.858 97/200 - train: loss 1.544; val: loss 0.778 accuracy 0.888 98/200 - train: loss 1.534; val: loss 0.761 accuracy 0.891 99/200 - train: loss 1.545; val: loss 0.782 accuracy 0.887 100/200 - train: loss 1.502; val: loss 0.761 accuracy 0.891 101/200 - train: loss 1.556; val: loss 0.776 accuracy 0.884 102/200 - train: loss 1.521; val: loss 0.762 accuracy 0.892 103/200 - train: loss 1.491; val: loss 0.765 accuracy 0.889 104/200 - train: loss 1.551; val: loss 0.803 accuracy 0.877 105/200 - train: loss 1.522; val: loss 0.785 accuracy 0.876 106/200 - train: loss 1.476; val: loss 0.762 accuracy 0.890 107/200 - train: loss 1.527; val: loss 0.750 accuracy 0.899 108/200 - train: loss 1.505; val: loss 0.737 accuracy 0.899 109/200 - train: loss 1.533; val: loss 0.745 accuracy 0.896 110/200 - train: loss 1.502; val: loss 0.776 accuracy 0.886 111/200 - train: loss 1.472; val: loss 0.725 accuracy 0.907 112/200 - train: loss 1.474; val: loss 0.733 accuracy 0.905 113/200 - train: loss 1.516; val: loss 0.760 accuracy 0.892 114/200 - train: loss 1.499; val: loss 0.719 accuracy 0.911 115/200 - train: loss 1.462; val: loss 0.727 accuracy 0.906 116/200 - train: loss 1.452; val: loss 0.722 accuracy 0.906 117/200 - train: loss 1.473; val: loss 0.750 accuracy 0.890 118/200 - train: loss 1.487; val: loss 0.739 accuracy 0.900 119/200 - train: loss 1.490; val: loss 0.729 accuracy 0.905 120/200 - train: loss 1.436; val: loss 0.731 accuracy 0.902 121/200 - train: loss 1.487; val: loss 0.731 accuracy 0.903 122/200 - train: loss 1.442; val: loss 0.705 accuracy 0.914 123/200 - train: loss 1.445; val: loss 0.745 accuracy 0.895 124/200 - train: loss 1.423; val: loss 0.701 accuracy 0.917 125/200 - train: loss 1.445; val: loss 0.707 accuracy 0.911 126/200 - train: loss 1.380; val: loss 0.699 accuracy 0.916 127/200 - train: loss 1.461; val: loss 0.704 accuracy 0.913 128/200 - train: loss 1.433; val: loss 0.693 accuracy 0.916 129/200 - train: loss 1.425; val: loss 0.703 accuracy 0.914 130/200 - train: loss 1.416; val: loss 0.679 accuracy 0.928 131/200 - train: loss 1.404; val: loss 0.671 accuracy 0.930 132/200 - train: loss 1.440; val: loss 0.688 accuracy 0.928 133/200 - train: loss 1.440; val: loss 0.704 accuracy 0.912 134/200 - train: loss 1.421; val: loss 0.694 accuracy 0.922 135/200 - train: loss 1.406; val: loss 0.668 accuracy 0.933 136/200 - train: loss 1.419; val: loss 0.707 accuracy 0.914 137/200 - train: loss 1.427; val: loss 0.668 accuracy 0.931 138/200 - train: loss 1.379; val: loss 0.670 accuracy 0.930 139/200 - train: loss 1.420; val: loss 0.675 accuracy 0.931 140/200 - train: loss 1.371; val: loss 0.665 accuracy 0.933 141/200 - train: loss 1.359; val: loss 0.666 accuracy 0.932 142/200 - train: loss 1.412; val: loss 0.655 accuracy 0.935 143/200 - train: loss 1.378; val: loss 0.666 accuracy 0.935 144/200 - train: loss 1.378; val: loss 0.648 accuracy 0.938 145/200 - train: loss 1.331; val: loss 0.661 accuracy 0.934 146/200 - train: loss 1.340; val: loss 0.641 accuracy 0.942 147/200 - train: loss 1.353; val: loss 0.660 accuracy 0.935 148/200 - train: loss 1.358; val: loss 0.644 accuracy 0.944 149/200 - train: loss 1.351; val: loss 0.646 accuracy 0.941 150/200 - train: loss 1.343; val: loss 0.648 accuracy 0.937 151/200 - train: loss 1.342; val: loss 0.633 accuracy 0.944 152/200 - train: loss 1.347; val: loss 0.643 accuracy 0.943 153/200 - train: loss 1.341; val: loss 0.648 accuracy 0.937 154/200 - train: loss 1.308; val: loss 0.638 accuracy 0.942 155/200 - train: loss 1.304; val: loss 0.635 accuracy 0.944 156/200 - train: loss 1.347; val: loss 0.628 accuracy 0.948 157/200 - train: loss 1.301; val: loss 0.621 accuracy 0.952 158/200 - train: loss 1.319; val: loss 0.616 accuracy 0.954 159/200 - train: loss 1.328; val: loss 0.617 accuracy 0.953 160/200 - train: loss 1.302; val: loss 0.620 accuracy 0.953 161/200 - train: loss 1.290; val: loss 0.618 accuracy 0.953 162/200 - train: loss 1.268; val: loss 0.615 accuracy 0.954 163/200 - train: loss 1.276; val: loss 0.616 accuracy 0.953 164/200 - train: loss 1.249; val: loss 0.611 accuracy 0.957 165/200 - train: loss 1.224; val: loss 0.607 accuracy 0.957 166/200 - train: loss 1.258; val: loss 0.610 accuracy 0.954 167/200 - train: loss 1.255; val: loss 0.604 accuracy 0.957 168/200 - train: loss 1.250; val: loss 0.607 accuracy 0.955 169/200 - train: loss 1.218; val: loss 0.598 accuracy 0.961 170/200 - train: loss 1.248; val: loss 0.593 accuracy 0.963 171/200 - train: loss 1.251; val: loss 0.593 accuracy 0.964 172/200 - train: loss 1.269; val: loss 0.595 accuracy 0.962 173/200 - train: loss 1.239; val: loss 0.592 accuracy 0.963 174/200 - train: loss 1.255; val: loss 0.595 accuracy 0.963 175/200 - train: loss 1.262; val: loss 0.589 accuracy 0.966 176/200 - train: loss 1.216; val: loss 0.591 accuracy 0.964 177/200 - train: loss 1.211; val: loss 0.587 accuracy 0.965 178/200 - train: loss 1.220; val: loss 0.587 accuracy 0.964 179/200 - train: loss 1.215; val: loss 0.582 accuracy 0.968 180/200 - train: loss 1.189; val: loss 0.581 accuracy 0.967 181/200 - train: loss 1.197; val: loss 0.580 accuracy 0.968 182/200 - train: loss 1.157; val: loss 0.581 accuracy 0.967 183/200 - train: loss 1.150; val: loss 0.575 accuracy 0.968 184/200 - train: loss 1.212; val: loss 0.575 accuracy 0.972 185/200 - train: loss 1.205; val: loss 0.576 accuracy 0.969 186/200 - train: loss 1.180; val: loss 0.575 accuracy 0.970 187/200 - train: loss 1.171; val: loss 0.574 accuracy 0.971 188/200 - train: loss 1.200; val: loss 0.571 accuracy 0.972 189/200 - train: loss 1.171; val: loss 0.571 accuracy 0.971 190/200 - train: loss 1.160; val: loss 0.570 accuracy 0.972 191/200 - train: loss 1.161; val: loss 0.572 accuracy 0.971 192/200 - train: loss 1.167; val: loss 0.571 accuracy 0.971 193/200 - train: loss 1.174; val: loss 0.570 accuracy 0.972 194/200 - train: loss 1.199; val: loss 0.572 accuracy 0.971 195/200 - train: loss 1.170; val: loss 0.572 accuracy 0.971 196/200 - train: loss 1.138; val: loss 0.570 accuracy 0.971 197/200 - train: loss 1.211; val: loss 0.570 accuracy 0.971 198/200 - train: loss 1.183; val: loss 0.570 accuracy 0.971 199/200 - train: loss 1.173; val: loss 0.570 accuracy 0.971 200/200 - train: loss 1.159; val: loss 0.570 accuracy 0.971
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['train loss']) + 1)
ax.plot(xs, history['train loss'], '.-', label='train')
ax.plot(xs, history['val loss'], '.-', label='val')
ax.set_xlabel('epoch')
ax.set_ylabel('loss')
ax.legend()
ax.grid()
plt.show()
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['val acc']) + 1)
ax.plot(xs, history['val acc'], '-')
ax.set_xlabel('epoch')
ax.set_ylabel('val acc')
ax.grid()
plt.show()
Only input-independent weights (relative position encoding) remain:
class SelfAttention2d(nn.Module):
def __init__(self, in_channels, out_channels, head_channels, shape):
super().__init__()
self.heads = out_channels // head_channels
self.head_channels = head_channels
self.scale = head_channels**-0.5
self.to_values = nn.Conv2d(in_channels, out_channels, 1)
self.unifyheads = nn.Conv2d(out_channels, out_channels, 1)
height, width = shape
self.pos_enc = nn.Parameter(torch.Tensor(self.heads, (2 * height - 1) * (2 * width - 1)))
self.register_buffer("relative_indices", self.get_indices(height, width))
def forward(self, x):
b, _, h, w = x.shape
values = self.to_values(x).view(b, self.heads, self.head_channels, -1)
indices = self.relative_indices.expand(self.heads, -1)
rel_pos_enc = self.pos_enc.gather(-1, indices)
rel_pos_enc = rel_pos_enc.unflatten(-1, (h * w, h * w))
att = F.softmax(rel_pos_enc, dim=-2)
out = values @ att
out = out.view(b, -1, h, w)
out = self.unifyheads(out)
return out
@staticmethod
def get_indices(h, w):
y = torch.arange(h, dtype=torch.long)
x = torch.arange(w, dtype=torch.long)
y1, x1, y2, x2 = torch.meshgrid(y, x, y, x, indexing='ij')
indices = (y1 - y2 + h - 1) * (2 * w - 1) + x1 - x2 + w - 1
indices = indices.flatten()
return indices
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
Number of parameters: 5,429,466
trainer.run(train_loader, max_epochs=EPOCHS);
1/200 - train: loss 2.248; val: loss 2.092 accuracy 0.262 2/200 - train: loss 2.203; val: loss 1.968 accuracy 0.318 3/200 - train: loss 2.155; val: loss 1.874 accuracy 0.350 4/200 - train: loss 2.121; val: loss 1.838 accuracy 0.374 5/200 - train: loss 2.093; val: loss 1.768 accuracy 0.404 6/200 - train: loss 2.068; val: loss 1.691 accuracy 0.472 7/200 - train: loss 2.027; val: loss 1.597 accuracy 0.502 8/200 - train: loss 2.013; val: loss 1.560 accuracy 0.519 9/200 - train: loss 2.019; val: loss 1.539 accuracy 0.518 10/200 - train: loss 1.979; val: loss 1.484 accuracy 0.559 11/200 - train: loss 1.991; val: loss 1.443 accuracy 0.580 12/200 - train: loss 1.925; val: loss 1.406 accuracy 0.607 13/200 - train: loss 1.920; val: loss 1.411 accuracy 0.592 14/200 - train: loss 1.949; val: loss 1.398 accuracy 0.591 15/200 - train: loss 1.903; val: loss 1.333 accuracy 0.629 16/200 - train: loss 1.872; val: loss 1.272 accuracy 0.666 17/200 - train: loss 1.865; val: loss 1.273 accuracy 0.665 18/200 - train: loss 1.862; val: loss 1.255 accuracy 0.673 19/200 - train: loss 1.860; val: loss 1.255 accuracy 0.664 20/200 - train: loss 1.832; val: loss 1.209 accuracy 0.695 21/200 - train: loss 1.826; val: loss 1.170 accuracy 0.709 22/200 - train: loss 1.809; val: loss 1.162 accuracy 0.713 23/200 - train: loss 1.794; val: loss 1.161 accuracy 0.713 24/200 - train: loss 1.801; val: loss 1.157 accuracy 0.731 25/200 - train: loss 1.757; val: loss 1.145 accuracy 0.713 26/200 - train: loss 1.771; val: loss 1.170 accuracy 0.710 27/200 - train: loss 1.788; val: loss 1.122 accuracy 0.732 28/200 - train: loss 1.786; val: loss 1.112 accuracy 0.738 29/200 - train: loss 1.737; val: loss 1.144 accuracy 0.717 30/200 - train: loss 1.768; val: loss 1.110 accuracy 0.739 31/200 - train: loss 1.751; val: loss 1.077 accuracy 0.749 32/200 - train: loss 1.733; val: loss 1.082 accuracy 0.747 33/200 - train: loss 1.740; val: loss 1.063 accuracy 0.771 34/200 - train: loss 1.761; val: loss 1.074 accuracy 0.763 35/200 - train: loss 1.735; val: loss 1.072 accuracy 0.765 36/200 - train: loss 1.721; val: loss 1.044 accuracy 0.760 37/200 - train: loss 1.750; val: loss 1.009 accuracy 0.795 38/200 - train: loss 1.773; val: loss 1.043 accuracy 0.779 39/200 - train: loss 1.718; val: loss 0.989 accuracy 0.799 40/200 - train: loss 1.740; val: loss 0.974 accuracy 0.801 41/200 - train: loss 1.698; val: loss 0.977 accuracy 0.797 42/200 - train: loss 1.715; val: loss 1.033 accuracy 0.769 43/200 - train: loss 1.690; val: loss 0.955 accuracy 0.812 44/200 - train: loss 1.702; val: loss 0.934 accuracy 0.821 45/200 - train: loss 1.683; val: loss 0.961 accuracy 0.808 46/200 - train: loss 1.703; val: loss 0.965 accuracy 0.819 47/200 - train: loss 1.726; val: loss 0.946 accuracy 0.816 48/200 - train: loss 1.704; val: loss 0.943 accuracy 0.813 49/200 - train: loss 1.696; val: loss 0.950 accuracy 0.815 50/200 - train: loss 1.689; val: loss 0.980 accuracy 0.805 51/200 - train: loss 1.666; val: loss 0.934 accuracy 0.817 52/200 - train: loss 1.686; val: loss 0.914 accuracy 0.828 53/200 - train: loss 1.708; val: loss 0.979 accuracy 0.815 54/200 - train: loss 1.678; val: loss 0.945 accuracy 0.815 55/200 - train: loss 1.655; val: loss 0.923 accuracy 0.823 56/200 - train: loss 1.683; val: loss 0.927 accuracy 0.823 57/200 - train: loss 1.660; val: loss 0.905 accuracy 0.836 58/200 - train: loss 1.637; val: loss 0.907 accuracy 0.832 59/200 - train: loss 1.684; val: loss 0.947 accuracy 0.817 60/200 - train: loss 1.671; val: loss 0.909 accuracy 0.834 61/200 - train: loss 1.659; val: loss 0.919 accuracy 0.827 62/200 - train: loss 1.653; val: loss 0.914 accuracy 0.833 63/200 - train: loss 1.644; val: loss 0.885 accuracy 0.840 64/200 - train: loss 1.676; val: loss 0.878 accuracy 0.846 65/200 - train: loss 1.635; val: loss 0.856 accuracy 0.854 66/200 - train: loss 1.646; val: loss 0.897 accuracy 0.836 67/200 - train: loss 1.638; val: loss 0.894 accuracy 0.836 68/200 - train: loss 1.614; val: loss 0.934 accuracy 0.807 69/200 - train: loss 1.628; val: loss 0.879 accuracy 0.845 70/200 - train: loss 1.642; val: loss 0.884 accuracy 0.837 71/200 - train: loss 1.647; val: loss 0.857 accuracy 0.856 72/200 - train: loss 1.633; val: loss 0.895 accuracy 0.839 73/200 - train: loss 1.636; val: loss 0.855 accuracy 0.855 74/200 - train: loss 1.604; val: loss 0.858 accuracy 0.863 75/200 - train: loss 1.626; val: loss 0.881 accuracy 0.845 76/200 - train: loss 1.592; val: loss 0.852 accuracy 0.850 77/200 - train: loss 1.646; val: loss 0.866 accuracy 0.854 78/200 - train: loss 1.611; val: loss 0.872 accuracy 0.847 79/200 - train: loss 1.636; val: loss 0.883 accuracy 0.847 80/200 - train: loss 1.611; val: loss 0.861 accuracy 0.846 81/200 - train: loss 1.562; val: loss 0.828 accuracy 0.861 82/200 - train: loss 1.595; val: loss 0.817 accuracy 0.871 83/200 - train: loss 1.604; val: loss 0.846 accuracy 0.854 84/200 - train: loss 1.566; val: loss 0.803 accuracy 0.876 85/200 - train: loss 1.600; val: loss 0.834 accuracy 0.858 86/200 - train: loss 1.594; val: loss 0.828 accuracy 0.868 87/200 - train: loss 1.577; val: loss 0.853 accuracy 0.855 88/200 - train: loss 1.611; val: loss 0.877 accuracy 0.843 89/200 - train: loss 1.602; val: loss 0.803 accuracy 0.879 90/200 - train: loss 1.578; val: loss 0.845 accuracy 0.860 91/200 - train: loss 1.585; val: loss 0.838 accuracy 0.859 92/200 - train: loss 1.592; val: loss 0.803 accuracy 0.875 93/200 - train: loss 1.583; val: loss 0.829 accuracy 0.862 94/200 - train: loss 1.589; val: loss 0.799 accuracy 0.879 95/200 - train: loss 1.565; val: loss 0.798 accuracy 0.879 96/200 - train: loss 1.599; val: loss 0.788 accuracy 0.885 97/200 - train: loss 1.545; val: loss 0.816 accuracy 0.867 98/200 - train: loss 1.582; val: loss 0.791 accuracy 0.882 99/200 - train: loss 1.551; val: loss 0.805 accuracy 0.877 100/200 - train: loss 1.596; val: loss 0.820 accuracy 0.866 101/200 - train: loss 1.592; val: loss 0.832 accuracy 0.856 102/200 - train: loss 1.572; val: loss 0.812 accuracy 0.872 103/200 - train: loss 1.531; val: loss 0.784 accuracy 0.880 104/200 - train: loss 1.571; val: loss 0.778 accuracy 0.883 105/200 - train: loss 1.545; val: loss 0.787 accuracy 0.879 106/200 - train: loss 1.562; val: loss 0.793 accuracy 0.882 107/200 - train: loss 1.521; val: loss 0.782 accuracy 0.882 108/200 - train: loss 1.525; val: loss 0.762 accuracy 0.887 109/200 - train: loss 1.556; val: loss 0.781 accuracy 0.879 110/200 - train: loss 1.581; val: loss 0.805 accuracy 0.869 111/200 - train: loss 1.540; val: loss 0.751 accuracy 0.895 112/200 - train: loss 1.539; val: loss 0.741 accuracy 0.899 113/200 - train: loss 1.530; val: loss 0.747 accuracy 0.893 114/200 - train: loss 1.540; val: loss 0.756 accuracy 0.891 115/200 - train: loss 1.521; val: loss 0.741 accuracy 0.900 116/200 - train: loss 1.530; val: loss 0.734 accuracy 0.902 117/200 - train: loss 1.531; val: loss 0.797 accuracy 0.873 118/200 - train: loss 1.504; val: loss 0.758 accuracy 0.890 119/200 - train: loss 1.547; val: loss 0.761 accuracy 0.894 120/200 - train: loss 1.527; val: loss 0.737 accuracy 0.902 121/200 - train: loss 1.498; val: loss 0.763 accuracy 0.884 122/200 - train: loss 1.541; val: loss 0.752 accuracy 0.896 123/200 - train: loss 1.505; val: loss 0.729 accuracy 0.901 124/200 - train: loss 1.511; val: loss 0.759 accuracy 0.891 125/200 - train: loss 1.488; val: loss 0.733 accuracy 0.905 126/200 - train: loss 1.509; val: loss 0.718 accuracy 0.910 127/200 - train: loss 1.480; val: loss 0.737 accuracy 0.898 128/200 - train: loss 1.478; val: loss 0.724 accuracy 0.906 129/200 - train: loss 1.509; val: loss 0.722 accuracy 0.905 130/200 - train: loss 1.447; val: loss 0.697 accuracy 0.917 131/200 - train: loss 1.442; val: loss 0.722 accuracy 0.907 132/200 - train: loss 1.463; val: loss 0.724 accuracy 0.905 133/200 - train: loss 1.476; val: loss 0.723 accuracy 0.906 134/200 - train: loss 1.439; val: loss 0.702 accuracy 0.913 135/200 - train: loss 1.447; val: loss 0.715 accuracy 0.911 136/200 - train: loss 1.469; val: loss 0.738 accuracy 0.901 137/200 - train: loss 1.445; val: loss 0.706 accuracy 0.913 138/200 - train: loss 1.471; val: loss 0.686 accuracy 0.923 139/200 - train: loss 1.438; val: loss 0.694 accuracy 0.921 140/200 - train: loss 1.458; val: loss 0.686 accuracy 0.924 141/200 - train: loss 1.457; val: loss 0.701 accuracy 0.915 142/200 - train: loss 1.466; val: loss 0.688 accuracy 0.921 143/200 - train: loss 1.455; val: loss 0.676 accuracy 0.925 144/200 - train: loss 1.418; val: loss 0.684 accuracy 0.925 145/200 - train: loss 1.475; val: loss 0.675 accuracy 0.928 146/200 - train: loss 1.433; val: loss 0.688 accuracy 0.922 147/200 - train: loss 1.426; val: loss 0.682 accuracy 0.923 148/200 - train: loss 1.407; val: loss 0.670 accuracy 0.929 149/200 - train: loss 1.417; val: loss 0.653 accuracy 0.937 150/200 - train: loss 1.408; val: loss 0.666 accuracy 0.928 151/200 - train: loss 1.389; val: loss 0.664 accuracy 0.932 152/200 - train: loss 1.389; val: loss 0.661 accuracy 0.933 153/200 - train: loss 1.405; val: loss 0.646 accuracy 0.939 154/200 - train: loss 1.396; val: loss 0.658 accuracy 0.935 155/200 - train: loss 1.446; val: loss 0.650 accuracy 0.936 156/200 - train: loss 1.384; val: loss 0.648 accuracy 0.939 157/200 - train: loss 1.399; val: loss 0.650 accuracy 0.937 158/200 - train: loss 1.388; val: loss 0.649 accuracy 0.939 159/200 - train: loss 1.386; val: loss 0.646 accuracy 0.939 160/200 - train: loss 1.347; val: loss 0.650 accuracy 0.937 161/200 - train: loss 1.369; val: loss 0.636 accuracy 0.944 162/200 - train: loss 1.358; val: loss 0.627 accuracy 0.947 163/200 - train: loss 1.365; val: loss 0.643 accuracy 0.938 164/200 - train: loss 1.357; val: loss 0.630 accuracy 0.945 165/200 - train: loss 1.311; val: loss 0.622 accuracy 0.946 166/200 - train: loss 1.306; val: loss 0.635 accuracy 0.944 167/200 - train: loss 1.296; val: loss 0.617 accuracy 0.949 168/200 - train: loss 1.307; val: loss 0.616 accuracy 0.952 169/200 - train: loss 1.294; val: loss 0.612 accuracy 0.954 170/200 - train: loss 1.308; val: loss 0.616 accuracy 0.953 171/200 - train: loss 1.295; val: loss 0.612 accuracy 0.954 172/200 - train: loss 1.277; val: loss 0.605 accuracy 0.957 173/200 - train: loss 1.309; val: loss 0.615 accuracy 0.953 174/200 - train: loss 1.267; val: loss 0.607 accuracy 0.955 175/200 - train: loss 1.256; val: loss 0.606 accuracy 0.957 176/200 - train: loss 1.273; val: loss 0.603 accuracy 0.958 177/200 - train: loss 1.290; val: loss 0.605 accuracy 0.957 178/200 - train: loss 1.268; val: loss 0.601 accuracy 0.958 179/200 - train: loss 1.253; val: loss 0.597 accuracy 0.959 180/200 - train: loss 1.232; val: loss 0.599 accuracy 0.959 181/200 - train: loss 1.248; val: loss 0.598 accuracy 0.961 182/200 - train: loss 1.288; val: loss 0.596 accuracy 0.959 183/200 - train: loss 1.242; val: loss 0.597 accuracy 0.960 184/200 - train: loss 1.237; val: loss 0.597 accuracy 0.960 185/200 - train: loss 1.283; val: loss 0.590 accuracy 0.963 186/200 - train: loss 1.273; val: loss 0.595 accuracy 0.961 187/200 - train: loss 1.251; val: loss 0.590 accuracy 0.963 188/200 - train: loss 1.219; val: loss 0.590 accuracy 0.964 189/200 - train: loss 1.198; val: loss 0.587 accuracy 0.964 190/200 - train: loss 1.261; val: loss 0.587 accuracy 0.964 191/200 - train: loss 1.252; val: loss 0.589 accuracy 0.963 192/200 - train: loss 1.210; val: loss 0.587 accuracy 0.964 193/200 - train: loss 1.220; val: loss 0.585 accuracy 0.966 194/200 - train: loss 1.227; val: loss 0.584 accuracy 0.965 195/200 - train: loss 1.222; val: loss 0.585 accuracy 0.966 196/200 - train: loss 1.239; val: loss 0.585 accuracy 0.965 197/200 - train: loss 1.227; val: loss 0.584 accuracy 0.966 198/200 - train: loss 1.237; val: loss 0.584 accuracy 0.965 199/200 - train: loss 1.210; val: loss 0.585 accuracy 0.965 200/200 - train: loss 1.174; val: loss 0.585 accuracy 0.965
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['train loss']) + 1)
ax.plot(xs, history['train loss'], '.-', label='train')
ax.plot(xs, history['val loss'], '.-', label='val')
ax.set_xlabel('epoch')
ax.set_ylabel('loss')
ax.legend()
ax.grid()
plt.show()
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['val acc']) + 1)
ax.plot(xs, history['val acc'], '-')
ax.set_xlabel('epoch')
ax.set_ylabel('val acc')
ax.grid()
plt.show()