Fast Forward Computer Vision (FFCV): train models with accelerated data loading
GitHub: https://github.com/libffcv/ffcv
Installation
conda create -y -n ffcv python=3.9 cupy pkg-config compilers libjpeg-turbo opencv pytorch torchvision cudatoolkit=11.3 numba -c pytorch -c conda-forge
conda activate ffcv
pip install ffcv
Imports
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.cuda.amp as amp
import torchvision
import ffcv
import ffcv.fields as fields
import ffcv.fields.decoders as decoders
import ffcv.transforms as transforms
Configuration
DATA_DIR='./data'
NUM_CLASSES = 10
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]
BATCH_SIZE = 512
NUM_WORKERS = 20
EPOCHS = 50
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)
device: cuda
def convert_dataset(dset, name):
writer = ffcv.writer.DatasetWriter(name + '.beton', {
'image': fields.RGBImageField(),
'label': fields.IntField()
})
writer.from_indexed_dataset(dset)
train_dset = torchvision.datasets.CIFAR10(DATA_DIR, train=True, download=True)
test_dset = torchvision.datasets.CIFAR10(DATA_DIR, train=False, download=True)
Files already downloaded and verified Files already downloaded and verified
Convert to FFCV format
convert_dataset(train_dset, 'cifar_train')
convert_dataset(test_dset, 'cifar_test')
100%|██████████████████████████████████| 50000/50000 [00:00<00:00, 82795.60it/s] 100%|██████████████████████████████████| 10000/10000 [00:00<00:00, 24838.28it/s]
def get_image_pipeline(train=True):
augmentation_pipeline = [
transforms.RandomHorizontalFlip(),
transforms.RandomTranslate(padding=2),
transforms.Cutout(8, tuple(map(int, CIFAR_MEAN)))
] if train else []
image_pipeline = [
decoders.SimpleRGBImageDecoder()
] + augmentation_pipeline + [
transforms.ToTensor(),
transforms.ToDevice(DEVICE, non_blocking=True),
transforms.ToTorchImage(),
transforms.Convert(torch.float16),
torchvision.transforms.Normalize(CIFAR_MEAN, CIFAR_STD)
]
return image_pipeline
label_pipeline = [
decoders.IntDecoder(),
transforms.ToTensor(),
transforms.ToDevice(DEVICE),
transforms.Squeeze()
]
train_image_pipeline = get_image_pipeline(train=True)
test_image_pipeline = get_image_pipeline(train=False)
train_loader = ffcv.loader.Loader(f'cifar_train.beton',
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
order=ffcv.loader.OrderOption.RANDOM,
drop_last=True,
pipelines={'image': train_image_pipeline,
'label': label_pipeline})
test_loader = ffcv.loader.Loader(f'cifar_test.beton',
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
order=ffcv.loader.OrderOption.SEQUENTIAL,
drop_last=False,
pipelines={'image': test_image_pipeline,
'label': label_pipeline})
class ConvBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, act=True):
padding = (kernel_size - 1) // 2
layers = [
nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm2d(out_channels)
]
if act: layers.append(nn.SiLU(inplace=True))
super().__init__(*layers)
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.residual = nn.Sequential(
ConvBlock(in_channels, out_channels),
ConvBlock(out_channels, out_channels, act=False)
)
self.shortcut = self.get_shortcut(in_channels, out_channels)
self.act = nn.SiLU(inplace=True)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
out = self.shortcut(x) + self.gamma * self.residual(x)
return self.act(out)
def get_shortcut(self, in_channels, out_channels):
if in_channels != out_channels:
shortcut = ConvBlock(in_channels, out_channels, 1, act=False)
else:
shortcut = nn.Identity()
return shortcut
class ResidualStack(nn.Sequential):
def __init__(self, in_channels, channels_list, num_blocks_list, strides):
layers = []
for num_blocks, out_channels, stride in zip(num_blocks_list, channels_list, strides):
if stride > 1: layers.append(nn.MaxPool2d(stride))
for _ in range(num_blocks):
layers.append(ResidualBlock(in_channels, out_channels))
in_channels = out_channels
super().__init__(*layers)
class Stem(nn.Sequential):
def __init__(self, in_channels, channels_list, stride):
layers = []
for out_channels in channels_list:
layers.append(ConvBlock(in_channels, out_channels, 3, stride=stride))
in_channels = out_channels
stride = 1
super().__init__(*layers)
class Head(nn.Sequential):
def __init__(self, in_channels, classes, p_drop=0.):
super().__init__(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Dropout(p_drop),
nn.Linear(in_channels, classes)
)
class ResNet(nn.Sequential):
def __init__(self, classes, num_blocks_list, channels_list, strides, in_channels=3, head_p_drop=0.):
super().__init__(
Stem(in_channels, [32, 32, 64], strides[0]),
ResidualStack(64, channels_list, num_blocks_list, strides[1:]),
Head(channels_list[-1], classes, head_p_drop)
)
def init_linear(m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
model = ResNet(NUM_CLASSES,
num_blocks_list=[2, 2, 2],
channels_list=[64, 128, 256],
strides=[1, 1, 2, 2],
head_p_drop=0.3)
model.apply(init_linear);
model = model.to(memory_format=torch.channels_last).cuda()
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
Number of parameters: 2,804,592
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.AdamW(model.parameters(), lr=1e-2, weight_decay=1e-2)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-2, steps_per_epoch=len(train_loader), epochs=EPOCHS)
scaler = amp.GradScaler()
for e in range(EPOCHS):
print(f'{e}/{EPOCHS}: ', end='')
model.train()
total_loss, total_num = 0., 0
for images, labels in train_loader:
optimizer.zero_grad(set_to_none=True)
with amp.autocast():
out = model(images)
loss = loss_fn(out, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
scheduler.step()
batch_size = images.shape[0]
total_loss += batch_size * loss.item()
total_num += batch_size
print(f'train loss: {total_loss / total_num:.3f}; ', end='')
model.eval()
with torch.no_grad():
total_correct, total_num = 0., 0.
for images, labels in test_loader:
with amp.autocast():
out = (model(images) + model(torch.fliplr(images))) / 2. # Test-time augmentation
total_correct += (out.argmax(1) == labels).sum().cpu().item()
total_num += images.shape[0]
print(f'val accuracy: {total_correct / total_num:.3f}')
0/50: train loss: 2.023; val accuracy: 0.379 1/50: train loss: 1.616; val accuracy: 0.524 2/50: train loss: 1.400; val accuracy: 0.583 3/50: train loss: 1.255; val accuracy: 0.654 4/50: train loss: 1.145; val accuracy: 0.664 5/50: train loss: 1.073; val accuracy: 0.745 6/50: train loss: 1.001; val accuracy: 0.669 7/50: train loss: 0.952; val accuracy: 0.775 8/50: train loss: 0.908; val accuracy: 0.795 9/50: train loss: 0.872; val accuracy: 0.802 10/50: train loss: 0.839; val accuracy: 0.825 11/50: train loss: 0.815; val accuracy: 0.787 12/50: train loss: 0.786; val accuracy: 0.808 13/50: train loss: 0.764; val accuracy: 0.850 14/50: train loss: 0.745; val accuracy: 0.873 15/50: train loss: 0.722; val accuracy: 0.869 16/50: train loss: 0.704; val accuracy: 0.876 17/50: train loss: 0.687; val accuracy: 0.873 18/50: train loss: 0.678; val accuracy: 0.868 19/50: train loss: 0.664; val accuracy: 0.884 20/50: train loss: 0.650; val accuracy: 0.889 21/50: train loss: 0.640; val accuracy: 0.889 22/50: train loss: 0.630; val accuracy: 0.883 23/50: train loss: 0.616; val accuracy: 0.897 24/50: train loss: 0.613; val accuracy: 0.896 25/50: train loss: 0.601; val accuracy: 0.904 26/50: train loss: 0.591; val accuracy: 0.906 27/50: train loss: 0.582; val accuracy: 0.911 28/50: train loss: 0.576; val accuracy: 0.906 29/50: train loss: 0.569; val accuracy: 0.915 30/50: train loss: 0.563; val accuracy: 0.918 31/50: train loss: 0.559; val accuracy: 0.915 32/50: train loss: 0.554; val accuracy: 0.920 33/50: train loss: 0.549; val accuracy: 0.923 34/50: train loss: 0.543; val accuracy: 0.920 35/50: train loss: 0.538; val accuracy: 0.926 36/50: train loss: 0.536; val accuracy: 0.930 37/50: train loss: 0.533; val accuracy: 0.927 38/50: train loss: 0.529; val accuracy: 0.928 39/50: train loss: 0.527; val accuracy: 0.930 40/50: train loss: 0.526; val accuracy: 0.931 41/50: train loss: 0.524; val accuracy: 0.932 42/50: train loss: 0.522; val accuracy: 0.932 43/50: train loss: 0.521; val accuracy: 0.933 44/50: train loss: 0.521; val accuracy: 0.933 45/50: train loss: 0.520; val accuracy: 0.934 46/50: train loss: 0.519; val accuracy: 0.934 47/50: train loss: 0.519; val accuracy: 0.933 48/50: train loss: 0.518; val accuracy: 0.934 49/50: train loss: 0.518; val accuracy: 0.933