Variational Autoencoder (VAE), arXiv:1312.6114 [stat.ML]
$\beta$-VAE, I. Higgins et al, β-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework, ICLR, (2017).
$\sigma$-VAE, arXiv:2006.13202 [cs.LG]
MCMC sampling, arXiv:1610.09296 [cs.LG]
Data $\boldsymbol{x}$ are generated by a random process involving latent continuous random variables $\boldsymbol{z}$. Encoder $q_{\boldsymbol{\phi}}(\boldsymbol{z}|\boldsymbol{x})$ produces distribution over the possible values of $\boldsymbol{z}$ from which the datapoint $\boldsymbol{x}$ could have been generated. Decoder $p_{\boldsymbol{\theta}}(\boldsymbol{x}|\boldsymbol{z})$ produces a distribution over the possible values of $\boldsymbol{x}$ corresponding to $\boldsymbol{z}$. Marginal distribution of data is $p_{\boldsymbol{\theta}}(\boldsymbol{x})=\int p(\boldsymbol{z})p_{\boldsymbol{\theta}}(\boldsymbol{x}|\boldsymbol{z})\,\mathrm{d}\boldsymbol{z}$. Marginal likelihood of datapoint $\boldsymbol{x}$ satisfies a lower bound $$ \log p_{\boldsymbol{\theta}}(\boldsymbol{x})\geq \mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{z}|\boldsymbol{x})} \left[\log p_{\boldsymbol{\theta}}(\boldsymbol{x}|\boldsymbol{z})\right] -D_{\mathrm{KL}}(q_{\boldsymbol{\phi}}(\boldsymbol{z}|\boldsymbol{x})||p(\boldsymbol{z}))\,. $$ Here $D_{\mathrm{KL}}$ is the Kullback-Leibler divergence, the first term is the negative reconstruction loss. The model is fitted to a given dataset by maximizing the evidence lower bound (ELBO) $\log p_{\boldsymbol{\theta}}(\boldsymbol{x})$.
Assuming that the prior over the latent variables is the centered isotropic multivariate Gaussian $p(\boldsymbol{z})=\mathcal{N}(\boldsymbol{z},\boldsymbol{0},\boldsymbol{I})$ and approximate posterior distribution of the latent variables $q_{\boldsymbol{\phi}}(\boldsymbol{z}|\boldsymbol{x})$ is a multivariate Gaussian, $q_{\boldsymbol{\phi}}(\boldsymbol{z}|\boldsymbol{x})=\mathcal{N}(\boldsymbol{z},\boldsymbol{\mu},\boldsymbol{\sigma}^{2}\boldsymbol{I})$, the KL divergence becomes $$ D_{\mathrm{KL}}(q_{\boldsymbol{\phi}}(\boldsymbol{z}|\boldsymbol{x})||p(\boldsymbol{z}))= -\frac{1}{2}\sum_{j=1}^{N_{\mathrm{latent}}}(1+\log\sigma_{j}^{2}-\mu_{j}^{2}-\sigma_{j}^{2})\,. $$ Here $N_{\mathrm{latent}}$ is the number of latent variables. If the distribution $p_{\boldsymbol{\theta}}(\boldsymbol{x}|\boldsymbol{z})$ is an isotropic Gaussian, $p_{\boldsymbol{\theta}}(\boldsymbol{x}|\boldsymbol{z})=\mathcal{N}(\boldsymbol{x}, \boldsymbol{\mu}_{\boldsymbol{\theta}},\sigma_{\boldsymbol{\theta}}^2\boldsymbol{I})$, then $$ -\log p_{\boldsymbol{\theta}}(\boldsymbol{x}|\boldsymbol{z}) =\frac{1}{2\sigma_{\boldsymbol{\theta}}^2}N_{\mathrm{data}}\mathrm{MSE}(\boldsymbol{\mu}_{\boldsymbol{\theta}},\boldsymbol{x})+ N_{\mathrm{data}}\log\sigma_{\boldsymbol{\theta}}\sqrt{2\pi}\,, $$ where $N_{\mathrm{data}}$ is the dimensionality of $\boldsymbol{x}$.
The average $\mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{z}|\boldsymbol{x})}$ is estimated by sampling one value of $\boldsymbol{z}$. Assuming that $\sigma_{\boldsymbol{\theta}}$ is constant, the loss function of the variational autoencoder takes the form $$ \mathcal{L}(\boldsymbol{\theta},\boldsymbol{\phi},\boldsymbol{x})= \mathrm{MSE}(\boldsymbol{\mu}_{\boldsymbol{\theta}},\boldsymbol{x})+ \frac{N_{\mathrm{latent}}}{N_{\mathrm{data}}}\mathbb{E}_{j}[D_{\mathrm{KL}}(q_{\boldsymbol{\phi}}(z_{j}|\boldsymbol{x})||p(z_{j}))]\,, $$ In $\beta$-VAE the loss function is $$ \mathcal{L}(\boldsymbol{\theta},\boldsymbol{\phi},\boldsymbol{x})= \mathrm{MSE}(\boldsymbol{\mu}_{\boldsymbol{\theta}},\boldsymbol{x})+ \beta\frac{N_{\mathrm{latent}}}{N_{\mathrm{data}}}\mathbb{E}_{j}[D_{\mathrm{KL}}(q_{\boldsymbol{\phi}}(z_{j}|\boldsymbol{x})||p(z_{j}))] $$ $\sigma$-VAE: the optimal $\sigma_{\boldsymbol{\theta}}$ that minimizes the reconstruction loss is $\sigma^{{*}2}=\mathrm{MSE}(\boldsymbol{\mu}_{\boldsymbol{\theta}},\boldsymbol{x})$. The loss function becomes $$ \mathcal{L}(\boldsymbol{\theta},\boldsymbol{\phi},\boldsymbol{x})= \mathrm{MSE}(\boldsymbol{\mu}_{\boldsymbol{\theta}},\boldsymbol{x})+ 2\sigma^{{*}2}\frac{N_{\mathrm{latent}}}{N_{\mathrm{data}}}\mathbb{E}_{j}[D_{\mathrm{KL}}(q_{\boldsymbol{\phi}}(z_{j}|\boldsymbol{x})||p(z_{j}))] $$
Imports
from pathlib import Path
from functools import partial
from collections import defaultdict
import math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision.utils as vision_utils
Configuration
DATA_DIR = Path("./data")
MODELS_DIR = Path("./models")
NUM_WORKERS = 8
BATCH_SIZE = 128
IMAGE_SIZE = 128
IMAGE_CHANNELS = 3
LATENT_CHANNELS = 16
EPOCHS = 40
LEARNING_RATE = 1e-2
WEIGHT_DECAY = 1e-2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
The loss function is $$ \mathcal{L}(\boldsymbol{\theta},\boldsymbol{\phi},\boldsymbol{x})= \mathrm{MSE}(\hat{\boldsymbol{x}},\boldsymbol{x})+ \beta\frac{N_{\mathrm{latent}}}{N_{\mathrm{data}}}\mathbb{E}_{j}\left[-\frac{1}{2}(1+\log\sigma_j^2-\mu_j^2-\sigma_j^2)\right] $$
class VAELoss(nn.Module):
def __init__(self, β=1):
super().__init__()
self.reconstruction_loss = nn.MSELoss()
self.β = β
def kld_loss(self, mu, log_var):
kld = -0.5 * (1 + log_var - mu ** 2 - log_var.exp())
loss = torch.mean(kld)
return loss
def forward(self, outputs, target):
output, mu, log_var = outputs
reconst_loss = self.reconstruction_loss(output, target)
kld_loss = self.kld_loss(mu, log_var)
d_inp = np.prod(output.shape[1:])
d_latent = np.prod(mu.shape[1:])
kld_weight = self.β * d_latent / d_inp
loss = reconst_loss + kld_weight * kld_loss
return {"loss": loss, "reconstruction loss": reconst_loss, "KLD": kld_loss}
We create fully convolutional encoder and decoder.
def init_model(model):
for m in model.modules():
if isinstance(m, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
class ConvBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
padding = (kernel_size - 1) // 2
super().__init__(
nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
class Stack(nn.Sequential):
def __init__(self, channels_list, block):
layers = []
for in_channels, out_channels in zip(channels_list[:-1], channels_list[1:]):
layers.append(block(in_channels, out_channels))
super().__init__(*layers)
Encoder
class DownBlock(nn.Sequential):
def __init__(self, in_channels, out_channels):
super().__init__(
nn.MaxPool2d(2),
ConvBlock(in_channels, out_channels, 3)
)
class Encoder(nn.Module):
def __init__(self, in_channels, channels_list, latent_channels):
super().__init__()
self.stem = ConvBlock(in_channels, channels_list[0], 3)
self.blocks = Stack(channels_list, DownBlock)
self.to_mu = nn.Conv2d(channels_list[-1], latent_channels, 3, padding=1)
self.to_var = nn.Conv2d(channels_list[-1], latent_channels, 3, padding=1)
def forward(self, x):
x = self.stem(x)
x = self.blocks(x)
mu = self.to_mu(x)
log_var = self.to_var(x)
return mu, log_var
Decoder
class UpBlock(nn.Sequential):
def __init__(self, in_channels, out_channels):
super().__init__(
nn.Upsample(scale_factor=2, mode='nearest'),
ConvBlock(in_channels, out_channels, 3)
)
class Decoder(nn.Module):
def __init__(self, latent_channels, channels_list, out_channels):
super().__init__()
self.stem = ConvBlock(latent_channels, channels_list[0], 3)
self.blocks = Stack(channels_list, UpBlock)
self.to_output = nn.Conv2d(channels_list[-1], out_channels, 3, padding=1)
def forward(self, x):
x = self.stem(x)
x = self.blocks(x)
x = self.to_output(x)
x = torch.sigmoid(x)
return x
Full VAE
class VAE(nn.Module):
def __init__(self, num_downsamplings, latent_channels, channels=32, in_channels=3):
super().__init__()
self.num_downsamplings = num_downsamplings
self.latent_channels = latent_channels
channels_list = [channels * 2**i for i in range(num_downsamplings + 1)]
self.encoder = Encoder(in_channels, channels_list, latent_channels)
channels_list.reverse()
self.decoder = Decoder(latent_channels, channels_list, in_channels)
def forward(self, x):
mu, log_var = self.encoder(x)
z = self.reparameterize(mu, log_var)
out = self.decoder(z)
return out, mu, log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + std * eps
def sample(self, num_samples, shape, device, num_steps=0):
reduction = 2**self.num_downsamplings
z = torch.randn(num_samples, self.latent_channels, shape[0] // reduction, shape[1] // reduction, device=device)
with torch.no_grad():
# MCMC sampling
for num in range(num_steps):
x = self.decoder(z)
mu, log_var = self.encoder(x)
z = self.reparameterize(mu, log_var)
out = self.decoder(z)
return out
model = VAE(num_downsamplings=5, latent_channels=LATENT_CHANNELS, channels=16, in_channels=IMAGE_CHANNELS)
init_model(model)
model = model.to(DEVICE)
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
Number of parameters: 3,368,771
train_transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.CenterCrop(IMAGE_SIZE),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
val_transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.CenterCrop(IMAGE_SIZE),
transforms.ToTensor()
])
train_dset = datasets.CelebA(str(DATA_DIR), split='train', transform=train_transform, download=False)
val_dset = datasets.CelebA(str(DATA_DIR), split='test', transform=val_transform, download=False)
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_WORKERS)
val_loader = torch.utils.data.DataLoader(val_dset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS)
def plot_batch(ax, batch, title=None, **kwargs):
imgs = vision_utils.make_grid(batch, padding=2, normalize=True)
imgs = np.moveaxis(imgs.numpy(), 0, -1)
ax.set_axis_off()
if title is not None: ax.set_title(title)
return ax.imshow(imgs, **kwargs)
def show_images(batch, title):
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111)
plot_batch(ax, batch, title)
plt.show()
test_batch, _ = next(iter(train_loader))
show_images(test_batch[:64], "Training Images")
class AverageLoss():
def __init__(self, name):
self.name = name
self.reset()
def reset(self):
self.num_samples = 0
self.total_loss = 0.
def update(self, data):
batch_size = data['batch_size']
self.num_samples += batch_size
self.total_loss += batch_size * data[self.name]
def compute(self):
avg_loss = self.total_loss / self.num_samples
metrics = {self.name: avg_loss}
return metrics
class Learner:
def __init__(self, model, loss, optimizer, train_loader, val_loader, device,
batch_scheduler=None, epoch_scheduler=None):
self.model = model
self.loss = loss
self.optimizer = optimizer
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.batch_scheduler = batch_scheduler
self.epoch_scheduler = epoch_scheduler
self.history = defaultdict(list)
self.metrics = [AverageLoss(x) for x in ["loss", "reconstruction loss", "KLD"]]
def iterate(self, loader, train=False):
for metric in self.metrics:
metric.reset()
for batch in loader:
images = batch[0].to(self.device)
outputs = self.model(images)
losses = self.loss(outputs, images)
if train: self.backward_pass(losses["loss"])
data = {k: v.item() for k, v in losses.items()}
data["batch_size"] = len(images)
for metric in self.metrics:
metric.update(data)
summary = {}
for metric in self.metrics:
summary.update(metric.compute())
return summary
def backward_pass(self, batch_loss):
self.optimizer.zero_grad()
batch_loss.backward()
self.optimizer.step()
if self.batch_scheduler is not None:
self.batch_scheduler.step()
def log_metrics(self, metrics, name):
print(f"{name}: ", end='', flush=True)
for key, val in metrics.items():
self.history[name + ' ' + key].append(val)
print(f"{key} {val:.3f} ", end='')
def train(self):
self.model.train()
metrics = self.iterate(self.train_loader, train=True)
self.log_metrics(metrics, 'train')
def validate(self):
self.model.eval()
with torch.no_grad():
metrics = self.iterate(self.val_loader)
self.log_metrics(metrics, 'val')
def fit(self, epochs):
for epoch in range(1, epochs + 1):
print(f"{epoch}/{epochs} - ", end='')
self.train()
print('; ', end='')
self.validate()
print()
if self.epoch_scheduler is not None:
self.epoch_scheduler.step()
torch.save(model.state_dict(), str(MODELS_DIR / 'final_model.pt'))
def plot_history_train_val(history, key):
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['train ' + key]) + 1)
ax.plot(xs, history['train ' + key], '.-', label='train')
ax.plot(xs, history['val ' + key], '.-', label='val')
ax.set_xlabel('epoch')
ax.set_ylabel(key)
ax.grid()
ax.legend()
plt.show()
We exclude the parameters of final encoder and decoder layers from weight decay.
def separate_parameters(model):
# Separate parameters of the model into groups with weight decay and no weight decay
# We exclude the parameters of final encoder and decoder layers from weight decay
parameters_decay = set()
parameters_no_decay = set()
layers_no_weight_decay = ["encoder.to_mu", "encoder.to_var", "decoder.to_output"]
for param_name, param in model.named_parameters():
if any(layer_name in param_name for layer_name in layers_no_weight_decay):
parameters_no_decay.add(param_name)
else:
parameters_decay.add(param_name)
# sanity check
assert len(parameters_decay & parameters_no_decay) == 0
assert len(parameters_decay) + len(parameters_no_decay) == len(list(model.parameters()))
return parameters_decay, parameters_no_decay
def get_optimizer(model, learning_rate, weight_decay):
param_dict = {pn: p for pn, p in model.named_parameters()}
parameters_decay, parameters_no_decay = separate_parameters(model)
optim_groups = [
{"params": [param_dict[pn] for pn in parameters_decay], "weight_decay": weight_decay},
{"params": [param_dict[pn] for pn in parameters_no_decay], "weight_decay": 0.0},
]
optimizer = optim.AdamW(optim_groups, lr=learning_rate)
return optimizer
loss = VAELoss(β=5.)
optimizer = get_optimizer(model, LEARNING_RATE, WEIGHT_DECAY)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
steps_per_epoch=len(train_loader), epochs=EPOCHS)
learner = Learner(model, loss, optimizer, train_loader, val_loader, DEVICE, batch_scheduler=lr_scheduler)
learner.fit(EPOCHS)
1/40 - train: loss 0.025 reconstruction loss 0.019 KLD 0.226 ; val: loss 0.022 reconstruction loss 0.016 KLD 0.225 2/40 - train: loss 0.021 reconstruction loss 0.015 KLD 0.225 ; val: loss 0.021 reconstruction loss 0.015 KLD 0.233 3/40 - train: loss 0.020 reconstruction loss 0.014 KLD 0.228 ; val: loss 0.021 reconstruction loss 0.015 KLD 0.225 4/40 - train: loss 0.020 reconstruction loss 0.014 KLD 0.225 ; val: loss 0.019 reconstruction loss 0.014 KLD 0.209 5/40 - train: loss 0.020 reconstruction loss 0.014 KLD 0.222 ; val: loss 0.019 reconstruction loss 0.014 KLD 0.220 6/40 - train: loss 0.019 reconstruction loss 0.014 KLD 0.223 ; val: loss 0.019 reconstruction loss 0.013 KLD 0.231 7/40 - train: loss 0.019 reconstruction loss 0.013 KLD 0.224 ; val: loss 0.019 reconstruction loss 0.014 KLD 0.219 8/40 - train: loss 0.019 reconstruction loss 0.013 KLD 0.225 ; val: loss 0.019 reconstruction loss 0.013 KLD 0.220 9/40 - train: loss 0.019 reconstruction loss 0.013 KLD 0.225 ; val: loss 0.019 reconstruction loss 0.013 KLD 0.229 10/40 - train: loss 0.019 reconstruction loss 0.013 KLD 0.225 ; val: loss 0.019 reconstruction loss 0.013 KLD 0.229 11/40 - train: loss 0.019 reconstruction loss 0.013 KLD 0.226 ; val: loss 0.019 reconstruction loss 0.013 KLD 0.229 12/40 - train: loss 0.019 reconstruction loss 0.013 KLD 0.226 ; val: loss 0.019 reconstruction loss 0.013 KLD 0.226 13/40 - train: loss 0.019 reconstruction loss 0.013 KLD 0.226 ; val: loss 0.019 reconstruction loss 0.013 KLD 0.215 14/40 - train: loss 0.019 reconstruction loss 0.013 KLD 0.226 ; val: loss 0.019 reconstruction loss 0.013 KLD 0.223 15/40 - train: loss 0.019 reconstruction loss 0.013 KLD 0.227 ; val: loss 0.018 reconstruction loss 0.013 KLD 0.227 16/40 - train: loss 0.019 reconstruction loss 0.013 KLD 0.227 ; val: loss 0.019 reconstruction loss 0.013 KLD 0.220 17/40 - train: loss 0.019 reconstruction loss 0.013 KLD 0.227 ; val: loss 0.019 reconstruction loss 0.013 KLD 0.227 18/40 - train: loss 0.018 reconstruction loss 0.013 KLD 0.227 ; val: loss 0.018 reconstruction loss 0.012 KLD 0.231 19/40 - train: loss 0.018 reconstruction loss 0.012 KLD 0.227 ; val: loss 0.019 reconstruction loss 0.013 KLD 0.223 20/40 - train: loss 0.018 reconstruction loss 0.012 KLD 0.227 ; val: loss 0.018 reconstruction loss 0.013 KLD 0.223 21/40 - train: loss 0.018 reconstruction loss 0.012 KLD 0.227 ; val: loss 0.018 reconstruction loss 0.012 KLD 0.224 22/40 - train: loss 0.018 reconstruction loss 0.012 KLD 0.227 ; val: loss 0.018 reconstruction loss 0.012 KLD 0.226 23/40 - train: loss 0.018 reconstruction loss 0.012 KLD 0.227 ; val: loss 0.018 reconstruction loss 0.012 KLD 0.235 24/40 - train: loss 0.018 reconstruction loss 0.012 KLD 0.227 ; val: loss 0.018 reconstruction loss 0.012 KLD 0.230 25/40 - train: loss 0.018 reconstruction loss 0.012 KLD 0.227 ; val: loss 0.018 reconstruction loss 0.012 KLD 0.219 26/40 - train: loss 0.018 reconstruction loss 0.012 KLD 0.227 ; val: loss 0.018 reconstruction loss 0.012 KLD 0.222 27/40 - train: loss 0.018 reconstruction loss 0.012 KLD 0.227 ; val: loss 0.018 reconstruction loss 0.012 KLD 0.228 28/40 - train: loss 0.018 reconstruction loss 0.012 KLD 0.227 ; val: loss 0.018 reconstruction loss 0.012 KLD 0.224 29/40 - train: loss 0.018 reconstruction loss 0.012 KLD 0.228 ; val: loss 0.018 reconstruction loss 0.012 KLD 0.223 30/40 - train: loss 0.018 reconstruction loss 0.012 KLD 0.228 ; val: loss 0.018 reconstruction loss 0.012 KLD 0.222 31/40 - train: loss 0.018 reconstruction loss 0.012 KLD 0.227 ; val: loss 0.018 reconstruction loss 0.012 KLD 0.228 32/40 - train: loss 0.018 reconstruction loss 0.012 KLD 0.228 ; val: loss 0.018 reconstruction loss 0.012 KLD 0.224 33/40 - train: loss 0.018 reconstruction loss 0.012 KLD 0.228 ; val: loss 0.017 reconstruction loss 0.011 KLD 0.227 34/40 - train: loss 0.017 reconstruction loss 0.012 KLD 0.228 ; val: loss 0.017 reconstruction loss 0.011 KLD 0.228 35/40 - train: loss 0.017 reconstruction loss 0.011 KLD 0.228 ; val: loss 0.017 reconstruction loss 0.011 KLD 0.230 36/40 - train: loss 0.017 reconstruction loss 0.011 KLD 0.228 ; val: loss 0.017 reconstruction loss 0.011 KLD 0.226 37/40 - train: loss 0.017 reconstruction loss 0.011 KLD 0.228 ; val: loss 0.017 reconstruction loss 0.011 KLD 0.228 38/40 - train: loss 0.017 reconstruction loss 0.011 KLD 0.228 ; val: loss 0.017 reconstruction loss 0.011 KLD 0.227 39/40 - train: loss 0.017 reconstruction loss 0.011 KLD 0.228 ; val: loss 0.017 reconstruction loss 0.011 KLD 0.229 40/40 - train: loss 0.017 reconstruction loss 0.011 KLD 0.228 ; val: loss 0.017 reconstruction loss 0.011 KLD 0.228
plot_history_train_val(learner.history, 'loss')
plot_history_train_val(learner.history, 'reconstruction loss')
plot_history_train_val(learner.history, 'KLD')
model.load_state_dict(torch.load(str(MODELS_DIR / 'final_model.pt')))
<All keys matched successfully>
model.eval();
Testing reconstruction
def reconstruct(model, batch, device):
batch = batch.to(device)
with torch.no_grad():
reconstructed_batch = model(batch)[0]
reconstructed_batch = reconstructed_batch.cpu()
return reconstructed_batch
def show_2_batches(batch1, batch2, title1, title2):
fig = plt.figure(figsize=(16, 8))
ax = fig.add_subplot(121)
plot_batch(ax, batch1, title1)
ax = fig.add_subplot(122)
plot_batch(ax, batch2, title2)
plt.show()
test_batch, _ = next(iter(val_loader))
reconstructed_batch = reconstruct(model, test_batch, DEVICE)
show_2_batches(test_batch[:64], reconstructed_batch[:64], "Validation Images", "Reconstructed Images")
Testing generation
sample_images = model.sample(64, (IMAGE_SIZE, IMAGE_SIZE), DEVICE, 2)
show_images(sample_images.cpu(), "Sample Images")
Reset model parameters
init_model(model)
loss = VAELoss(β=1.)
optimizer = get_optimizer(model, LEARNING_RATE, WEIGHT_DECAY)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
steps_per_epoch=len(train_loader), epochs=EPOCHS)
learner = Learner(model, loss, optimizer, train_loader, val_loader, DEVICE, batch_scheduler=lr_scheduler)
learner.fit(EPOCHS)
1/40 - train: loss 0.021 reconstruction loss 0.018 KLD 0.684 ; val: loss 0.015 reconstruction loss 0.012 KLD 0.602 2/40 - train: loss 0.015 reconstruction loss 0.011 KLD 0.593 ; val: loss 0.015 reconstruction loss 0.012 KLD 0.605 3/40 - train: loss 0.013 reconstruction loss 0.010 KLD 0.588 ; val: loss 0.013 reconstruction loss 0.010 KLD 0.601 4/40 - train: loss 0.013 reconstruction loss 0.010 KLD 0.589 ; val: loss 0.013 reconstruction loss 0.009 KLD 0.611 5/40 - train: loss 0.012 reconstruction loss 0.009 KLD 0.597 ; val: loss 0.012 reconstruction loss 0.009 KLD 0.631 6/40 - train: loss 0.012 reconstruction loss 0.009 KLD 0.606 ; val: loss 0.012 reconstruction loss 0.008 KLD 0.623 7/40 - train: loss 0.012 reconstruction loss 0.009 KLD 0.612 ; val: loss 0.012 reconstruction loss 0.009 KLD 0.628 8/40 - train: loss 0.012 reconstruction loss 0.008 KLD 0.617 ; val: loss 0.012 reconstruction loss 0.009 KLD 0.587 9/40 - train: loss 0.012 reconstruction loss 0.008 KLD 0.620 ; val: loss 0.012 reconstruction loss 0.009 KLD 0.613 10/40 - train: loss 0.011 reconstruction loss 0.008 KLD 0.623 ; val: loss 0.012 reconstruction loss 0.009 KLD 0.623 11/40 - train: loss 0.011 reconstruction loss 0.008 KLD 0.625 ; val: loss 0.012 reconstruction loss 0.009 KLD 0.643 12/40 - train: loss 0.011 reconstruction loss 0.008 KLD 0.627 ; val: loss 0.011 reconstruction loss 0.008 KLD 0.634 13/40 - train: loss 0.011 reconstruction loss 0.008 KLD 0.628 ; val: loss 0.011 reconstruction loss 0.008 KLD 0.648 14/40 - train: loss 0.011 reconstruction loss 0.008 KLD 0.629 ; val: loss 0.011 reconstruction loss 0.008 KLD 0.623 15/40 - train: loss 0.011 reconstruction loss 0.008 KLD 0.630 ; val: loss 0.011 reconstruction loss 0.008 KLD 0.636 16/40 - train: loss 0.011 reconstruction loss 0.008 KLD 0.631 ; val: loss 0.011 reconstruction loss 0.008 KLD 0.607 17/40 - train: loss 0.011 reconstruction loss 0.008 KLD 0.632 ; val: loss 0.011 reconstruction loss 0.008 KLD 0.627 18/40 - train: loss 0.011 reconstruction loss 0.008 KLD 0.633 ; val: loss 0.011 reconstruction loss 0.008 KLD 0.642 19/40 - train: loss 0.011 reconstruction loss 0.007 KLD 0.633 ; val: loss 0.011 reconstruction loss 0.008 KLD 0.620 20/40 - train: loss 0.011 reconstruction loss 0.007 KLD 0.634 ; val: loss 0.011 reconstruction loss 0.007 KLD 0.649 21/40 - train: loss 0.011 reconstruction loss 0.007 KLD 0.634 ; val: loss 0.011 reconstruction loss 0.007 KLD 0.640 22/40 - train: loss 0.011 reconstruction loss 0.007 KLD 0.635 ; val: loss 0.011 reconstruction loss 0.007 KLD 0.660 23/40 - train: loss 0.011 reconstruction loss 0.007 KLD 0.635 ; val: loss 0.011 reconstruction loss 0.007 KLD 0.631 24/40 - train: loss 0.011 reconstruction loss 0.007 KLD 0.635 ; val: loss 0.011 reconstruction loss 0.007 KLD 0.643 25/40 - train: loss 0.010 reconstruction loss 0.007 KLD 0.636 ; val: loss 0.010 reconstruction loss 0.007 KLD 0.642 26/40 - train: loss 0.010 reconstruction loss 0.007 KLD 0.636 ; val: loss 0.010 reconstruction loss 0.007 KLD 0.617 27/40 - train: loss 0.010 reconstruction loss 0.007 KLD 0.636 ; val: loss 0.010 reconstruction loss 0.007 KLD 0.624 28/40 - train: loss 0.010 reconstruction loss 0.007 KLD 0.636 ; val: loss 0.010 reconstruction loss 0.007 KLD 0.625 29/40 - train: loss 0.010 reconstruction loss 0.007 KLD 0.636 ; val: loss 0.010 reconstruction loss 0.007 KLD 0.622 30/40 - train: loss 0.010 reconstruction loss 0.007 KLD 0.637 ; val: loss 0.010 reconstruction loss 0.007 KLD 0.639 31/40 - train: loss 0.010 reconstruction loss 0.007 KLD 0.637 ; val: loss 0.010 reconstruction loss 0.007 KLD 0.638 32/40 - train: loss 0.010 reconstruction loss 0.007 KLD 0.637 ; val: loss 0.010 reconstruction loss 0.007 KLD 0.650 33/40 - train: loss 0.010 reconstruction loss 0.007 KLD 0.637 ; val: loss 0.010 reconstruction loss 0.007 KLD 0.632 34/40 - train: loss 0.010 reconstruction loss 0.007 KLD 0.637 ; val: loss 0.010 reconstruction loss 0.007 KLD 0.632 35/40 - train: loss 0.010 reconstruction loss 0.007 KLD 0.637 ; val: loss 0.010 reconstruction loss 0.007 KLD 0.632 36/40 - train: loss 0.010 reconstruction loss 0.007 KLD 0.638 ; val: loss 0.010 reconstruction loss 0.007 KLD 0.634 37/40 - train: loss 0.010 reconstruction loss 0.007 KLD 0.638 ; val: loss 0.010 reconstruction loss 0.007 KLD 0.632 38/40 - train: loss 0.010 reconstruction loss 0.007 KLD 0.638 ; val: loss 0.010 reconstruction loss 0.006 KLD 0.635 39/40 - train: loss 0.010 reconstruction loss 0.007 KLD 0.638 ; val: loss 0.010 reconstruction loss 0.006 KLD 0.635 40/40 - train: loss 0.010 reconstruction loss 0.007 KLD 0.638 ; val: loss 0.010 reconstruction loss 0.006 KLD 0.638
plot_history_train_val(learner.history, 'loss')
plot_history_train_val(learner.history, 'reconstruction loss')
plot_history_train_val(learner.history, 'KLD')
Image generation
model.eval();
img_large = model.sample(1, (4*IMAGE_SIZE, 4*IMAGE_SIZE), DEVICE)
show_images(img_large.cpu(), "Large latent")