Vision Transformer (ViT)

According to "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale". Paper in OpenReview

Original transformer: arXiv:1706.03762 [cs.CL]

vit.png

Configuration

Imports

In [1]:
from collections import defaultdict
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from tqdm import tqdm

import matplotlib.pyplot as plt

Configuration

In [2]:
IMAGE_SIZE = 32
PATCH_SIZE = 4
NUM_CLASSES = 10

BATCH_SIZE = 128
NUM_WORKERS = 4
EPOCHS = 50
LEARNING_RATE = 3e-4

SAVE_PATH = 'weights.pkl'
In [3]:
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)
device: cuda

Dataset

In [4]:
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor()
])
In [5]:
train_dset = datasets.CIFAR10(root='.', train=True, download=True, transform=train_transform)
test_dset = datasets.CIFAR10(root='.', train=False, download=True, transform=transforms.ToTensor())
Files already downloaded and verified
Files already downloaded and verified
In [6]:
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True,
                                           num_workers=NUM_WORKERS)
val_loader = torch.utils.data.DataLoader(test_dset, batch_size=BATCH_SIZE, shuffle=False,
                                         num_workers=NUM_WORKERS)

Model

In [7]:
class Residual(nn.Module):
    def __init__(self, *layers):
        super().__init__()
        self.residual = nn.Sequential(*layers)
    
    def forward(self, x):
        return x + self.residual(x)
In [8]:
class FeedForward(nn.Sequential):
    def __init__(self, dim, hidden_dim, p_drop=0.):
        super().__init__(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(p_drop),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(p_drop)
        )
In [9]:
class SelfAttention(nn.Module):
    def __init__(self, dim, heads=8, p_drop=0.):
        super().__init__()
        self.heads = heads

        self.to_keys = nn.Linear(dim, dim)
        self.to_queries = nn.Linear(dim, dim)
        self.to_values = nn.Linear(dim, dim)
        self.unifyheads = nn.Linear(dim, dim)

        self.attn_drop = nn.Dropout(p_drop)
        self.resid_drop = nn.Dropout(p_drop)

    def forward(self, x):
        b, t, d = x.size()
        h, d_q = self.heads, d // self.heads

        keys = self.to_keys(x).view(b, t, h, d_q).transpose(1, 2) # move head forward to the batch dim
        queries = self.to_queries(x).view(b, t, h, d_q).transpose(1, 2)
        values = self.to_values(x).view(b, t, h, d_q).transpose(1, 2)

        att = queries @ keys.transpose(-2, -1)
        att = F.softmax(att * d_q**-0.5, dim=-1)
        att = self.attn_drop(att)
        
        out = att @ values
        out = out.transpose(1, 2).contiguous().view(b, t, d) # move head back
        out = self.unifyheads(out)
        out = self.resid_drop(out)
        return out
In [10]:
class Block(nn.Sequential):
    def __init__(self, dim, heads, mlp_dim, p_drop):
        super().__init__(
            Residual(nn.LayerNorm(dim), SelfAttention(dim, heads, p_drop)),
            Residual(nn.LayerNorm(dim), FeedForward(dim, mlp_dim, p_drop))
        )
In [11]:
class Transformer(nn.Sequential):
    def __init__(self, dim, depth, heads, mlp_dim, p_drop):
        layers = [Block(dim, heads, mlp_dim, p_drop) for _ in range(depth)]
        super().__init__(*layers)
In [12]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, dim, channels=3, emb_p_drop=0.):
        super().__init__()
        self.patch_size = patch_size
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2

        self.patch_to_embedding = nn.Linear(patch_dim, dim)
        self.pos_embedding = nn.Parameter(torch.Tensor(1, num_patches, dim))
        self.cls_token = nn.Parameter(torch.Tensor(1, 1, dim))
        self.emb_dropout = nn.Dropout(emb_p_drop)
    
    def forward(self, x):
        p = self.patch_size
        x = F.unfold(x, p, stride=p).transpose(-1, -2) # extract patches
        x = self.patch_to_embedding(x)
        x += self.pos_embedding # add positional embedding
        # add classification token
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = self.emb_dropout(x)
        return x
In [13]:
class TakeFirst(nn.Module):
    def forward(self, x):
        return x[:, 0]
In [14]:
class Head(nn.Sequential):
    def __init__(self, dim, hidden_dim, num_classes, p_drop=0.):
        super().__init__(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(p_drop),
            nn.Linear(hidden_dim, num_classes)
        )
In [15]:
class ViT(nn.Sequential):
    def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3,
                 p_drop=0., emb_p_drop=0.):
        super().__init__(
            PatchEmbedding(image_size, patch_size, dim, channels, emb_p_drop),
            Transformer(dim, depth, heads, mlp_dim, p_drop),
            TakeFirst(),
            Head(dim, mlp_dim, num_classes, p_drop)
        )
        self.reset_parameters()
    
    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0.0, std=0.02)
                if m.bias is not None: nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.weight, 1.)
                nn.init.zeros_(m.bias)
            elif isinstance(m, PatchEmbedding):
                nn.init.normal_(m.pos_embedding, mean=0.0, std=0.02)
                nn.init.normal_(m.cls_token, mean=0.0, std=0.02)
In [16]:
model = ViT(IMAGE_SIZE, PATCH_SIZE, NUM_CLASSES, dim=512, depth=8, heads=8, mlp_dim=2048,
            p_drop=0.1, emb_p_drop = 0.1).to(DEVICE)

Training

Training loop

In [18]:
def show_or_save(fig, filename=None):
    if filename:
        fig.savefig(filename, bbox_inches='tight', pad_inches=0.05)
        plt.close(fig)
    else:
        plt.show()
In [19]:
class History:
    def __init__(self):
        self.values = defaultdict(list)

    def append(self, key, value):
        self.values[key].append(value)

    def reset(self):
        for k in self.values.keys():
            self.values[k] = []

    def _plot(self, key, line_type='-', label=None):
        if not label: label=key
        xs = np.arange(1, len(self.values[key])+1)
        self.ax.plot(xs, self.values[key], line_type, label=label)

    def plot_train_val(self, key, x_is_batch=False, ylog=False, filename=None):
        fig = plt.figure()
        self.ax = fig.add_subplot(111)
        self._plot('train ' + key, '.-', 'train')
        self._plot('val ' + key, '.-', 'val')
        self.ax.legend()
        if ylog: self.ax.set_yscale('log')
        self.ax.set_xlabel('batch' if x_is_batch else 'epoch')
        self.ax.set_ylabel(key)
        show_or_save(fig, filename)
In [20]:
class Learner:
    def __init__(self, model, loss, optimizer, train_loader, val_loader, device,
                 epoch_scheduler=None, batch_scheduler=None):
        self.model = model
        self.loss = loss
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.epoch_scheduler = epoch_scheduler
        self.batch_scheduler = batch_scheduler
        self.history = History()
    
    
    def iterate(self, loader, msg="", backward_pass=False):
        total_loss = 0.0
        num_samples = 0
        num_correct = 0
        
        pbar = tqdm(enumerate(loader), total=len(loader))
        for it, (X, Y) in pbar:
            X, Y = X.to(self.device), Y.to(self.device)
            Y_pred = self.model(X)
            batch_size = X.size(0)
            batch_loss = self.loss(Y_pred, Y)
            if backward_pass:
                self.optimizer.zero_grad()
                batch_loss.backward()
                self.optimizer.step()
                if self.batch_scheduler is not None:
                    self.batch_scheduler.step()
            
            Y_pred.detach_() # conserve memory
            labels_pred = torch.argmax(Y_pred, -1)
            total_loss += batch_size * batch_loss.item()
            num_correct += (labels_pred == Y).sum()
            num_samples += batch_size
            
            pbar.set_description("{} iter {}: loss {:.3f}, acc {:.3f}".format(
                msg, it, total_loss / num_samples, float(num_correct) / num_samples))
    
        avg_loss = total_loss / num_samples
        accuracy = float(num_correct) / num_samples
        return avg_loss, accuracy
    
    
    def train(self, msg):
        self.model.train()
        train_loss, train_acc = self.iterate(self.train_loader, msg + ' train:', backward_pass=True)
        self.history.append('train loss', train_loss)
        self.history.append('train acc', train_acc)

        
    def validate(self, msg):
        self.model.eval()
        with torch.no_grad():
            val_loss, val_acc = self.iterate(self.val_loader, msg + ' val:')
        self.history.append('val loss', val_loss)
        self.history.append('val acc', val_acc)


    def fit(self, epochs):
        for e in range(epochs):
            msg = f'epoch {e+1}/{epochs}'
            self.train(msg)
            self.validate(msg)
            if self.epoch_scheduler is not None:
                self.epoch_scheduler.step()

Start training

In [21]:
loss = nn.CrossEntropyLoss()
In [22]:
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.1)
In [25]:
learner = Learner(model, loss, optimizer, train_loader, val_loader, DEVICE)
In [26]:
learner.batch_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
                                                        steps_per_epoch=len(train_loader),
                                                        epochs=EPOCHS)
In [27]:
learner.fit(EPOCHS)
epoch 1/50 train: iter 390: loss 2.049, acc 0.224: 100%|██████████| 391/391 [02:29<00:00,  2.62it/s]
epoch 1/50 val: iter 78: loss 1.913, acc 0.274: 100%|██████████| 79/79 [00:09<00:00,  7.94it/s]
epoch 2/50 train: iter 390: loss 1.845, acc 0.299: 100%|██████████| 391/391 [02:31<00:00,  2.57it/s]
epoch 2/50 val: iter 78: loss 1.693, acc 0.348: 100%|██████████| 79/79 [00:10<00:00,  7.88it/s]
epoch 3/50 train: iter 390: loss 1.724, acc 0.353: 100%|██████████| 391/391 [02:32<00:00,  2.57it/s]
epoch 3/50 val: iter 78: loss 1.603, acc 0.419: 100%|██████████| 79/79 [00:10<00:00,  7.87it/s]
epoch 4/50 train: iter 390: loss 1.614, acc 0.404: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 4/50 val: iter 78: loss 1.526, acc 0.443: 100%|██████████| 79/79 [00:10<00:00,  7.86it/s]
epoch 5/50 train: iter 390: loss 1.533, acc 0.440: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 5/50 val: iter 78: loss 1.407, acc 0.489: 100%|██████████| 79/79 [00:10<00:00,  7.86it/s]
epoch 6/50 train: iter 390: loss 1.454, acc 0.470: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 6/50 val: iter 78: loss 1.350, acc 0.511: 100%|██████████| 79/79 [00:10<00:00,  7.86it/s]
epoch 7/50 train: iter 390: loss 1.412, acc 0.486: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 7/50 val: iter 78: loss 1.292, acc 0.527: 100%|██████████| 79/79 [00:10<00:00,  7.85it/s]
epoch 8/50 train: iter 390: loss 1.372, acc 0.505: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 8/50 val: iter 78: loss 1.269, acc 0.531: 100%|██████████| 79/79 [00:10<00:00,  7.85it/s]
epoch 9/50 train: iter 390: loss 1.347, acc 0.512: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 9/50 val: iter 78: loss 1.285, acc 0.540: 100%|██████████| 79/79 [00:10<00:00,  7.85it/s]
epoch 10/50 train: iter 390: loss 1.321, acc 0.522: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 10/50 val: iter 78: loss 1.232, acc 0.549: 100%|██████████| 79/79 [00:10<00:00,  7.85it/s]
epoch 11/50 train: iter 390: loss 1.307, acc 0.528: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 11/50 val: iter 78: loss 1.270, acc 0.548: 100%|██████████| 79/79 [00:10<00:00,  7.86it/s]
epoch 12/50 train: iter 390: loss 1.286, acc 0.537: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 12/50 val: iter 78: loss 1.243, acc 0.554: 100%|██████████| 79/79 [00:10<00:00,  7.86it/s]
epoch 13/50 train: iter 390: loss 1.268, acc 0.544: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 13/50 val: iter 78: loss 1.194, acc 0.576: 100%|██████████| 79/79 [00:10<00:00,  7.84it/s]
epoch 14/50 train: iter 390: loss 1.245, acc 0.552: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 14/50 val: iter 78: loss 1.159, acc 0.584: 100%|██████████| 79/79 [00:10<00:00,  7.84it/s]
epoch 15/50 train: iter 390: loss 1.241, acc 0.554: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 15/50 val: iter 78: loss 1.159, acc 0.584: 100%|██████████| 79/79 [00:10<00:00,  7.84it/s]
epoch 16/50 train: iter 390: loss 1.222, acc 0.559: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 16/50 val: iter 78: loss 1.109, acc 0.594: 100%|██████████| 79/79 [00:10<00:00,  7.85it/s]
epoch 17/50 train: iter 390: loss 1.201, acc 0.568: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 17/50 val: iter 78: loss 1.131, acc 0.595: 100%|██████████| 79/79 [00:10<00:00,  7.86it/s]
epoch 18/50 train: iter 390: loss 1.180, acc 0.573: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 18/50 val: iter 78: loss 1.163, acc 0.575: 100%|██████████| 79/79 [00:10<00:00,  7.85it/s]
epoch 19/50 train: iter 390: loss 1.172, acc 0.578: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 19/50 val: iter 78: loss 1.103, acc 0.606: 100%|██████████| 79/79 [00:10<00:00,  7.84it/s]
epoch 20/50 train: iter 390: loss 1.149, acc 0.584: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 20/50 val: iter 78: loss 1.064, acc 0.612: 100%|██████████| 79/79 [00:10<00:00,  7.84it/s]
epoch 21/50 train: iter 390: loss 1.134, acc 0.591: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 21/50 val: iter 78: loss 1.096, acc 0.612: 100%|██████████| 79/79 [00:10<00:00,  7.85it/s]
epoch 22/50 train: iter 390: loss 1.115, acc 0.598: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 22/50 val: iter 78: loss 1.031, acc 0.636: 100%|██████████| 79/79 [00:10<00:00,  7.85it/s]
epoch 23/50 train: iter 390: loss 1.097, acc 0.604: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 23/50 val: iter 78: loss 1.027, acc 0.634: 100%|██████████| 79/79 [00:10<00:00,  7.85it/s]
epoch 24/50 train: iter 390: loss 1.076, acc 0.609: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 24/50 val: iter 78: loss 1.008, acc 0.637: 100%|██████████| 79/79 [00:10<00:00,  7.85it/s]
epoch 25/50 train: iter 390: loss 1.061, acc 0.621: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 25/50 val: iter 78: loss 0.996, acc 0.645: 100%|██████████| 79/79 [00:10<00:00,  7.85it/s]
epoch 26/50 train: iter 390: loss 1.046, acc 0.622: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 26/50 val: iter 78: loss 0.966, acc 0.658: 100%|██████████| 79/79 [00:10<00:00,  7.84it/s]
epoch 27/50 train: iter 390: loss 1.020, acc 0.634: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 27/50 val: iter 78: loss 0.969, acc 0.654: 100%|██████████| 79/79 [00:10<00:00,  7.86it/s]
epoch 28/50 train: iter 390: loss 1.010, acc 0.636: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 28/50 val: iter 78: loss 0.953, acc 0.656: 100%|██████████| 79/79 [00:10<00:00,  7.86it/s]
epoch 29/50 train: iter 390: loss 0.988, acc 0.643: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 29/50 val: iter 78: loss 0.966, acc 0.658: 100%|██████████| 79/79 [00:10<00:00,  7.84it/s]
epoch 30/50 train: iter 390: loss 0.965, acc 0.655: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 30/50 val: iter 78: loss 0.925, acc 0.672: 100%|██████████| 79/79 [00:10<00:00,  7.86it/s]
epoch 31/50 train: iter 390: loss 0.942, acc 0.659: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 31/50 val: iter 78: loss 0.899, acc 0.680: 100%|██████████| 79/79 [00:10<00:00,  7.85it/s]
epoch 32/50 train: iter 390: loss 0.923, acc 0.667: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 32/50 val: iter 78: loss 0.868, acc 0.693: 100%|██████████| 79/79 [00:10<00:00,  7.86it/s]
epoch 33/50 train: iter 390: loss 0.904, acc 0.674: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 33/50 val: iter 78: loss 0.873, acc 0.690: 100%|██████████| 79/79 [00:10<00:00,  7.85it/s]
epoch 34/50 train: iter 390: loss 0.876, acc 0.686: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 34/50 val: iter 78: loss 0.896, acc 0.681: 100%|██████████| 79/79 [00:10<00:00,  7.84it/s]
epoch 35/50 train: iter 390: loss 0.857, acc 0.690: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 35/50 val: iter 78: loss 0.858, acc 0.697: 100%|██████████| 79/79 [00:10<00:00,  7.85it/s]
epoch 36/50 train: iter 390: loss 0.836, acc 0.697: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 36/50 val: iter 78: loss 0.856, acc 0.702: 100%|██████████| 79/79 [00:10<00:00,  7.85it/s]
epoch 37/50 train: iter 390: loss 0.816, acc 0.705: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 37/50 val: iter 78: loss 0.828, acc 0.710: 100%|██████████| 79/79 [00:10<00:00,  7.85it/s]
epoch 38/50 train: iter 390: loss 0.793, acc 0.715: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 38/50 val: iter 78: loss 0.830, acc 0.708: 100%|██████████| 79/79 [00:10<00:00,  7.86it/s]
epoch 39/50 train: iter 390: loss 0.773, acc 0.723: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 39/50 val: iter 78: loss 0.824, acc 0.710: 100%|██████████| 79/79 [00:10<00:00,  7.84it/s]
epoch 40/50 train: iter 390: loss 0.747, acc 0.731: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 40/50 val: iter 78: loss 0.816, acc 0.717: 100%|██████████| 79/79 [00:10<00:00,  7.86it/s]
epoch 41/50 train: iter 390: loss 0.730, acc 0.737: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 41/50 val: iter 78: loss 0.794, acc 0.723: 100%|██████████| 79/79 [00:10<00:00,  7.85it/s]
epoch 42/50 train: iter 390: loss 0.717, acc 0.743: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 42/50 val: iter 78: loss 0.802, acc 0.722: 100%|██████████| 79/79 [00:10<00:00,  7.85it/s]
epoch 43/50 train: iter 390: loss 0.696, acc 0.750: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 43/50 val: iter 78: loss 0.795, acc 0.725: 100%|██████████| 79/79 [00:10<00:00,  7.85it/s]
epoch 44/50 train: iter 390: loss 0.677, acc 0.754: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 44/50 val: iter 78: loss 0.779, acc 0.732: 100%|██████████| 79/79 [00:10<00:00,  7.84it/s]
epoch 45/50 train: iter 390: loss 0.670, acc 0.757: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 45/50 val: iter 78: loss 0.777, acc 0.734: 100%|██████████| 79/79 [00:10<00:00,  7.85it/s]
epoch 46/50 train: iter 390: loss 0.662, acc 0.760: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 46/50 val: iter 78: loss 0.787, acc 0.730: 100%|██████████| 79/79 [00:10<00:00,  7.85it/s]
epoch 47/50 train: iter 390: loss 0.652, acc 0.765: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 47/50 val: iter 78: loss 0.785, acc 0.733: 100%|██████████| 79/79 [00:10<00:00,  7.85it/s]
epoch 48/50 train: iter 390: loss 0.647, acc 0.766: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 48/50 val: iter 78: loss 0.782, acc 0.732: 100%|██████████| 79/79 [00:10<00:00,  7.86it/s]
epoch 49/50 train: iter 390: loss 0.642, acc 0.769: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 49/50 val: iter 78: loss 0.784, acc 0.732: 100%|██████████| 79/79 [00:10<00:00,  7.85it/s]
epoch 50/50 train: iter 390: loss 0.638, acc 0.768: 100%|██████████| 391/391 [02:32<00:00,  2.56it/s]
epoch 50/50 val: iter 78: loss 0.783, acc 0.732: 100%|██████████| 79/79 [00:10<00:00,  7.84it/s]
In [28]:
learner.history.plot_train_val('loss')
In [29]:
learner.history.plot_train_val('acc')
In [30]:
torch.save(model.state_dict(), SAVE_PATH)