CIFAR10 classification using tinygrad¶

tinygrad is an end-to-end deep learning stack inspired by PyTorch, JAX, and TVM.

GitHub repo: https://github.com/tinygrad/tinygrad

Configuration¶

Imports

InĀ [1]:
from functools import partial
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt

from tinygrad import Tensor, nn, dtypes, TinyJit, Variable
from tinygrad.helpers import trange, Context, BEAM, WINO

Configuration

InĀ [2]:
NUM_CLASSES = 10
IMAGE_SIZE = 32

BATCH_SIZE = 32
EPOCHS = 2000
LEARNING_RATE = 1e-2
WEIGHT_DECAY = 1e-3

Model¶

InĀ [3]:
class NormAct:
    def __init__(self, channels):
        self.norm = nn.BatchNorm(channels)

    def __call__(self, x):
        return self.norm(x).relu()
InĀ [4]:
class ResidualBlock:
    def __init__(self, channels, stride=1, p_drop=0.):
        self.p_drop = p_drop
        
        if stride > 1:
            self.shortcut_conv = nn.Conv2d(channels, channels, stride, stride=stride, groups=channels, bias=False)
        else:
            self.shortcut_conv = lambda x: x
            
        self.residual = [
            NormAct(channels),
            nn.Conv2d(channels, channels, 2 + stride, stride=stride, padding=1, groups=channels, bias=False),
            NormAct(channels),
            nn.Conv2d(channels, channels, 1, bias=False)
        ]

        self.γ = Tensor.zeros(1)

    def __call__(self, x):
        out = self.shortcut_conv(x).dropout(self.p_drop) + self.γ * x.sequential(self.residual)
        return out
InĀ [5]:
class Head:
    def __init__(self, channels, classes, p_drop=0.):
        self.p_drop = p_drop
        self.norm = NormAct(channels)
        self.linear = nn.Linear(channels, classes)

    def __call__(self, x):
        x = self.norm(x).mean((2, 3)).dropout(self.p_drop)
        x = self.linear(x)
        return x
InĀ [6]:
def Stem(in_channels, out_channels):
    return nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False)
InĀ [7]:
class Net:
    def __init__(self, classes, width=32, in_channels=3, res_p_drop=0., head_p_drop=0.):
        strides = [1, 2, 1, 2, 1, 2, 1]
        self.layers = [
            Stem(in_channels, width),
            *[ResidualBlock(width, stride=stride, p_drop=res_p_drop) for stride in strides],
            Head(width, classes, p_drop=head_p_drop)
        ]
    
    def __call__(self, x):
        return x.sequential(self.layers)
InĀ [8]:
def reset_parameters(state_dict):
    for name, param in state_dict.items():
        if "norm" in name:
            if "weight" in name:
                param.assign(Tensor.ones(*param.shape))
            elif "bias" in name:
                param.assign(Tensor.zeros(*param.shape))
        elif "γ" in name:
            param.assign(Tensor.zeros(*param.shape))
        else:
            if "weight" in name:
                param.assign(Tensor.glorot_uniform(*param.shape))
            elif "bias" in name:
                param.assign(Tensor.zeros(*param.shape))
InĀ [9]:
model = Net(NUM_CLASSES, width=96, res_p_drop=0.1, head_p_drop=0.1)
InĀ [10]:
state_dict = nn.state.get_state_dict(model)
InĀ [11]:
reset_parameters(state_dict)
InĀ [12]:
print("Number of parameters: {:,}".format(sum(p.numel() for p in nn.state.get_parameters(model)
                                              if p.requires_grad is None or p.requires_grad)))
Number of parameters: 80,177

Data¶

InĀ [13]:
def plot_tensor_image(X):
    img = (X * 255).cast(dtypes.uint8).numpy()
    img = np.moveaxis(img, 0, -1)
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.set_axis_off()
    ax.imshow(img)
    fig.show()
InĀ [14]:
def reflection_pad(X, padding):
    X = X[..., :, 1:padding + 1].flip(-1).cat(X, X[..., :, -(padding + 1):-1].flip(-1), dim=-1)
    X = X[..., 1:padding + 1, :].flip(-2).cat(X, X[..., -(padding + 1):-1, :].flip(-2), dim=-2)
    return X
InĀ [15]:
train_transform = [
    lambda x: x.float() / 255.0,
    partial(reflection_pad, padding=4),
    lambda x: x.cast(dtypes.default_float)
]
InĀ [16]:
val_transform = [
    lambda x: x.float() / 255.0,
    lambda x: x.cast(dtypes.default_float)
]
InĀ [17]:
X_train, Y_train, X_test, Y_test = nn.datasets.cifar()
InĀ [18]:
X_train, X_test = X_train.sequential(train_transform), X_test.sequential(val_transform)
InĀ [19]:
plot_tensor_image(X_test[1])
No description has been provided for this image

Data loading¶

InĀ [20]:
def flip_LR(X, prob=0.5):
    X = (Tensor.rand(X.shape[0], 1, 1, 1) < prob).where(X.flip(-1), X)
    return X
InĀ [21]:
def random_crop(X, crop_size):
    b, c, h, w = X.shape
    low_x = Tensor.randint(b, low = 0, high = w - crop_size).reshape(b, 1, 1, 1)
    low_y = Tensor.randint(b, low = 0, high = h - crop_size).reshape(b, 1, 1, 1)

    idx_x = Tensor.arange(crop_size, dtype = dtypes.int32).reshape((1, 1, 1, crop_size))
    idx_y = Tensor.arange(crop_size, dtype = dtypes.int32).reshape((1, 1, crop_size, 1))

    idx_x = (low_x + idx_x).expand(-1, c, h, -1)
    idx_y = (low_y + idx_y).expand(-1, c, crop_size, crop_size)
    X = X.gather(-1, idx_x).gather(-2, idx_y)
    return X
InĀ [22]:
def random_brightness(X, low=0.7, high=1.3):
    factor = Tensor.uniform(X.shape[0], 1, 1, 1, low=low, high=high)
    X = (X * factor).clamp(0., 1.)
    return X
InĀ [23]:
@TinyJit
def augmentations(X):
    X = flip_LR(X)
    X = random_crop(X, crop_size=IMAGE_SIZE)
    X = random_brightness(X, low=0.7, high = 1.3)
    return X
InĀ [24]:
def fetch_batches(X_in, Y_in, batch_size, is_train):
    X, Y = X_in, Y_in
    data_size = X.shape[0]
    if is_train:
        perms = Tensor.randperm(data_size, device=X.device)
        X, Y = X[perms], Y[perms]
        X = augmentations(X)

    full_batches = (data_size // batch_size) * batch_size
    i_var = Variable("i", 0, full_batches - batch_size)
    for i in range(0, full_batches, batch_size):
        i_var_b = i_var.bind(i)
        X_batch, Y_batch = X[i_var_b:i_var_b + batch_size], Y[i_var_b:i_var_b + batch_size]
        yield X_batch, Y_batch

Training¶

Loss¶

InĀ [25]:
def loss_fn(out, Y):
    loss = out.sparse_categorical_crossentropy(Y)
    return loss

Scheduler¶

InĀ [26]:
class LR_Scheduler:
    def __init__(self, optimizer):
        self.optimizer = optimizer
        self.epoch_counter = Tensor([0], requires_grad=False, device=self.optimizer.device)
    
    def get_lr(self):
        pass
    
    def schedule_step(self):
        return [self.epoch_counter.assign(self.epoch_counter + 1), self.optimizer.lr.assign(self.get_lr())]
    
    def step(self):
        Tensor.realize(*self.schedule_step())
InĀ [27]:
class OneCycleLR(LR_Scheduler):
    def __init__(self, optimizer, max_lr, total_steps,
                 pct_start=0.3, div_factor=25.0, final_div_factor=10000.0):
        super().__init__(optimizer)
        self.initial_lr = max_lr / div_factor
        self.max_lr = max_lr
        self.min_lr = self.initial_lr / final_div_factor
        self.increase_steps = total_steps * pct_start
        self.decrease_steps = total_steps - self.increase_steps
        self.optimizer.lr.assign(self.get_lr()).realize() # update the initial LR

    @staticmethod
    def _annealing_linear(start, end, pct):
        return pct * (end - start) + start
    
    def get_lr(self):
        is_increasing = self.epoch_counter < self.increase_steps
        lr = is_increasing.where(
            self._annealing_linear(
                self.initial_lr,
                self.max_lr,
                self.epoch_counter / self.increase_steps
            ),
            self._annealing_linear(
                self.max_lr,
                self.min_lr,
                (self.epoch_counter - self.increase_steps) / self.decrease_steps
            )
        )
        lr = lr.cast(self.optimizer.lr.dtype)
        return lr

Training functions¶

InĀ [28]:
def iterate(step_fn, batcher):
    num_samples = 0
    total_loss = 0.
    num_correct = 0
    for x, y in batcher:
        loss, out = step_fn(x, y)
        pred = out.argmax(axis=-1)
        correct = (pred == y)
        loss, correct = loss.numpy(), correct.numpy()
        num_samples += correct.shape[0]
        total_loss += loss
        num_correct += np.sum(correct)
    
    avg_loss = total_loss / num_samples
    acc = num_correct / num_samples
    metrics = {"loss": avg_loss, "acc": acc}
    return metrics
InĀ [29]:
def train(model, loss_fn, optimizer, batcher, batch_scheduler):
    @TinyJit
    @Tensor.train()
    def train_step(x, y):
        out = model(x)
        loss = loss_fn(out, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        batch_scheduler.step()
        return loss.realize(), out.realize()

    with Context(BEAM=BEAM.value, WINO=WINO.value):
        metrics = iterate(train_step, batcher)
    return metrics
InĀ [30]:
def evaluate(model, loss_fn, batcher):
    @TinyJit
    def eval_step(x, y):
        out = model(x)
        loss = loss_fn(out, y)
        return loss.realize(), out.realize()

    metrics = iterate(eval_step, batcher)
    return metrics
InĀ [31]:
def update_history(history, metrics, name):
    for key, val in metrics.items():
        history[name + ' ' + key].append(val)
InĀ [32]:
def history_plot_train_val(history, key):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    xs = np.arange(1, len(history['train ' + key]) + 1)
    ax.plot(xs, history['train ' + key], '.-', label='train')
    ax.plot(xs, history['val ' + key], '.-', label='val')
    ax.set_xlabel('epoch')
    ax.set_ylabel(key)
    ax.legend()
    ax.grid()
    plt.show()

Start training¶

InĀ [33]:
num_train_samples = X_train.shape[0]
InĀ [34]:
num_steps_per_epoch = num_train_samples // BATCH_SIZE
InĀ [35]:
total_train_steps = num_steps_per_epoch * EPOCHS
InĀ [36]:
optimizer = nn.optim.AdamW(nn.state.get_parameters(model), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
InĀ [37]:
lr_scheduler = OneCycleLR(optimizer, max_lr=LEARNING_RATE, total_steps=total_train_steps)
InĀ [38]:
history = defaultdict(list)
pbar = trange(EPOCHS)
for epoch in pbar:
    train_batcher = fetch_batches(X_train, Y_train, BATCH_SIZE, is_train=True)
    train_metrics = train(model, loss_fn, optimizer, train_batcher, lr_scheduler)
    update_history(history, train_metrics, "train")
    
    val_batcher = fetch_batches(X_test, Y_test, BATCH_SIZE, is_train=False)
    val_metrics = evaluate(model, loss_fn, val_batcher)
    update_history(history, val_metrics, "val")
    pbar.set_description(f"acc={train_metrics['acc']:.3f}, val acc={val_metrics['acc']:.3f}")
acc=0.997, val acc=0.928: 100%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 2000/2000 [17:18:34<00:00,  0.03it/s]
InĀ [39]:
history_plot_train_val(history, 'loss')
No description has been provided for this image
InĀ [40]:
history_plot_train_val(history, 'acc')
No description has been provided for this image