Tiny model for CIFAR10Ā¶
ConfigurationĀ¶
Imports
InĀ [1]:
import warnings
InĀ [2]:
warnings.simplefilter('ignore')
InĀ [3]:
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import tqdm
import tqdm.autonotebook
tqdm.autonotebook.tqdm = tqdm.tqdm # hack to force ASCII output everywhere
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets
import torchvision.transforms.v2 as transforms
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
import ignite.metrics
import ignite.contrib.handlers
Configuration
InĀ [4]:
DATA_DIR='./data'
NUM_CLASSES = 10
NUM_WORKERS = 8
BATCH_SIZE = 32
EPOCHS = 2000
LEARNING_RATE = 1e-2
WEIGHT_DECAY = 1e-3
InĀ [5]:
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)
device: cuda
DataĀ¶
InĀ [6]:
train_transform = transforms.Compose([
transforms.ToImage(),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4, fill=127),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToDtype(torch.float, scale=True),
transforms.RandomErasing(p=0.5, value=0.5)
])
InĀ [7]:
val_transform = transforms.Compose([
transforms.ToImage(),
transforms.ToDtype(torch.float, scale=True),
])
InĀ [8]:
train_dset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=train_transform)
test_dset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=val_transform)
Files already downloaded and verified Files already downloaded and verified
InĀ [9]:
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
InĀ [10]:
def dataset_show_image(dset, idx):
X, Y = dset[idx]
title = "Ground truth: {}".format(dset.classes[Y])
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_axis_off()
ax.imshow(np.moveaxis(X.numpy(), 0, -1))
ax.set_title(title)
plt.show()
InĀ [11]:
dataset_show_image(test_dset, 1)
ModelĀ¶
InĀ [12]:
class NormAct(nn.Sequential):
def __init__(self, channels):
super().__init__(
nn.BatchNorm2d(channels),
nn.ReLU(inplace=True)
)
InĀ [13]:
class ResidualBlock(nn.Module):
def __init__(self, channels, stride=1, p_drop=0.):
super().__init__()
self.shortcut = nn.Dropout(p_drop)
if stride > 1:
self.shortcut = nn.Sequential(
nn.Conv2d(channels, channels, stride, stride=stride, groups=channels, bias=False),
self.shortcut
)
self.residual = nn.Sequential(
NormAct(channels),
nn.Conv2d(channels, channels, 2 + stride, stride=stride, padding=1, groups=channels, bias=False),
NormAct(channels),
nn.Conv2d(channels, channels, 1, bias=False),
)
self.Ī³ = nn.Parameter(torch.tensor(0.))
def forward(self, x):
out = self.shortcut(x) + self.Ī³ * self.residual(x)
return out
InĀ [14]:
class Head(nn.Sequential):
def __init__(self, channels, classes, p_drop=0.):
super().__init__(
NormAct(channels),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Dropout(p_drop),
nn.Linear(channels, classes)
)
InĀ [15]:
def Stem(in_channels, out_channels):
return nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False)
InĀ [16]:
class Net(nn.Sequential):
def __init__(self, classes, width=32, in_channels=3, res_p_drop=0., head_p_drop=0.):
strides = [1, 2, 1, 2, 1, 2, 1]
super().__init__(
Stem(in_channels, width),
*[ResidualBlock(width, stride=stride, p_drop=res_p_drop) for stride in strides],
Head(width, classes, p_drop=head_p_drop)
)
InĀ [17]:
def reset_parameters(model):
for m in model.modules():
if isinstance(m, (nn.Linear, nn.Conv2d)):
nn.init.xavier_normal_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1.)
nn.init.zeros_(m.bias)
elif isinstance(m, ResidualBlock):
nn.init.zeros_(m.Ī³)
InĀ [18]:
model = Net(NUM_CLASSES, width=96, res_p_drop=0.1, head_p_drop=0.1).to(DEVICE)
InĀ [19]:
reset_parameters(model)
InĀ [20]:
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
Number of parameters: 80,177
TrainingĀ¶
Trainer setupĀ¶
Trainer
InĀ [21]:
loss = nn.CrossEntropyLoss()
InĀ [22]:
params = [p for p in model.parameters() if p.requires_grad]
InĀ [23]:
optimizer = optim.AdamW(params, lr=1e-6, weight_decay=WEIGHT_DECAY)
InĀ [24]:
trainer = create_supervised_trainer(model, optimizer, loss, device=DEVICE,
output_transform=lambda x, y, y_pred, loss: (y_pred, y, loss.item()))
InĀ [25]:
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
steps_per_epoch=len(train_loader), epochs=EPOCHS)
InĀ [26]:
trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step());
InĀ [27]:
ignite.metrics.RunningAverage(output_transform=lambda output: output[2]).attach(trainer, "loss")
InĀ [28]:
ignite.metrics.Accuracy(output_transform=lambda output: (output[0], output[1])).attach(trainer, "accuracy")
InĀ [29]:
pbar = ignite.contrib.handlers.ProgressBar(persist=True, ncols=140)
InĀ [30]:
pbar.attach(trainer, event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED)
Evaluator
InĀ [31]:
val_metrics = {"accuracy": ignite.metrics.Accuracy(), "loss": ignite.metrics.Loss(loss)}
InĀ [32]:
evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=DEVICE)
InĀ [33]:
history = defaultdict(list)
InĀ [34]:
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
train_state = engine.state
epoch = train_state.epoch
max_epochs = train_state.max_epochs
train_loss = train_state.metrics["loss"]
train_acc = train_state.metrics["accuracy"]
history['train loss'].append(train_loss)
history['train acc'].append(train_acc)
evaluator.run(test_loader)
val_metrics = evaluator.state.metrics
val_loss = val_metrics["loss"]
val_acc = val_metrics["accuracy"]
history['val loss'].append(val_loss)
history['val acc'].append(val_acc)
pbar.pbar.set_postfix({"loss": f"{train_loss:.3f}",
"acc": f"{train_acc:.3f}",
"val loss": f"{val_loss:.3f}",
"val acc": f"{val_acc:.3f}"})
Start trainingĀ¶
InĀ [35]:
trainer.run(train_loader, max_epochs=EPOCHS);
Epoch: [2000/2000] 100%|āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā, loss=0.100, acc=0.964, val loss=0.270, val acc=0.937 [8:54:31<00:00]
InĀ [36]:
def history_plot_train_val(history, key):
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['train ' + key]) + 1)
ax.plot(xs, history['train ' + key], '.-', label='train')
ax.plot(xs, history['val ' + key], '.-', label='val')
ax.set_xlabel('epoch')
ax.set_ylabel(key)
ax.legend()
ax.grid()
plt.show()
InĀ [37]:
history_plot_train_val(history, 'loss')
InĀ [38]:
history_plot_train_val(history, 'acc')