Imports
import math
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision.utils as vision_utils
Configuration
DATA_DIR = "data"
NUM_WORKERS = 4
BATCH_SIZE = 128
IMAGE_SIZE = 128
IMAGE_CHANNELS = 3
NOISE_DIM = 100 # Size of z latent vector
GENERATOR_HIDDEN_CHANNELS = 16
DISCRIMINATOR_HIDDEN_CHANNELS = 16
EPOCHS = 40
LEARNING_RATE_G = 1e-4
WEIGHT_DECAY_G = 1e-1
LEARNING_RATE_D = 1e-4
WEIGHT_DECAY_D = 1e-1
β1 = 0.5
PRINT_FREQ = 500
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def plot_batch(ax, batch, title=None, **kwargs):
imgs = vision_utils.make_grid(batch, padding=2, normalize=True)
imgs = np.moveaxis(imgs.numpy(), 0, -1)
ax.set_axis_off()
if title is not None: ax.set_title(title)
return ax.imshow(imgs, **kwargs)
def show_images(batch, title):
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111)
plot_batch(ax, batch, title)
plt.show()
ClebA dataset
transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.CenterCrop(IMAGE_SIZE),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = datasets.ImageFolder(root=DATA_DIR, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
real_batch, _ = next(iter(dataloader))
show_images(real_batch[:64], "Training Images")
def weights_init(m):
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None: nn.init.ones_(m.weight)
if m.bias is not None:nn.init.zeros_(m.bias)
InstanceNorm, arXiv:1607.08022 [cs.CV]
class PreConvTransposeBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
padding = (kernel_size - 1) // 2
output_padding = stride + 2 * padding - kernel_size
super().__init__(
nn.InstanceNorm2d(in_channels, affine=True),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding)
)
class PreConvBlockG(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
padding = (kernel_size - 1) // 2
super().__init__(
nn.InstanceNorm2d(in_channels, affine=True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
)
class GeneratorBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.shortcut = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(in_channels, out_channels, 1))
self.residual = nn.Sequential(PreConvTransposeBlock(in_channels, out_channels, kernel_size=4, stride=2),
PreConvBlockG(out_channels, out_channels, kernel_size=3))
self.γ = nn.Parameter(torch.tensor(0.0))
def forward(self, x):
out = self.shortcut(x) + self.γ * self.residual(x)
return out
class GeneratorHead(nn.Sequential):
def __init__(self, in_channels, out_channels):
super().__init__(
PreConvBlockG(in_channels, out_channels, 1),
nn.Tanh()
)
class Generator(nn.Sequential):
def __init__(self, noise_dim, hidden_channels, out_channels, num_blocks):
channels = [hidden_channels * 2**x for x in reversed(range(num_blocks + 1))]
layers = [nn.ConvTranspose2d(noise_dim, channels[0], kernel_size=4)]
layers += [GeneratorBlock(c1, c2) for c1, c2 in zip(channels, channels[1:])]
layers.append(GeneratorHead(channels[-1], out_channels))
super().__init__(*layers)
generator = Generator(NOISE_DIM, GENERATOR_HIDDEN_CHANNELS, IMAGE_CHANNELS, num_blocks=5).to(DEVICE)
generator.apply(weights_init);
print("Number of generator parameters: {:,}".format(sum(p.numel() for p in generator.parameters())))
Number of generator parameters: 4,577,992
PatchGAN discriminator, arXiv:1611.07004 [cs.CV]
class PreConvBlockD(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
padding = (kernel_size - 1) // 2
super().__init__(
nn.GroupNorm(1, in_channels), # Equivalent to LayerNorm
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
)
class DiscriminatorBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.shortcut = nn.Sequential(nn.AvgPool2d(2),
nn.Conv2d(in_channels, out_channels, 1))
self.residual = nn.Sequential(PreConvBlockD(in_channels, in_channels, kernel_size=3, stride=1),
PreConvBlockD(in_channels, out_channels, kernel_size=4, stride=2))
self.γ = nn.Parameter(torch.tensor(0.0))
def forward(self, x):
out = self.shortcut(x) + self.γ * self.residual(x)
return out
class Discriminator(nn.Sequential):
def __init__(self, in_channels, hidden_channels, num_blocks):
channels = [hidden_channels * 2**x for x in range(num_blocks + 1)]
layers = [nn.Conv2d(in_channels, channels[0], 3, padding=1)]
layers += [DiscriminatorBlock(c1, c2) for c1, c2 in zip(channels, channels[1:])]
layers.append(PreConvBlockD(channels[-1], 1, kernel_size=1))
super().__init__(*layers)
discriminator = Discriminator(IMAGE_CHANNELS, DISCRIMINATOR_HIDDEN_CHANNELS, num_blocks=5).to(DEVICE)
discriminator.apply(weights_init);
print("Number of discriminator parameters: {:,}".format(sum(p.numel() for p in discriminator.parameters())))
Number of discriminator parameters: 3,760,182
Hinge loss for GANs, arXiv:1705.02894 [stat.ML]
class GANLossHinge:
@staticmethod
def D_real(logits):
return torch.mean(F.relu(1. - logits, inplace=True))
@staticmethod
def D_fake(logits):
return torch.mean(F.relu(1. + logits, inplace=True))
@staticmethod
def G(logits):
return - torch.mean(logits)
class Trainer:
def __init__(self, generator, discriminator, criterion_D_real, criterion_D_fake, criterion_G,
optimizer_G, optimizer_D, dataloader, device,
batch_scheduler_G=None, batch_scheduler_D=None):
self.generator = generator
self.discriminator = discriminator
self.criterion_D_real = criterion_D_real
self.criterion_D_fake = criterion_D_fake
self.criterion_G = criterion_G
self.optimizer_G = optimizer_G
self.optimizer_D = optimizer_D
self.dataloader = dataloader
self.device = device
self.batch_scheduler_G = batch_scheduler_G
self.batch_scheduler_D = batch_scheduler_D
self.history = defaultdict(list)
self.fixed_noise = self.get_noise(64)
def get_noise(self, batch_size):
return torch.randn(batch_size, NOISE_DIM, 1, 1, device=self.device)
def save_fake_images(self):
with torch.no_grad():
fake_images = self.generator(self.fixed_noise).detach().cpu()
self.history['fake_images'].append(fake_images)
def update_discriminator(self, real_images):
self.optimizer_D.zero_grad()
# All-real batch
output_real = self.discriminator(real_images)
loss_D_real = self.criterion_D_real(output_real)
# All-fake batch
batch_size = real_images.size(0)
noise = self.get_noise(batch_size)
fake_images = self.generator(noise)
output_fake = self.discriminator(fake_images.detach())
loss_D_fake = self.criterion_D_fake(output_fake)
loss_D = 0.5 * (loss_D_real + loss_D_fake)
loss_D.backward()
self.optimizer_D.step()
if self.batch_scheduler_D is not None:
self.batch_scheduler_D.step()
self.history['loss_D'].append(loss_D.item())
self.history['logits_real'].append(torch.mean(output_real).item())
self.history['logits_fake'].append(torch.mean(output_fake).item())
return fake_images
def update_generator(self, fake_images):
self.optimizer_G.zero_grad()
output_fake_G = self.discriminator(fake_images) # discriminator was updated: output_fake_G not equal to output_fake
loss_G = self.criterion_G(output_fake_G)
loss_G.backward()
self.optimizer_G.step()
if self.batch_scheduler_G is not None:
self.batch_scheduler_G.step()
self.history['loss_G'].append(loss_G.item())
def print_metrics(self, iters, epoch, epochs):
print('{} [{}/{}]: D loss {:.4f}, G loss {:.4f}, logits real {:.4f}, logits fake {:.4f}'.format(
iters, epoch+1, epochs, self.history['loss_D'][-1], self.history['loss_G'][-1],
self.history['logits_real'][-1], self.history['logits_fake'][-1]))
def train(self, epochs, print_freq):
iters = 0
for epoch in range(epochs):
for real_images, _ in self.dataloader:
real_images = real_images.to(self.device)
fake_images = self.update_discriminator(real_images)
self.update_generator(fake_images)
if iters % print_freq == 0:
self.print_metrics(iters, epoch, epochs)
self.save_fake_images()
iters += 1
self.save_fake_images()
criterion = GANLossHinge()
optimizer_G = optim.AdamW(generator.parameters(), lr=LEARNING_RATE_G, weight_decay=WEIGHT_DECAY_G, betas=(β1, 0.999))
optimizer_D = optim.AdamW(discriminator.parameters(), lr=LEARNING_RATE_D, weight_decay=WEIGHT_DECAY_D, betas=(β1, 0.999))
lr_scheduler_G = optim.lr_scheduler.OneCycleLR(optimizer_G, max_lr=LEARNING_RATE_G,
steps_per_epoch=len(dataloader), epochs=EPOCHS, cycle_momentum=False)
lr_scheduler_D = optim.lr_scheduler.OneCycleLR(optimizer_D, max_lr=LEARNING_RATE_D,
steps_per_epoch=len(dataloader), epochs=EPOCHS, cycle_momentum=False)
trainer = Trainer(generator, discriminator,
criterion.D_real, criterion.D_fake, criterion.G,
optimizer_G, optimizer_D, dataloader, DEVICE,
lr_scheduler_G, lr_scheduler_D)
trainer.train(EPOCHS, PRINT_FREQ)
0 [1/40]: D loss 0.9989, G loss 0.3954, logits real -0.3282, logits fake -0.3912 500 [1/40]: D loss 0.7014, G loss 0.2357, logits real 0.4687, logits fake -0.2347 1000 [1/40]: D loss 0.5863, G loss 0.4419, logits real 0.4525, logits fake -0.4432 1500 [1/40]: D loss 0.3566, G loss 0.8974, logits real 0.9940, logits fake -0.8913 2000 [2/40]: D loss 0.1845, G loss 1.3546, logits real 2.0549, logits fake -1.2681 2500 [2/40]: D loss 0.1481, G loss 1.4282, logits real 2.4298, logits fake -1.3870 3000 [2/40]: D loss 0.2417, G loss 1.0233, logits real 2.5179, logits fake -1.4250 3500 [3/40]: D loss 0.3093, G loss 0.9324, logits real 2.0659, logits fake -1.4532 4000 [3/40]: D loss 0.2683, G loss 1.3177, logits real 2.7029, logits fake -1.1631 4500 [3/40]: D loss 0.4232, G loss 0.6033, logits real 1.2999, logits fake -1.8271 5000 [4/40]: D loss 0.5720, G loss 0.5053, logits real 0.6629, logits fake -1.6923 5500 [4/40]: D loss 0.4714, G loss 1.3494, logits real 1.7170, logits fake -0.5421 6000 [4/40]: D loss 0.5726, G loss 0.9775, logits real 1.0604, logits fake -0.6620 6500 [5/40]: D loss 0.5137, G loss 1.2993, logits real 1.3112, logits fake -0.6964 7000 [5/40]: D loss 0.5943, G loss 0.5668, logits real 0.6176, logits fake -1.0909 7500 [5/40]: D loss 0.5794, G loss 0.7274, logits real 0.7761, logits fake -0.8558 8000 [6/40]: D loss 0.6849, G loss 1.0214, logits real 0.7167, logits fake -0.4562 8500 [6/40]: D loss 0.7422, G loss 0.7130, logits real 0.5398, logits fake -0.2807 9000 [6/40]: D loss 1.0444, G loss 0.0565, logits real -0.2227, logits fake -0.8144 9500 [7/40]: D loss 0.8576, G loss 0.6265, logits real 0.6711, logits fake 0.2051 10000 [7/40]: D loss 1.0358, G loss 1.2432, logits real 1.2872, logits fake 0.8543 10500 [7/40]: D loss 0.8886, G loss 0.4129, logits real 0.3960, logits fake 0.0152 11000 [7/40]: D loss 0.7947, G loss 0.4750, logits real 0.0196, logits fake -0.5449 11500 [8/40]: D loss 0.8153, G loss 0.3000, logits real -0.2129, logits fake -0.6956 12000 [8/40]: D loss 0.8724, G loss 0.6919, logits real 0.6640, logits fake 0.2493 12500 [8/40]: D loss 0.8003, G loss -0.1362, logits real 0.2031, logits fake -0.3534 13000 [9/40]: D loss 0.8788, G loss 0.2050, logits real 0.2743, logits fake -0.0200 13500 [9/40]: D loss 0.7762, G loss 0.1777, logits real -0.1695, logits fake -0.7547 14000 [9/40]: D loss 0.8769, G loss 0.7025, logits real 0.5766, logits fake 0.1467 14500 [10/40]: D loss 0.7299, G loss 0.1577, logits real -0.0842, logits fake -0.8057 15000 [10/40]: D loss 0.7486, G loss 0.4825, logits real -0.0989, logits fake -0.7559 15500 [10/40]: D loss 0.7297, G loss 0.9733, logits real 0.3832, logits fake -0.4175 16000 [11/40]: D loss 0.7445, G loss 0.6441, logits real 0.3155, logits fake -0.3619 16500 [11/40]: D loss 0.5995, G loss 0.6764, logits real 0.2452, logits fake -0.8217 17000 [11/40]: D loss 0.7081, G loss 0.3917, logits real 0.2261, logits fake -0.6709 17500 [12/40]: D loss 0.6753, G loss 0.3999, logits real 0.1981, logits fake -0.8832 18000 [12/40]: D loss 0.7094, G loss 0.9524, logits real 0.5742, logits fake -0.3612 18500 [12/40]: D loss 0.6171, G loss -0.0637, logits real 0.2310, logits fake -0.9874 19000 [13/40]: D loss 0.7940, G loss 0.8039, logits real 1.3070, logits fake 0.2818 19500 [13/40]: D loss 0.6391, G loss 1.8423, logits real 1.3986, logits fake -0.0020 20000 [13/40]: D loss 0.5311, G loss 1.2937, logits real 0.9359, logits fake -0.4635 20500 [13/40]: D loss 0.5016, G loss 0.2299, logits real 0.4690, logits fake -1.0579 21000 [14/40]: D loss 0.4962, G loss 1.2784, logits real 0.6414, logits fake -0.8633 21500 [14/40]: D loss 0.4749, G loss 1.2514, logits real 0.8434, logits fake -0.7096 22000 [14/40]: D loss 0.4860, G loss 1.1279, logits real 0.7917, logits fake -0.7321 22500 [15/40]: D loss 0.5746, G loss 1.6527, logits real 1.4934, logits fake -0.1603 23000 [15/40]: D loss 0.4619, G loss 0.4736, logits real 0.5584, logits fake -1.1429 23500 [15/40]: D loss 0.5745, G loss 1.6893, logits real 1.6235, logits fake -0.0346 24000 [16/40]: D loss 0.4645, G loss 0.7335, logits real 0.5216, logits fake -1.1867 24500 [16/40]: D loss 0.6208, G loss 0.1134, logits real 0.0455, logits fake -1.8317 25000 [16/40]: D loss 0.6879, G loss 0.1680, logits real -0.1937, logits fake -1.9556 25500 [17/40]: D loss 0.6609, G loss 0.0826, logits real -0.0558, logits fake -2.0029 26000 [17/40]: D loss 0.3729, G loss 1.0357, logits real 0.8876, logits fake -1.1824 26500 [17/40]: D loss 0.4249, G loss 0.5028, logits real 0.7529, logits fake -1.1155 27000 [18/40]: D loss 0.5695, G loss 1.6246, logits real 1.5684, logits fake -0.1833 27500 [18/40]: D loss 0.5068, G loss 1.0806, logits real 1.2727, logits fake -0.4243 28000 [18/40]: D loss 0.4889, G loss 2.4949, logits real 1.6947, logits fake -0.2728 28500 [19/40]: D loss 0.3801, G loss 1.4602, logits real 1.1103, logits fake -0.7853 29000 [19/40]: D loss 0.4014, G loss 0.7702, logits real 0.7030, logits fake -1.2260 29500 [19/40]: D loss 0.3435, G loss 1.5383, logits real 1.5182, logits fake -0.6815 30000 [19/40]: D loss 0.3775, G loss 0.7085, logits real 0.6682, logits fake -1.7505 30500 [20/40]: D loss 0.4322, G loss 0.7698, logits real 0.6033, logits fake -1.4577 31000 [20/40]: D loss 0.3519, G loss 0.6006, logits real 0.9485, logits fake -1.6102 31500 [20/40]: D loss 0.6809, G loss 2.0072, logits real 2.0127, logits fake 0.2053 32000 [21/40]: D loss 0.4041, G loss 0.4265, logits real 0.6394, logits fake -1.7824 32500 [21/40]: D loss 0.4826, G loss 1.0501, logits real 0.4670, logits fake -1.3259 33000 [21/40]: D loss 0.4672, G loss 0.1853, logits real 0.5160, logits fake -1.7496 33500 [22/40]: D loss 0.7057, G loss 0.8682, logits real -0.2159, logits fake -2.3037 34000 [22/40]: D loss 0.3408, G loss 0.6700, logits real 1.0602, logits fake -1.3491 34500 [22/40]: D loss 0.3739, G loss 0.7048, logits real 0.6753, logits fake -1.8668 35000 [23/40]: D loss 0.5090, G loss 0.7230, logits real 0.3656, logits fake -1.9331 35500 [23/40]: D loss 0.3348, G loss 1.0852, logits real 1.0737, logits fake -1.4878 36000 [23/40]: D loss 0.3689, G loss 1.5058, logits real 1.7223, logits fake -0.6053 36500 [24/40]: D loss 0.3999, G loss 0.6561, logits real 0.6604, logits fake -1.8169 37000 [24/40]: D loss 0.2978, G loss 1.7424, logits real 1.7653, logits fake -0.9796 37500 [24/40]: D loss 0.2608, G loss 1.5015, logits real 1.8587, logits fake -0.9827 38000 [25/40]: D loss 0.3400, G loss 1.2158, logits real 1.2459, logits fake -1.2976 38500 [25/40]: D loss 0.3823, G loss 0.4329, logits real 0.8489, logits fake -1.8379 39000 [25/40]: D loss 0.5078, G loss 2.2481, logits real 2.2970, logits fake -0.1910 39500 [25/40]: D loss 0.2824, G loss 2.2805, logits real 2.0370, logits fake -1.0266 40000 [26/40]: D loss 0.2144, G loss 1.4994, logits real 1.6908, logits fake -1.4804 40500 [26/40]: D loss 0.3072, G loss 1.1498, logits real 1.2462, logits fake -1.3053 41000 [26/40]: D loss 0.2757, G loss 1.9425, logits real 2.1266, logits fake -0.9000 41500 [27/40]: D loss 0.2767, G loss 2.2231, logits real 2.2715, logits fake -0.9104 42000 [27/40]: D loss 0.2267, G loss 0.8177, logits real 1.8538, logits fake -1.4692 42500 [27/40]: D loss 0.2475, G loss 1.6527, logits real 1.5741, logits fake -1.8913 43000 [28/40]: D loss 0.2429, G loss 1.5730, logits real 2.2205, logits fake -1.1504 43500 [28/40]: D loss 0.2432, G loss 0.8140, logits real 1.5316, logits fake -1.9669 44000 [28/40]: D loss 0.3046, G loss 1.0124, logits real 1.2057, logits fake -2.0693 44500 [29/40]: D loss 0.2197, G loss 1.4000, logits real 2.2178, logits fake -1.1229 45000 [29/40]: D loss 0.2386, G loss 1.7064, logits real 2.1172, logits fake -1.3072 45500 [29/40]: D loss 0.1910, G loss 1.9443, logits real 2.2899, logits fake -1.4480 46000 [30/40]: D loss 0.2758, G loss 2.0807, logits real 2.3684, logits fake -0.9439 46500 [30/40]: D loss 0.2032, G loss 1.7891, logits real 2.5569, logits fake -1.3563 47000 [30/40]: D loss 0.2314, G loss 1.2728, logits real 1.9433, logits fake -1.6194 47500 [31/40]: D loss 0.2233, G loss 1.8861, logits real 2.2580, logits fake -1.5123 48000 [31/40]: D loss 0.2375, G loss 1.1658, logits real 1.5939, logits fake -2.2996 48500 [31/40]: D loss 0.1570, G loss 1.8415, logits real 2.1304, logits fake -1.8227 49000 [31/40]: D loss 0.1455, G loss 2.2480, logits real 2.6633, logits fake -1.6189 49500 [32/40]: D loss 0.1965, G loss 1.6342, logits real 2.2894, logits fake -1.4202 50000 [32/40]: D loss 0.2165, G loss 1.5066, logits real 2.2227, logits fake -1.8342 50500 [32/40]: D loss 0.1535, G loss 1.6225, logits real 2.4805, logits fake -1.5775 51000 [33/40]: D loss 0.2074, G loss 1.7396, logits real 1.9575, logits fake -2.2553 51500 [33/40]: D loss 0.1582, G loss 1.7959, logits real 2.4480, logits fake -1.7960 52000 [33/40]: D loss 0.1645, G loss 1.9366, logits real 2.5978, logits fake -1.6151 52500 [34/40]: D loss 0.1661, G loss 1.7223, logits real 2.6000, logits fake -1.7144 53000 [34/40]: D loss 0.1480, G loss 2.1309, logits real 3.1120, logits fake -1.5588 53500 [34/40]: D loss 0.1204, G loss 1.8373, logits real 2.6827, logits fake -1.6749 54000 [35/40]: D loss 0.1643, G loss 1.8832, logits real 2.5383, logits fake -1.6997 54500 [35/40]: D loss 0.1970, G loss 1.7919, logits real 2.6112, logits fake -1.5330 55000 [35/40]: D loss 0.1821, G loss 1.5336, logits real 2.3185, logits fake -1.8593 55500 [36/40]: D loss 0.1365, G loss 1.9283, logits real 2.6313, logits fake -1.8469 56000 [36/40]: D loss 0.1418, G loss 1.8703, logits real 2.4342, logits fake -1.8000 56500 [36/40]: D loss 0.1354, G loss 2.0224, logits real 2.9076, logits fake -1.5236 57000 [37/40]: D loss 0.1590, G loss 1.8018, logits real 2.6975, logits fake -1.7003 57500 [37/40]: D loss 0.1649, G loss 1.8073, logits real 2.7424, logits fake -1.7931 58000 [37/40]: D loss 0.1426, G loss 2.0696, logits real 2.6659, logits fake -2.0334 58500 [37/40]: D loss 0.1731, G loss 1.7543, logits real 2.3579, logits fake -1.8464 59000 [38/40]: D loss 0.1296, G loss 1.9206, logits real 2.6974, logits fake -1.8775 59500 [38/40]: D loss 0.1909, G loss 1.8966, logits real 2.4968, logits fake -1.8848 60000 [38/40]: D loss 0.1201, G loss 1.9811, logits real 3.0063, logits fake -1.9123 60500 [39/40]: D loss 0.1573, G loss 1.7648, logits real 2.8817, logits fake -1.6806 61000 [39/40]: D loss 0.1277, G loss 1.8872, logits real 2.8245, logits fake -1.8597 61500 [39/40]: D loss 0.1415, G loss 1.9578, logits real 2.7728, logits fake -1.9543 62000 [40/40]: D loss 0.1094, G loss 1.9513, logits real 2.7964, logits fake -1.9538 62500 [40/40]: D loss 0.1221, G loss 1.8556, logits real 2.7409, logits fake -1.8492 63000 [40/40]: D loss 0.1267, G loss 1.9557, logits real 2.5516, logits fake -1.9559
def plot_losses(G_losses, D_losses):
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(111)
iters = np.arange(1, len(G_losses) + 1)
ax.scatter(iters, G_losses, marker='.', label="Generator")
ax.scatter(iters, D_losses, marker='.', label="Discriminator")
ax.set_xlabel("Iterations")
ax.set_ylabel("Loss")
ax.set_title("Generator and Discriminator Loss During Training")
ax.legend()
plt.show()
def animate_progress(img_list):
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111)
ax.set_axis_off()
frames = [[plot_batch(ax, batch, animated=True)] for batch in img_list]
ani = animation.ArtistAnimation(fig, frames, interval=1000, repeat_delay=1000, blit=True)
plt.close()
return ani
def show_fake_batch(fixed_noise):
with torch.no_grad():
fake_batch = generator(fixed_noise).detach().cpu()
show_images(fake_batch, "Generated Images")
plot_losses(trainer.history['loss_G'], trainer.history['loss_D'])
ani = animate_progress(trainer.history['fake_images'][:32])
HTML(ani.to_jshtml())