According to arXiv:2201.03545 [cs.CV]
Original implementation in https://github.com/facebookresearch/ConvNeXt
Imports
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
import ignite.metrics
import ignite.contrib.handlers
Configuration
DATA_DIR='./data'
IMAGE_SIZE = 32
NUM_CLASSES = 10
NUM_WORKERS = 8
BATCH_SIZE = 32
EPOCHS = 100
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-1
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)
device: cuda
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(IMAGE_SIZE, padding=4),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor()
])
train_dset = datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=train_transform)
test_dset = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transforms.ToTensor())
Files already downloaded and verified Files already downloaded and verified
def dataset_show_image(dset, idx):
X, Y = dset[idx]
title = "Ground truth: {}".format(dset.classes[Y])
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_axis_off()
ax.imshow(np.moveaxis(X.numpy(), 0, -1))
ax.set_title(title)
plt.show()
dataset_show_image(test_dset, 1)
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_WORKERS, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, pin_memory=True)
Utilities
class LayerNormChannels(nn.Module):
def __init__(self, channels):
super().__init__()
self.norm = nn.LayerNorm(channels)
def forward(self, x):
x = x.transpose(1, -1)
x = self.norm(x)
x = x.transpose(-1, 1)
return x
class Residual(nn.Module):
def __init__(self, *layers):
super().__init__()
self.residual = nn.Sequential(*layers)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
return x + self.gamma * self.residual(x)
ConvNeXt stages
class ConvNeXtBlock(Residual):
def __init__(self, channels, kernel_size, mult=4, p_drop=0.):
padding = (kernel_size - 1) // 2
hidden_channels = channels * mult
super().__init__(
nn.Conv2d(channels, channels, kernel_size, padding=padding, groups=channels),
LayerNormChannels(channels),
nn.Conv2d(channels, hidden_channels, 1),
nn.GELU(),
nn.Conv2d(hidden_channels, channels, 1),
nn.Dropout(p_drop)
)
class DownsampleBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, stride=2):
super().__init__(
LayerNormChannels(in_channels),
nn.Conv2d(in_channels, out_channels, stride, stride=stride)
)
class Stage(nn.Sequential):
def __init__(self, in_channels, out_channels, num_blocks, kernel_size, p_drop=0.):
layers = [] if in_channels == out_channels else [DownsampleBlock(in_channels, out_channels)]
layers += [ConvNeXtBlock(out_channels, kernel_size, p_drop=p_drop) for _ in range(num_blocks)]
super().__init__(*layers)
class ConvNeXtBody(nn.Sequential):
def __init__(self, in_channels, channel_list, num_blocks_list, kernel_size, p_drop=0.):
layers = []
for out_channels, num_blocks in zip(channel_list, num_blocks_list):
layers.append(Stage(in_channels, out_channels, num_blocks, kernel_size, p_drop))
in_channels = out_channels
super().__init__(*layers)
Main model
class Stem(nn.Sequential):
def __init__(self, in_channels, out_channels, patch_size):
super().__init__(
nn.Conv2d(in_channels, out_channels, patch_size, stride=patch_size),
LayerNormChannels(out_channels)
)
class Head(nn.Sequential):
def __init__(self, in_channels, classes):
super().__init__(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.LayerNorm(in_channels),
nn.Linear(in_channels, classes)
)
class ConvNeXt(nn.Sequential):
def __init__(self, classes, channel_list, num_blocks_list, kernel_size, patch_size,
in_channels=3, res_p_drop=0.):
super().__init__(
Stem(in_channels, channel_list[0], patch_size),
ConvNeXtBody(channel_list[0], channel_list, num_blocks_list, kernel_size, res_p_drop),
Head(channel_list[-1], classes)
)
self.reset_parameters()
def reset_parameters(self):
for m in self.modules():
if isinstance(m, (nn.Linear, nn.Conv2d)):
nn.init.normal_(m.weight, std=0.02)
if m.bias is not None: nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.weight, 1.)
nn.init.zeros_(m.bias)
elif isinstance(m, Residual):
nn.init.zeros_(m.gamma)
def separate_parameters(self):
parameters_decay = set()
parameters_no_decay = set()
modules_weight_decay = (nn.Linear, nn.Conv2d)
modules_no_weight_decay = (nn.LayerNorm,)
for m_name, m in self.named_modules():
for param_name, param in m.named_parameters():
full_param_name = f"{m_name}.{param_name}" if m_name else param_name
if isinstance(m, modules_no_weight_decay):
parameters_no_decay.add(full_param_name)
elif param_name.endswith("bias"):
parameters_no_decay.add(full_param_name)
elif isinstance(m, Residual) and param_name.endswith("gamma"):
parameters_no_decay.add(full_param_name)
elif isinstance(m, modules_weight_decay):
parameters_decay.add(full_param_name)
# sanity check
assert len(parameters_decay & parameters_no_decay) == 0
assert len(parameters_decay) + len(parameters_no_decay) == len(list(model.parameters()))
return parameters_decay, parameters_no_decay
model = ConvNeXt(NUM_CLASSES,
channel_list = [64, 128, 256, 512],
num_blocks_list = [2, 2, 2, 2],
kernel_size=7, patch_size=1,
res_p_drop=0.)
model.to(DEVICE);
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
Number of parameters: 6,376,466
def get_optimizer(model, learning_rate, weight_decay):
param_dict = {pn: p for pn, p in model.named_parameters()}
parameters_decay, parameters_no_decay = model.separate_parameters()
optim_groups = [
{"params": [param_dict[pn] for pn in parameters_decay], "weight_decay": weight_decay},
{"params": [param_dict[pn] for pn in parameters_no_decay], "weight_decay": 0.0},
]
optimizer = optim.AdamW(optim_groups, lr=learning_rate)
return optimizer
loss = nn.CrossEntropyLoss()
optimizer = get_optimizer(model, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
trainer = create_supervised_trainer(model, optimizer, loss, device=DEVICE)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
steps_per_epoch=len(train_loader), epochs=EPOCHS)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step());
ignite.metrics.RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
Evaluator
val_metrics = {"accuracy": ignite.metrics.Accuracy(), "loss": ignite.metrics.Loss(loss)}
evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=DEVICE)
history = defaultdict(list)
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
train_state = engine.state
epoch = train_state.epoch
max_epochs = train_state.max_epochs
train_loss = train_state.metrics["loss"]
history['train loss'].append(train_loss)
evaluator.run(test_loader)
val_metrics = evaluator.state.metrics
val_loss = val_metrics["loss"]
val_acc = val_metrics["accuracy"]
history['val loss'].append(val_loss)
history['val acc'].append(val_acc)
print("{}/{} - train: loss {:.3f}; val: loss {:.3f} accuracy {:.3f}".format(
epoch, max_epochs, train_loss, val_loss, val_acc))
trainer.run(train_loader, max_epochs=EPOCHS);
1/100 - train: loss 1.855; val: loss 1.797 accuracy 0.339 2/100 - train: loss 1.648; val: loss 1.647 accuracy 0.409 3/100 - train: loss 1.566; val: loss 1.536 accuracy 0.446 4/100 - train: loss 1.531; val: loss 1.461 accuracy 0.469 5/100 - train: loss 1.465; val: loss 1.388 accuracy 0.501 6/100 - train: loss 1.385; val: loss 1.372 accuracy 0.502 7/100 - train: loss 1.330; val: loss 1.286 accuracy 0.540 8/100 - train: loss 1.258; val: loss 1.252 accuracy 0.548 9/100 - train: loss 1.164; val: loss 1.207 accuracy 0.568 10/100 - train: loss 1.100; val: loss 1.046 accuracy 0.631 11/100 - train: loss 1.058; val: loss 0.979 accuracy 0.653 12/100 - train: loss 0.957; val: loss 0.922 accuracy 0.673 13/100 - train: loss 0.925; val: loss 0.877 accuracy 0.692 14/100 - train: loss 0.867; val: loss 0.811 accuracy 0.713 15/100 - train: loss 0.788; val: loss 0.772 accuracy 0.734 16/100 - train: loss 0.744; val: loss 0.711 accuracy 0.753 17/100 - train: loss 0.725; val: loss 0.656 accuracy 0.768 18/100 - train: loss 0.689; val: loss 0.645 accuracy 0.779 19/100 - train: loss 0.582; val: loss 0.595 accuracy 0.790 20/100 - train: loss 0.588; val: loss 0.581 accuracy 0.798 21/100 - train: loss 0.536; val: loss 0.535 accuracy 0.816 22/100 - train: loss 0.530; val: loss 0.578 accuracy 0.801 23/100 - train: loss 0.514; val: loss 0.544 accuracy 0.812 24/100 - train: loss 0.465; val: loss 0.495 accuracy 0.833 25/100 - train: loss 0.481; val: loss 0.507 accuracy 0.831 26/100 - train: loss 0.483; val: loss 0.573 accuracy 0.802 27/100 - train: loss 0.464; val: loss 0.497 accuracy 0.829 28/100 - train: loss 0.486; val: loss 0.529 accuracy 0.816 29/100 - train: loss 0.402; val: loss 0.486 accuracy 0.833 30/100 - train: loss 0.428; val: loss 0.530 accuracy 0.821 31/100 - train: loss 0.410; val: loss 0.532 accuracy 0.815 32/100 - train: loss 0.418; val: loss 0.498 accuracy 0.829 33/100 - train: loss 0.403; val: loss 0.437 accuracy 0.852 34/100 - train: loss 0.372; val: loss 0.502 accuracy 0.830 35/100 - train: loss 0.359; val: loss 0.437 accuracy 0.851 36/100 - train: loss 0.353; val: loss 0.437 accuracy 0.854 37/100 - train: loss 0.347; val: loss 0.421 accuracy 0.855 38/100 - train: loss 0.360; val: loss 0.440 accuracy 0.848 39/100 - train: loss 0.376; val: loss 0.416 accuracy 0.857 40/100 - train: loss 0.328; val: loss 0.423 accuracy 0.855 41/100 - train: loss 0.334; val: loss 0.406 accuracy 0.866 42/100 - train: loss 0.333; val: loss 0.420 accuracy 0.858 43/100 - train: loss 0.337; val: loss 0.441 accuracy 0.852 44/100 - train: loss 0.305; val: loss 0.403 accuracy 0.865 45/100 - train: loss 0.325; val: loss 0.489 accuracy 0.836 46/100 - train: loss 0.280; val: loss 0.395 accuracy 0.867 47/100 - train: loss 0.267; val: loss 0.387 accuracy 0.871 48/100 - train: loss 0.282; val: loss 0.383 accuracy 0.875 49/100 - train: loss 0.274; val: loss 0.398 accuracy 0.867 50/100 - train: loss 0.257; val: loss 0.390 accuracy 0.877 51/100 - train: loss 0.257; val: loss 0.373 accuracy 0.877 52/100 - train: loss 0.234; val: loss 0.388 accuracy 0.871 53/100 - train: loss 0.248; val: loss 0.405 accuracy 0.867 54/100 - train: loss 0.262; val: loss 0.368 accuracy 0.879 55/100 - train: loss 0.229; val: loss 0.346 accuracy 0.885 56/100 - train: loss 0.228; val: loss 0.362 accuracy 0.885 57/100 - train: loss 0.197; val: loss 0.390 accuracy 0.873 58/100 - train: loss 0.199; val: loss 0.377 accuracy 0.880 59/100 - train: loss 0.204; val: loss 0.367 accuracy 0.882 60/100 - train: loss 0.185; val: loss 0.363 accuracy 0.884 61/100 - train: loss 0.167; val: loss 0.383 accuracy 0.881 62/100 - train: loss 0.180; val: loss 0.380 accuracy 0.878 63/100 - train: loss 0.150; val: loss 0.382 accuracy 0.883 64/100 - train: loss 0.138; val: loss 0.363 accuracy 0.890 65/100 - train: loss 0.140; val: loss 0.325 accuracy 0.899 66/100 - train: loss 0.133; val: loss 0.361 accuracy 0.894 67/100 - train: loss 0.136; val: loss 0.357 accuracy 0.893 68/100 - train: loss 0.104; val: loss 0.357 accuracy 0.897 69/100 - train: loss 0.125; val: loss 0.395 accuracy 0.885 70/100 - train: loss 0.104; val: loss 0.374 accuracy 0.895 71/100 - train: loss 0.087; val: loss 0.369 accuracy 0.896 72/100 - train: loss 0.087; val: loss 0.344 accuracy 0.898 73/100 - train: loss 0.069; val: loss 0.325 accuracy 0.906 74/100 - train: loss 0.076; val: loss 0.364 accuracy 0.902 75/100 - train: loss 0.058; val: loss 0.367 accuracy 0.902 76/100 - train: loss 0.049; val: loss 0.371 accuracy 0.903 77/100 - train: loss 0.056; val: loss 0.395 accuracy 0.899 78/100 - train: loss 0.048; val: loss 0.382 accuracy 0.901 79/100 - train: loss 0.042; val: loss 0.372 accuracy 0.904 80/100 - train: loss 0.044; val: loss 0.378 accuracy 0.906 81/100 - train: loss 0.031; val: loss 0.387 accuracy 0.903 82/100 - train: loss 0.019; val: loss 0.390 accuracy 0.911 83/100 - train: loss 0.020; val: loss 0.406 accuracy 0.907 84/100 - train: loss 0.013; val: loss 0.412 accuracy 0.909 85/100 - train: loss 0.019; val: loss 0.409 accuracy 0.910 86/100 - train: loss 0.016; val: loss 0.406 accuracy 0.909 87/100 - train: loss 0.007; val: loss 0.389 accuracy 0.913 88/100 - train: loss 0.010; val: loss 0.412 accuracy 0.912 89/100 - train: loss 0.005; val: loss 0.399 accuracy 0.916 90/100 - train: loss 0.003; val: loss 0.406 accuracy 0.917 91/100 - train: loss 0.004; val: loss 0.403 accuracy 0.915 92/100 - train: loss 0.002; val: loss 0.404 accuracy 0.918 93/100 - train: loss 0.002; val: loss 0.396 accuracy 0.919 94/100 - train: loss 0.001; val: loss 0.398 accuracy 0.922 95/100 - train: loss 0.001; val: loss 0.396 accuracy 0.922 96/100 - train: loss 0.001; val: loss 0.394 accuracy 0.921 97/100 - train: loss 0.001; val: loss 0.397 accuracy 0.922 98/100 - train: loss 0.000; val: loss 0.395 accuracy 0.922 99/100 - train: loss 0.001; val: loss 0.396 accuracy 0.922 100/100 - train: loss 0.001; val: loss 0.396 accuracy 0.922
def plot_history_train_val(history, key):
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['train ' + key]) + 1)
ax.plot(xs, history['train ' + key], '.-', label='train')
ax.plot(xs, history['val ' + key], '.-', label='val')
ax.set_xlabel('epoch')
ax.set_ylabel(key)
ax.legend()
ax.grid()
plt.show()
def plot_history(history, key):
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history[key]) + 1)
ax.plot(xs, history[key], '-')
ax.set_xlabel('epoch')
ax.set_ylabel(key)
ax.grid()
plt.show()
plot_history_train_val(history, 'loss')
plot_history(history, 'val acc')