Simple Residual GAN¶

Configuration¶

Imports

In [1]:
import math
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision.utils as vision_utils

Configuration

In [2]:
DATA_DIR = "data"

NUM_WORKERS = 4
BATCH_SIZE = 128

IMAGE_SIZE = 128
IMAGE_CHANNELS = 3
NOISE_DIM = 100 # Size of z latent vector
GENERATOR_HIDDEN_CHANNELS = 16
DISCRIMINATOR_HIDDEN_CHANNELS = 16

EPOCHS = 40
LEARNING_RATE_G = 1e-4
WEIGHT_DECAY_G = 1e-1
LEARNING_RATE_D = 1e-4
WEIGHT_DECAY_D = 1e-1
β1 = 0.5

PRINT_FREQ = 500
In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Utils¶

In [4]:
def plot_batch(ax, batch, title=None, **kwargs):
    imgs = vision_utils.make_grid(batch, padding=2, normalize=True)
    imgs = np.moveaxis(imgs.numpy(), 0, -1)
    ax.set_axis_off()
    if title is not None: ax.set_title(title)
    return ax.imshow(imgs, **kwargs)
In [5]:
def show_images(batch, title):
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111)
    plot_batch(ax, batch, title)
    plt.show()

Data¶

ClebA dataset

In [6]:
transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
In [7]:
dataset = datasets.ImageFolder(root=DATA_DIR, transform=transform)
In [8]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
In [9]:
real_batch, _ = next(iter(dataloader))
show_images(real_batch[:64], "Training Images")

Model¶

In [10]:
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
        nn.init.kaiming_normal_(m.weight)
        if m.bias is not None: nn.init.zeros_(m.bias)
    elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
        if m.weight is not None: nn.init.ones_(m.weight)
        if m.bias is not None:nn.init.zeros_(m.bias)

Generator¶

InstanceNorm, arXiv:1607.08022 [cs.CV]

In [11]:
class PreConvTransposeBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1):
        padding = (kernel_size - 1) // 2
        output_padding = stride + 2 * padding - kernel_size
        super().__init__(
            nn.InstanceNorm2d(in_channels, affine=True),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding)
        )
In [12]:
class PreConvBlockG(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1):
        padding = (kernel_size - 1) // 2
        super().__init__(
            nn.InstanceNorm2d(in_channels, affine=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        )
In [13]:
class GeneratorBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.shortcut = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),
                                      nn.Conv2d(in_channels, out_channels, 1))
        self.residual = nn.Sequential(PreConvTransposeBlock(in_channels, out_channels, kernel_size=4, stride=2),
                                      PreConvBlockG(out_channels, out_channels, kernel_size=3))
        self.γ = nn.Parameter(torch.tensor(0.0))
    
    def forward(self, x):
        out = self.shortcut(x) + self.γ * self.residual(x)
        return out
In [14]:
class GeneratorHead(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super().__init__(
            PreConvBlockG(in_channels, out_channels, 1),
            nn.Tanh()
        )
In [15]:
class Generator(nn.Sequential):
    def __init__(self, noise_dim, hidden_channels, out_channels, num_blocks):
        channels = [hidden_channels * 2**x for x in reversed(range(num_blocks + 1))]
        layers = [nn.ConvTranspose2d(noise_dim, channels[0], kernel_size=4)]
        layers += [GeneratorBlock(c1, c2) for c1, c2 in zip(channels, channels[1:])]
        layers.append(GeneratorHead(channels[-1], out_channels))
        super().__init__(*layers)
In [16]:
generator = Generator(NOISE_DIM, GENERATOR_HIDDEN_CHANNELS, IMAGE_CHANNELS, num_blocks=5).to(DEVICE)
In [17]:
generator.apply(weights_init);
In [18]:
print("Number of generator parameters: {:,}".format(sum(p.numel() for p in generator.parameters())))
Number of generator parameters: 4,577,992

Discriminator¶

PatchGAN discriminator, arXiv:1611.07004 [cs.CV]

In [19]:
class PreConvBlockD(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
        padding = (kernel_size - 1) // 2
        super().__init__(
            nn.GroupNorm(1, in_channels), # Equivalent to LayerNorm
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
        )
In [20]:
class DiscriminatorBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.shortcut = nn.Sequential(nn.AvgPool2d(2),
                                      nn.Conv2d(in_channels, out_channels, 1))
        self.residual = nn.Sequential(PreConvBlockD(in_channels, in_channels, kernel_size=3, stride=1),
                                      PreConvBlockD(in_channels, out_channels, kernel_size=4, stride=2))
        self.γ = nn.Parameter(torch.tensor(0.0))
    
    def forward(self, x):
        out = self.shortcut(x) + self.γ * self.residual(x)
        return out
In [21]:
class Discriminator(nn.Sequential):
    def __init__(self, in_channels, hidden_channels, num_blocks):
        channels = [hidden_channels * 2**x for x in range(num_blocks + 1)]
        layers = [nn.Conv2d(in_channels, channels[0], 3, padding=1)]
        layers += [DiscriminatorBlock(c1, c2) for c1, c2 in zip(channels, channels[1:])]
        layers.append(PreConvBlockD(channels[-1], 1, kernel_size=1))
        super().__init__(*layers)
In [22]:
discriminator = Discriminator(IMAGE_CHANNELS, DISCRIMINATOR_HIDDEN_CHANNELS, num_blocks=5).to(DEVICE)
In [23]:
discriminator.apply(weights_init);
In [24]:
print("Number of discriminator parameters: {:,}".format(sum(p.numel() for p in discriminator.parameters())))
Number of discriminator parameters: 3,760,182

Loss¶

Hinge loss for GANs, arXiv:1705.02894 [stat.ML]

In [25]:
class GANLossHinge:
    @staticmethod
    def D_real(logits):
        return torch.mean(F.relu(1. - logits, inplace=True))
    
    @staticmethod
    def D_fake(logits):
        return torch.mean(F.relu(1. + logits, inplace=True))
    
    @staticmethod
    def G(logits):
        return - torch.mean(logits)

Training¶

Trainer¶

In [26]:
class Trainer:
    def __init__(self, generator, discriminator, criterion_D_real, criterion_D_fake, criterion_G,
                 optimizer_G, optimizer_D, dataloader, device,
                 batch_scheduler_G=None, batch_scheduler_D=None):
        self.generator = generator
        self.discriminator = discriminator
        self.criterion_D_real = criterion_D_real
        self.criterion_D_fake = criterion_D_fake
        self.criterion_G = criterion_G
        self.optimizer_G = optimizer_G
        self.optimizer_D = optimizer_D
        self.dataloader = dataloader
        self.device = device
        self.batch_scheduler_G = batch_scheduler_G
        self.batch_scheduler_D = batch_scheduler_D
        self.history = defaultdict(list)
        self.fixed_noise = self.get_noise(64)
    
    def get_noise(self, batch_size):
        return torch.randn(batch_size, NOISE_DIM, 1, 1, device=self.device)
    
    def save_fake_images(self):
        with torch.no_grad():
            fake_images = self.generator(self.fixed_noise).detach().cpu()
        self.history['fake_images'].append(fake_images)
    
    
    def update_discriminator(self, real_images):
        self.optimizer_D.zero_grad()
        
        # All-real batch
        output_real = self.discriminator(real_images)
        loss_D_real = self.criterion_D_real(output_real)
        
        # All-fake batch
        batch_size = real_images.size(0)
        noise = self.get_noise(batch_size)
        fake_images = self.generator(noise)
        output_fake = self.discriminator(fake_images.detach())
        loss_D_fake = self.criterion_D_fake(output_fake)

        loss_D = 0.5 * (loss_D_real + loss_D_fake)
        loss_D.backward()
        self.optimizer_D.step()
        
        if self.batch_scheduler_D is not None:
            self.batch_scheduler_D.step()
        
        self.history['loss_D'].append(loss_D.item())
        self.history['logits_real'].append(torch.mean(output_real).item())
        self.history['logits_fake'].append(torch.mean(output_fake).item())
        
        return fake_images
    
    
    def update_generator(self, fake_images):
        self.optimizer_G.zero_grad()
        output_fake_G = self.discriminator(fake_images) # discriminator was updated: output_fake_G not equal to output_fake
        loss_G = self.criterion_G(output_fake_G)
        loss_G.backward()
        self.optimizer_G.step()
        
        if self.batch_scheduler_G is not None:
            self.batch_scheduler_G.step()
        
        self.history['loss_G'].append(loss_G.item())
    
    
    def print_metrics(self, iters, epoch, epochs):
        print('{} [{}/{}]: D loss {:.4f}, G loss {:.4f}, logits real {:.4f}, logits fake {:.4f}'.format(
            iters, epoch+1, epochs, self.history['loss_D'][-1], self.history['loss_G'][-1],
            self.history['logits_real'][-1], self.history['logits_fake'][-1]))
    
    
    def train(self, epochs, print_freq):
        iters = 0
        for epoch in range(epochs):
            for real_images, _ in self.dataloader:
                real_images = real_images.to(self.device)
                 
                fake_images = self.update_discriminator(real_images)
                self.update_generator(fake_images)
                
                if iters % print_freq == 0:
                    self.print_metrics(iters, epoch, epochs)
                    self.save_fake_images()
                iters += 1
        self.save_fake_images()

Start Training¶

In [27]:
criterion = GANLossHinge()
In [28]:
optimizer_G = optim.AdamW(generator.parameters(), lr=LEARNING_RATE_G, weight_decay=WEIGHT_DECAY_G, betas=(β1, 0.999))
optimizer_D = optim.AdamW(discriminator.parameters(), lr=LEARNING_RATE_D, weight_decay=WEIGHT_DECAY_D, betas=(β1, 0.999))
In [29]:
lr_scheduler_G = optim.lr_scheduler.OneCycleLR(optimizer_G, max_lr=LEARNING_RATE_G,
                                               steps_per_epoch=len(dataloader), epochs=EPOCHS, cycle_momentum=False)

lr_scheduler_D = optim.lr_scheduler.OneCycleLR(optimizer_D, max_lr=LEARNING_RATE_D,
                                               steps_per_epoch=len(dataloader), epochs=EPOCHS, cycle_momentum=False)
In [30]:
trainer = Trainer(generator, discriminator,
                  criterion.D_real, criterion.D_fake, criterion.G,
                  optimizer_G, optimizer_D, dataloader, DEVICE,
                  lr_scheduler_G, lr_scheduler_D)
In [35]:
trainer.train(EPOCHS, PRINT_FREQ)
0 [1/40]: D loss 0.9989, G loss 0.3954, logits real -0.3282, logits fake -0.3912
500 [1/40]: D loss 0.7014, G loss 0.2357, logits real 0.4687, logits fake -0.2347
1000 [1/40]: D loss 0.5863, G loss 0.4419, logits real 0.4525, logits fake -0.4432
1500 [1/40]: D loss 0.3566, G loss 0.8974, logits real 0.9940, logits fake -0.8913
2000 [2/40]: D loss 0.1845, G loss 1.3546, logits real 2.0549, logits fake -1.2681
2500 [2/40]: D loss 0.1481, G loss 1.4282, logits real 2.4298, logits fake -1.3870
3000 [2/40]: D loss 0.2417, G loss 1.0233, logits real 2.5179, logits fake -1.4250
3500 [3/40]: D loss 0.3093, G loss 0.9324, logits real 2.0659, logits fake -1.4532
4000 [3/40]: D loss 0.2683, G loss 1.3177, logits real 2.7029, logits fake -1.1631
4500 [3/40]: D loss 0.4232, G loss 0.6033, logits real 1.2999, logits fake -1.8271
5000 [4/40]: D loss 0.5720, G loss 0.5053, logits real 0.6629, logits fake -1.6923
5500 [4/40]: D loss 0.4714, G loss 1.3494, logits real 1.7170, logits fake -0.5421
6000 [4/40]: D loss 0.5726, G loss 0.9775, logits real 1.0604, logits fake -0.6620
6500 [5/40]: D loss 0.5137, G loss 1.2993, logits real 1.3112, logits fake -0.6964
7000 [5/40]: D loss 0.5943, G loss 0.5668, logits real 0.6176, logits fake -1.0909
7500 [5/40]: D loss 0.5794, G loss 0.7274, logits real 0.7761, logits fake -0.8558
8000 [6/40]: D loss 0.6849, G loss 1.0214, logits real 0.7167, logits fake -0.4562
8500 [6/40]: D loss 0.7422, G loss 0.7130, logits real 0.5398, logits fake -0.2807
9000 [6/40]: D loss 1.0444, G loss 0.0565, logits real -0.2227, logits fake -0.8144
9500 [7/40]: D loss 0.8576, G loss 0.6265, logits real 0.6711, logits fake 0.2051
10000 [7/40]: D loss 1.0358, G loss 1.2432, logits real 1.2872, logits fake 0.8543
10500 [7/40]: D loss 0.8886, G loss 0.4129, logits real 0.3960, logits fake 0.0152
11000 [7/40]: D loss 0.7947, G loss 0.4750, logits real 0.0196, logits fake -0.5449
11500 [8/40]: D loss 0.8153, G loss 0.3000, logits real -0.2129, logits fake -0.6956
12000 [8/40]: D loss 0.8724, G loss 0.6919, logits real 0.6640, logits fake 0.2493
12500 [8/40]: D loss 0.8003, G loss -0.1362, logits real 0.2031, logits fake -0.3534
13000 [9/40]: D loss 0.8788, G loss 0.2050, logits real 0.2743, logits fake -0.0200
13500 [9/40]: D loss 0.7762, G loss 0.1777, logits real -0.1695, logits fake -0.7547
14000 [9/40]: D loss 0.8769, G loss 0.7025, logits real 0.5766, logits fake 0.1467
14500 [10/40]: D loss 0.7299, G loss 0.1577, logits real -0.0842, logits fake -0.8057
15000 [10/40]: D loss 0.7486, G loss 0.4825, logits real -0.0989, logits fake -0.7559
15500 [10/40]: D loss 0.7297, G loss 0.9733, logits real 0.3832, logits fake -0.4175
16000 [11/40]: D loss 0.7445, G loss 0.6441, logits real 0.3155, logits fake -0.3619
16500 [11/40]: D loss 0.5995, G loss 0.6764, logits real 0.2452, logits fake -0.8217
17000 [11/40]: D loss 0.7081, G loss 0.3917, logits real 0.2261, logits fake -0.6709
17500 [12/40]: D loss 0.6753, G loss 0.3999, logits real 0.1981, logits fake -0.8832
18000 [12/40]: D loss 0.7094, G loss 0.9524, logits real 0.5742, logits fake -0.3612
18500 [12/40]: D loss 0.6171, G loss -0.0637, logits real 0.2310, logits fake -0.9874
19000 [13/40]: D loss 0.7940, G loss 0.8039, logits real 1.3070, logits fake 0.2818
19500 [13/40]: D loss 0.6391, G loss 1.8423, logits real 1.3986, logits fake -0.0020
20000 [13/40]: D loss 0.5311, G loss 1.2937, logits real 0.9359, logits fake -0.4635
20500 [13/40]: D loss 0.5016, G loss 0.2299, logits real 0.4690, logits fake -1.0579
21000 [14/40]: D loss 0.4962, G loss 1.2784, logits real 0.6414, logits fake -0.8633
21500 [14/40]: D loss 0.4749, G loss 1.2514, logits real 0.8434, logits fake -0.7096
22000 [14/40]: D loss 0.4860, G loss 1.1279, logits real 0.7917, logits fake -0.7321
22500 [15/40]: D loss 0.5746, G loss 1.6527, logits real 1.4934, logits fake -0.1603
23000 [15/40]: D loss 0.4619, G loss 0.4736, logits real 0.5584, logits fake -1.1429
23500 [15/40]: D loss 0.5745, G loss 1.6893, logits real 1.6235, logits fake -0.0346
24000 [16/40]: D loss 0.4645, G loss 0.7335, logits real 0.5216, logits fake -1.1867
24500 [16/40]: D loss 0.6208, G loss 0.1134, logits real 0.0455, logits fake -1.8317
25000 [16/40]: D loss 0.6879, G loss 0.1680, logits real -0.1937, logits fake -1.9556
25500 [17/40]: D loss 0.6609, G loss 0.0826, logits real -0.0558, logits fake -2.0029
26000 [17/40]: D loss 0.3729, G loss 1.0357, logits real 0.8876, logits fake -1.1824
26500 [17/40]: D loss 0.4249, G loss 0.5028, logits real 0.7529, logits fake -1.1155
27000 [18/40]: D loss 0.5695, G loss 1.6246, logits real 1.5684, logits fake -0.1833
27500 [18/40]: D loss 0.5068, G loss 1.0806, logits real 1.2727, logits fake -0.4243
28000 [18/40]: D loss 0.4889, G loss 2.4949, logits real 1.6947, logits fake -0.2728
28500 [19/40]: D loss 0.3801, G loss 1.4602, logits real 1.1103, logits fake -0.7853
29000 [19/40]: D loss 0.4014, G loss 0.7702, logits real 0.7030, logits fake -1.2260
29500 [19/40]: D loss 0.3435, G loss 1.5383, logits real 1.5182, logits fake -0.6815
30000 [19/40]: D loss 0.3775, G loss 0.7085, logits real 0.6682, logits fake -1.7505
30500 [20/40]: D loss 0.4322, G loss 0.7698, logits real 0.6033, logits fake -1.4577
31000 [20/40]: D loss 0.3519, G loss 0.6006, logits real 0.9485, logits fake -1.6102
31500 [20/40]: D loss 0.6809, G loss 2.0072, logits real 2.0127, logits fake 0.2053
32000 [21/40]: D loss 0.4041, G loss 0.4265, logits real 0.6394, logits fake -1.7824
32500 [21/40]: D loss 0.4826, G loss 1.0501, logits real 0.4670, logits fake -1.3259
33000 [21/40]: D loss 0.4672, G loss 0.1853, logits real 0.5160, logits fake -1.7496
33500 [22/40]: D loss 0.7057, G loss 0.8682, logits real -0.2159, logits fake -2.3037
34000 [22/40]: D loss 0.3408, G loss 0.6700, logits real 1.0602, logits fake -1.3491
34500 [22/40]: D loss 0.3739, G loss 0.7048, logits real 0.6753, logits fake -1.8668
35000 [23/40]: D loss 0.5090, G loss 0.7230, logits real 0.3656, logits fake -1.9331
35500 [23/40]: D loss 0.3348, G loss 1.0852, logits real 1.0737, logits fake -1.4878
36000 [23/40]: D loss 0.3689, G loss 1.5058, logits real 1.7223, logits fake -0.6053
36500 [24/40]: D loss 0.3999, G loss 0.6561, logits real 0.6604, logits fake -1.8169
37000 [24/40]: D loss 0.2978, G loss 1.7424, logits real 1.7653, logits fake -0.9796
37500 [24/40]: D loss 0.2608, G loss 1.5015, logits real 1.8587, logits fake -0.9827
38000 [25/40]: D loss 0.3400, G loss 1.2158, logits real 1.2459, logits fake -1.2976
38500 [25/40]: D loss 0.3823, G loss 0.4329, logits real 0.8489, logits fake -1.8379
39000 [25/40]: D loss 0.5078, G loss 2.2481, logits real 2.2970, logits fake -0.1910
39500 [25/40]: D loss 0.2824, G loss 2.2805, logits real 2.0370, logits fake -1.0266
40000 [26/40]: D loss 0.2144, G loss 1.4994, logits real 1.6908, logits fake -1.4804
40500 [26/40]: D loss 0.3072, G loss 1.1498, logits real 1.2462, logits fake -1.3053
41000 [26/40]: D loss 0.2757, G loss 1.9425, logits real 2.1266, logits fake -0.9000
41500 [27/40]: D loss 0.2767, G loss 2.2231, logits real 2.2715, logits fake -0.9104
42000 [27/40]: D loss 0.2267, G loss 0.8177, logits real 1.8538, logits fake -1.4692
42500 [27/40]: D loss 0.2475, G loss 1.6527, logits real 1.5741, logits fake -1.8913
43000 [28/40]: D loss 0.2429, G loss 1.5730, logits real 2.2205, logits fake -1.1504
43500 [28/40]: D loss 0.2432, G loss 0.8140, logits real 1.5316, logits fake -1.9669
44000 [28/40]: D loss 0.3046, G loss 1.0124, logits real 1.2057, logits fake -2.0693
44500 [29/40]: D loss 0.2197, G loss 1.4000, logits real 2.2178, logits fake -1.1229
45000 [29/40]: D loss 0.2386, G loss 1.7064, logits real 2.1172, logits fake -1.3072
45500 [29/40]: D loss 0.1910, G loss 1.9443, logits real 2.2899, logits fake -1.4480
46000 [30/40]: D loss 0.2758, G loss 2.0807, logits real 2.3684, logits fake -0.9439
46500 [30/40]: D loss 0.2032, G loss 1.7891, logits real 2.5569, logits fake -1.3563
47000 [30/40]: D loss 0.2314, G loss 1.2728, logits real 1.9433, logits fake -1.6194
47500 [31/40]: D loss 0.2233, G loss 1.8861, logits real 2.2580, logits fake -1.5123
48000 [31/40]: D loss 0.2375, G loss 1.1658, logits real 1.5939, logits fake -2.2996
48500 [31/40]: D loss 0.1570, G loss 1.8415, logits real 2.1304, logits fake -1.8227
49000 [31/40]: D loss 0.1455, G loss 2.2480, logits real 2.6633, logits fake -1.6189
49500 [32/40]: D loss 0.1965, G loss 1.6342, logits real 2.2894, logits fake -1.4202
50000 [32/40]: D loss 0.2165, G loss 1.5066, logits real 2.2227, logits fake -1.8342
50500 [32/40]: D loss 0.1535, G loss 1.6225, logits real 2.4805, logits fake -1.5775
51000 [33/40]: D loss 0.2074, G loss 1.7396, logits real 1.9575, logits fake -2.2553
51500 [33/40]: D loss 0.1582, G loss 1.7959, logits real 2.4480, logits fake -1.7960
52000 [33/40]: D loss 0.1645, G loss 1.9366, logits real 2.5978, logits fake -1.6151
52500 [34/40]: D loss 0.1661, G loss 1.7223, logits real 2.6000, logits fake -1.7144
53000 [34/40]: D loss 0.1480, G loss 2.1309, logits real 3.1120, logits fake -1.5588
53500 [34/40]: D loss 0.1204, G loss 1.8373, logits real 2.6827, logits fake -1.6749
54000 [35/40]: D loss 0.1643, G loss 1.8832, logits real 2.5383, logits fake -1.6997
54500 [35/40]: D loss 0.1970, G loss 1.7919, logits real 2.6112, logits fake -1.5330
55000 [35/40]: D loss 0.1821, G loss 1.5336, logits real 2.3185, logits fake -1.8593
55500 [36/40]: D loss 0.1365, G loss 1.9283, logits real 2.6313, logits fake -1.8469
56000 [36/40]: D loss 0.1418, G loss 1.8703, logits real 2.4342, logits fake -1.8000
56500 [36/40]: D loss 0.1354, G loss 2.0224, logits real 2.9076, logits fake -1.5236
57000 [37/40]: D loss 0.1590, G loss 1.8018, logits real 2.6975, logits fake -1.7003
57500 [37/40]: D loss 0.1649, G loss 1.8073, logits real 2.7424, logits fake -1.7931
58000 [37/40]: D loss 0.1426, G loss 2.0696, logits real 2.6659, logits fake -2.0334
58500 [37/40]: D loss 0.1731, G loss 1.7543, logits real 2.3579, logits fake -1.8464
59000 [38/40]: D loss 0.1296, G loss 1.9206, logits real 2.6974, logits fake -1.8775
59500 [38/40]: D loss 0.1909, G loss 1.8966, logits real 2.4968, logits fake -1.8848
60000 [38/40]: D loss 0.1201, G loss 1.9811, logits real 3.0063, logits fake -1.9123
60500 [39/40]: D loss 0.1573, G loss 1.7648, logits real 2.8817, logits fake -1.6806
61000 [39/40]: D loss 0.1277, G loss 1.8872, logits real 2.8245, logits fake -1.8597
61500 [39/40]: D loss 0.1415, G loss 1.9578, logits real 2.7728, logits fake -1.9543
62000 [40/40]: D loss 0.1094, G loss 1.9513, logits real 2.7964, logits fake -1.9538
62500 [40/40]: D loss 0.1221, G loss 1.8556, logits real 2.7409, logits fake -1.8492
63000 [40/40]: D loss 0.1267, G loss 1.9557, logits real 2.5516, logits fake -1.9559

Plotting¶

In [32]:
def plot_losses(G_losses, D_losses):
    fig = plt.figure(figsize=(10, 5))
    ax = fig.add_subplot(111)
    
    iters = np.arange(1, len(G_losses) + 1)
    ax.scatter(iters, G_losses, marker='.', label="Generator")
    ax.scatter(iters, D_losses, marker='.',  label="Discriminator")
    
    ax.set_xlabel("Iterations")
    ax.set_ylabel("Loss")
    ax.set_title("Generator and Discriminator Loss During Training")
    ax.legend()
    plt.show()
In [33]:
def animate_progress(img_list):
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111)
    ax.set_axis_off()
    frames = [[plot_batch(ax, batch, animated=True)] for batch in img_list]
    ani = animation.ArtistAnimation(fig, frames, interval=1000, repeat_delay=1000, blit=True)
    plt.close()
    return ani
In [34]:
def show_fake_batch(fixed_noise):
    with torch.no_grad():
        fake_batch = generator(fixed_noise).detach().cpu()
    show_images(fake_batch, "Generated Images")
In [39]:
plot_losses(trainer.history['loss_G'], trainer.history['loss_D'])
In [37]:
ani = animate_progress(trainer.history['fake_images'][:32])
HTML(ani.to_jshtml())
Out[37]:
In [38]:
generator.eval();
In [42]:
show_fake_batch(trainer.fixed_noise)