NVAE, arXiv:2007.03898 [stat.ML]
VDVAE, arXiv:2011.10650 [cs.LG]
VDVAE GitHub: https://github.com/openai/vdvae
Latent variables $\boldsymbol{z}$ are split into $L$ layers $\boldsymbol{z}_i$, $i=1\ldots L$. Prior distribution $$ p_{\boldsymbol{\theta}}(\boldsymbol{z})=p(\boldsymbol{z}_{L})\prod_{i=1}^{L-1}p_{\boldsymbol{\theta}}(\boldsymbol{z}_{i}|\boldsymbol{z}_{i<})\,. $$ Inference (encoder): $$ q_{\boldsymbol{\phi}}(\boldsymbol{z}|\boldsymbol{x})=q_{\boldsymbol{\phi}}(\boldsymbol{z}_{L}|\boldsymbol{x})\prod_{i=1}^{L-1}q_{\boldsymbol{\phi}}(\boldsymbol{z}_{i}|\boldsymbol{z}_{i<},\boldsymbol{x})\,. $$ ELBO takes the form $$ \log p(\boldsymbol{x})\geq\mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{z}|\boldsymbol{x})}[\log p_{\boldsymbol{\theta}}(\boldsymbol{x}|\boldsymbol{z})]-D_{\mathrm{KL}}(q_{\boldsymbol{\phi}}(\boldsymbol{z}|\boldsymbol{x})||p_{\boldsymbol{\theta}}(\boldsymbol{z}))\,, $$ where the Kullback-Leibler divergence is $$ D_{\mathrm{KL}}(q_{\boldsymbol{\phi}}(\boldsymbol{z}|\boldsymbol{x})||p_{\boldsymbol{\theta}}(\boldsymbol{z}))=D_{\mathrm{KL}}(q_{\boldsymbol{\phi}}(\boldsymbol{z}_{L}|\boldsymbol{x})||p(\boldsymbol{z}_{L}))+\sum_{i=1}^{L-1}\mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{z}_{i<}|\boldsymbol{x})}[D_{\mathrm{KL}}(q_{\boldsymbol{\phi}}(\boldsymbol{z}_{i}|\boldsymbol{z}_{i<},\boldsymbol{x})||p_{\boldsymbol{\theta}}(\boldsymbol{z}_{i}|\boldsymbol{z}_{i<}))]\,. $$ Here we introduced $$ q_{\boldsymbol{\phi}}(\boldsymbol{z}_{i<}|\boldsymbol{x})=q_{\boldsymbol{\phi}}(\boldsymbol{z}_{L}|\boldsymbol{x})\prod_{j=i+1}^{L-1}q_{\boldsymbol{\phi}}(\boldsymbol{z}_{j}|\boldsymbol{z}_{j<},\boldsymbol{x}) $$ Assuming Gaussian distributions $$ \begin{aligned} p(\boldsymbol{z}_{L})&=\mathcal{N}(\boldsymbol{z}_{L},\boldsymbol{0},\boldsymbol{I})\,,\\ p_{\boldsymbol{\theta}}(\boldsymbol{z}_{i}|\boldsymbol{z}_{i<})&=\mathcal{N}(\boldsymbol{z}_{i},\boldsymbol{\mu}_{\boldsymbol{\theta},i}(\boldsymbol{z}_{i<}),\sigma_{\boldsymbol{\theta},i}^{2}(\boldsymbol{z}_{i<})\boldsymbol{I})\,,\\ q_{\boldsymbol{\phi}}(\boldsymbol{z}_{L}|\boldsymbol{x})&=\mathcal{N}(\boldsymbol{z}_{L},\boldsymbol{\mu}_{\boldsymbol{\phi},L}(\boldsymbol{x}),\boldsymbol{\sigma}_{\boldsymbol{\phi},L}^{2}(\boldsymbol{x})\boldsymbol{I})\,,\\ q_{\boldsymbol{\phi}}(\boldsymbol{z}_{i}|\boldsymbol{z}_{i<},\boldsymbol{x})&=\mathcal{N}(\boldsymbol{z}_{i},\boldsymbol{\mu}_{\boldsymbol{\phi},i}(\boldsymbol{z}_{i<},\boldsymbol{x}),\boldsymbol{\sigma}_{\boldsymbol{\phi},i}^{2}(\boldsymbol{z}_{i<},\boldsymbol{x})\boldsymbol{I}) \end{aligned} $$ Kullback-Leibler divergence becomes $$ D_{\mathrm{KL}} = \frac{1}{2}\sum_{j=1}^{D_{L}}\left(\mu_{\boldsymbol{\phi},L}^{(j)2}+\sigma_{\boldsymbol{\phi},L}^{(j)2}-\log\sigma_{\boldsymbol{\phi},L}^{(j)2}-1\right) +\frac{1}{2}\sum_{i=1}^{L-1}\sum_{j=1}^{D_{i}}\left(\frac{(\mu_{\boldsymbol{\phi},i}^{(j)}-\mu_{\boldsymbol{\theta},i}^{(j)})^{2}}{\sigma_{\boldsymbol{\theta},i}^{(j)2}}+\frac{\sigma_{\boldsymbol{\phi},i}^{(j)2}}{\sigma_{\boldsymbol{\theta},i}^{(j)2}}-\log\frac{\sigma_{\boldsymbol{\phi},i}^{(j)2}}{\sigma_{\boldsymbol{\theta},i}^{(j)2}}-1\right) $$
Imports
from pathlib import Path
from functools import partial
from collections import defaultdict
import math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision.utils as vision_utils
Configuration
DATA_DIR = Path("./data")
MODELS_DIR = Path("./models")
NUM_WORKERS = 8
BATCH_SIZE = 128
IMAGE_SIZE = 128
IMAGE_CHANNELS = 3
ENCODER_CHANNELS = 16
LATENT_REDUCTION = 16 # ratio of the number of feature channels to the number of latent channels
NUM_DOWNSAMPLINGS = 5
EPOCHS = 40
LEARNING_RATE = 1e-2
WEIGHT_DECAY = 1e-2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ConvBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, act=True):
padding = (kernel_size - 1) // 2
layers = [
nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm2d(out_channels)
]
if act: layers.append(nn.ReLU(inplace=True))
super().__init__(*layers)
def reparameterize(mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + std * eps
def kl_divergence(mu, log_var, mu_prior, log_var_prior):
d_mu = mu - mu_prior
d_log_var = log_var - log_var_prior
kld = 0.5 * (d_mu**2 / log_var_prior.exp() + d_log_var.exp() - d_log_var - 1)
return kld
def init_model(model):
for m in model.modules():
if isinstance(m, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
class DownBlock(nn.Sequential):
def __init__(self, in_channels, out_channels):
super().__init__(
nn.MaxPool2d(2),
ConvBlock(in_channels, out_channels, 3)
)
class Encoder(nn.Module):
def __init__(self, in_channels, channels, num_downsamplings):
super().__init__()
self.stem = ConvBlock(in_channels, channels, 3)
self.blocks = nn.ModuleList()
in_channels = channels
for _ in range(num_downsamplings):
out_channels = in_channels * 2
self.blocks.append(DownBlock(in_channels, out_channels))
in_channels = out_channels
def forward(self, x):
x = self.stem(x)
xs = []
for block in self.blocks:
xs.append(x)
x = block(x)
return x, xs
class VAELayer(nn.Module):
def __init__(self, channels, latent_reduction):
super().__init__()
latent_channels = channels // latent_reduction
# nn.Conv2d with BatchNorm stabilizes KL losses
self.to_prior = ConvBlock(channels, 2 * latent_channels, 3, act=False)
self.encoder = ConvBlock(channels, 2 * latent_channels, 3, act=False)
self.decoder = ConvBlock(latent_channels, channels, 3, act=False)
self.γ = nn.Parameter(torch.zeros(1))
self.act = nn.ReLU(inplace=True)
def forward(self, x, x_prior):
mu_prior, log_var_prior = self.to_prior(x_prior).chunk(2, dim=1)
mu, log_var = self.encoder(x).chunk(2, dim=1)
z = reparameterize(mu, log_var)
kld = kl_divergence(mu, log_var, mu_prior, log_var_prior)
z = self.decoder(z)
out = x_prior + self.γ * z
out = self.act(out)
return out, kld
def sample(self, x_prior, t=1.):
mu_prior, log_var_prior = self.to_prior(x_prior).chunk(2, dim=1)
log_var_prior = log_var_prior + torch.ones_like(log_var_prior) * np.log(t)
z = reparameterize(mu_prior, log_var_prior)
z = self.decoder(z)
out = x_prior + self.γ * z
out = self.act(out)
return out
class DecoderStem(nn.Module):
def __init__(self, channels, latent_reduction, shape):
super().__init__()
self.vae = VAELayer(channels, latent_reduction)
self.prior = nn.Parameter(torch.zeros(1, channels, *shape))
def forward(self, x):
out, kld = self.vae(x, self.prior)
return out, kld
def sample(self, num_samples, t=1.):
out = self.vae.sample(self.prior.expand(num_samples, -1, -1, -1), t=t)
return out
class UpBlock(nn.Sequential):
def __init__(self, in_channels, out_channels):
super().__init__(
nn.Upsample(scale_factor=2, mode='nearest'),
ConvBlock(in_channels, out_channels, 3)
)
class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, latent_reduction):
super().__init__()
self.up = UpBlock(in_channels, out_channels)
self.vae = VAELayer(out_channels, latent_reduction)
def forward(self, x, skip):
x = self.up(x)
out, kld = self.vae(skip, x)
return out, kld
def sample(self, x, t=1.):
x = self.up(x)
out = self.vae.sample(x, t=t)
return out
class DecoderHead(nn.Sequential):
def __init__(self, in_channels, out_channels):
super().__init__(
ConvBlock(in_channels, out_channels, 3, act=False),
nn.Sigmoid()
)
class Decoder(nn.Module):
def __init__(self, in_channels, out_channels, latent_reduction, num_upsamplings, shape):
super().__init__()
self.stem = DecoderStem(in_channels, latent_reduction, shape)
self.blocks = nn.ModuleList()
for _ in range(num_upsamplings):
channels = in_channels // 2
self.blocks.append(DecoderBlock(in_channels, channels, latent_reduction))
in_channels = channels
self.head = DecoderHead(channels, out_channels)
def forward(self, x, xs):
klds = []
x, kld = self.stem(x)
klds.append(kld)
for block, skip in zip(self.blocks, reversed(xs)):
x, kld = block(x, skip)
klds.append(kld)
x = self.head(x)
return x, klds
def sample(self, num_samples, t=1.):
z = self.stem.sample(num_samples, t=t)
for block in self.blocks:
z = block.sample(z, t)
z = self.head(z)
return z
class VAE(nn.Module):
def __init__(self, num_downsamplings, latent_reduction, enc_channels, image_shape, in_channels=3):
super().__init__()
reduction = 2**num_downsamplings
shape = (image_shape[0] // reduction, image_shape[1] // reduction)
self.encoder = Encoder(in_channels, enc_channels, num_downsamplings)
self.decoder = Decoder(enc_channels * reduction, in_channels, latent_reduction, num_downsamplings, shape)
def forward(self, x):
x, xs = self.encoder(x)
out, klds = self.decoder(x, xs)
return out, klds
def sample(self, num_samples, t=1.):
with torch.no_grad():
out = self.decoder.sample(num_samples, t)
return out
model = VAE(num_downsamplings=NUM_DOWNSAMPLINGS, latent_reduction=LATENT_REDUCTION, enc_channels=ENCODER_CHANNELS,
image_shape=(IMAGE_SIZE, IMAGE_SIZE), in_channels=IMAGE_CHANNELS)
init_model(model)
model = model.to(DEVICE)
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
print("Number of encoder parameters: {:,}".format(sum(p.numel() for p in model.encoder.parameters())))
print("Number of decoder parameters: {:,}".format(sum(p.numel() for p in model.decoder.parameters())))
Number of parameters: 4,140,052 Number of encoder parameters: 1,573,776 Number of decoder parameters: 2,566,276
class VAELoss(nn.Module):
def __init__(self, β=1.):
super().__init__()
self.reconstruction_loss = nn.L1Loss()
self.β = β
def forward(self, outputs, target):
output, klds = outputs
reconst_loss = self.reconstruction_loss(output, target)
kld_losses = [torch.mean(kld) for kld in klds]
ndims = np.prod(output.shape[1:])
nlatent = np.array([np.prod(kld.shape[1:]) for kld in klds])
kld_weights = self.β * nlatent / ndims
total_kld_loss = sum(weight * kld_loss for weight, kld_loss in zip(kld_weights, kld_losses))
loss = reconst_loss + total_kld_loss
loss_dict = {"loss": loss, "reconstruction loss": reconst_loss}
for num, kld_loss in enumerate(kld_losses):
loss_dict[f"KLD{num}"] = kld_loss
return loss_dict
def plot_batch(ax, batch, title=None, **kwargs):
imgs = vision_utils.make_grid(batch, padding=2, normalize=True)
imgs = np.moveaxis(imgs.numpy(), 0, -1)
ax.set_axis_off()
if title is not None: ax.set_title(title)
return ax.imshow(imgs, **kwargs)
def show_images(batch, title):
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111)
plot_batch(ax, batch, title)
plt.show()
def show_2_batches(batch1, batch2, title1, title2):
fig = plt.figure(figsize=(16, 8))
ax = fig.add_subplot(121)
plot_batch(ax, batch1, title1)
ax = fig.add_subplot(122)
plot_batch(ax, batch2, title2)
plt.show()
def reconstruct(model, batch, device):
batch = batch.to(device)
with torch.no_grad():
reconstructed_batch = model(batch)[0]
reconstructed_batch = reconstructed_batch.cpu()
return reconstructed_batch
train_transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.CenterCrop(IMAGE_SIZE),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
val_transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.CenterCrop(IMAGE_SIZE),
transforms.ToTensor()
])
train_dset = datasets.CelebA(str(DATA_DIR), split='train', transform=train_transform, download=False)
val_dset = datasets.CelebA(str(DATA_DIR), split='test', transform=val_transform, download=False)
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_WORKERS)
val_loader = torch.utils.data.DataLoader(val_dset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS)
test_batch, _ = next(iter(train_loader))
show_images(test_batch[:64], "Training Images")
class AverageLoss():
def __init__(self, name):
self.name = name
self.reset()
def reset(self):
self.num_samples = 0
self.total_loss = 0.
def update(self, data):
batch_size = data['batch_size']
self.num_samples += batch_size
self.total_loss += batch_size * data[self.name]
def compute(self):
avg_loss = self.total_loss / self.num_samples
metrics = {self.name: avg_loss}
return metrics
class Learner:
def __init__(self, model, loss, optimizer, train_loader, val_loader, device,
batch_scheduler=None, epoch_scheduler=None):
self.model = model
self.loss = loss
self.optimizer = optimizer
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.batch_scheduler = batch_scheduler
self.epoch_scheduler = epoch_scheduler
self.history = defaultdict(list)
self.metrics = [AverageLoss(x) for x in ["loss", "reconstruction loss"]]
for num in range(NUM_DOWNSAMPLINGS + 1):
self.metrics.append(AverageLoss(f"KLD{num}"))
def iterate(self, loader, train=False):
for metric in self.metrics:
metric.reset()
for batch in loader:
images = batch[0].to(self.device)
outputs = self.model(images)
losses = self.loss(outputs, images)
if train: self.backward_pass(losses["loss"])
data = {k: v.item() for k, v in losses.items()}
data["batch_size"] = len(images)
for metric in self.metrics:
metric.update(data)
summary = {}
for metric in self.metrics:
summary.update(metric.compute())
return summary
def backward_pass(self, batch_loss):
self.optimizer.zero_grad()
batch_loss.backward()
self.optimizer.step()
if self.batch_scheduler is not None:
self.batch_scheduler.step()
def log_metrics(self, metrics, name):
print(f"{name}: ", end='', flush=True)
for key, val in metrics.items():
self.history[name + ' ' + key].append(val)
print(f"{key} {val:.4f} ", end='')
def train(self):
self.model.train()
metrics = self.iterate(self.train_loader, train=True)
self.log_metrics(metrics, 'train')
def validate(self):
self.model.eval()
with torch.no_grad():
metrics = self.iterate(self.val_loader)
self.log_metrics(metrics, 'val')
def fit(self, epochs):
for epoch in range(1, epochs + 1):
print(f"{epoch}/{epochs}:")
self.train()
print()
self.validate()
print()
if self.epoch_scheduler is not None:
self.epoch_scheduler.step()
torch.save(model.state_dict(), str(MODELS_DIR / 'final_model.pt'))
def plot_history_train_val(history, key):
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['train ' + key]) + 1)
ax.plot(xs, history['train ' + key], '.-', label='train')
ax.plot(xs, history['val ' + key], '.-', label='val')
ax.set_xlabel('epoch')
ax.set_ylabel(key)
ax.grid()
ax.legend()
plt.show()
loss = VAELoss(β=1.)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
steps_per_epoch=len(train_loader), epochs=EPOCHS)
learner = Learner(model, loss, optimizer, train_loader, val_loader, DEVICE, batch_scheduler=lr_scheduler)
learner.fit(EPOCHS)
1/40: train: loss 0.1771 reconstruction loss 0.1010 KLD0 0.1495 KLD1 0.0527 KLD2 0.1716 KLD3 0.1206 KLD4 0.1180 KLD5 0.1096 val: loss 0.0816 reconstruction loss 0.0598 KLD0 0.0609 KLD1 0.0361 KLD2 0.1066 KLD3 0.0455 KLD4 0.0283 KLD5 0.0223 2/40: train: loss 0.0735 reconstruction loss 0.0544 KLD0 0.0480 KLD1 0.0578 KLD2 0.0958 KLD3 0.0547 KLD4 0.0204 KLD5 0.0164 val: loss 0.0661 reconstruction loss 0.0488 KLD0 0.0436 KLD1 0.0859 KLD2 0.0883 KLD3 0.0607 KLD4 0.0143 KLD5 0.0119 3/40: train: loss 0.0634 reconstruction loss 0.0476 KLD0 0.0409 KLD1 0.1057 KLD2 0.0891 KLD3 0.0592 KLD4 0.0102 KLD5 0.0085 val: loss 0.0639 reconstruction loss 0.0485 KLD0 0.0427 KLD1 0.1181 KLD2 0.0922 KLD3 0.0583 KLD4 0.0081 KLD5 0.0073 4/40: train: loss 0.0587 reconstruction loss 0.0449 KLD0 0.0478 KLD1 0.1161 KLD2 0.0907 KLD3 0.0563 KLD4 0.0058 KLD5 0.0041 val: loss 0.0564 reconstruction loss 0.0432 KLD0 0.0542 KLD1 0.1110 KLD2 0.0923 KLD3 0.0556 KLD4 0.0054 KLD5 0.0028 5/40: train: loss 0.0556 reconstruction loss 0.0433 KLD0 0.0615 KLD1 0.1116 KLD2 0.0912 KLD3 0.0534 KLD4 0.0039 KLD5 0.0014 val: loss 0.0514 reconstruction loss 0.0392 KLD0 0.0723 KLD1 0.1169 KLD2 0.0947 KLD3 0.0530 KLD4 0.0030 KLD5 0.0005 6/40: train: loss 0.0528 reconstruction loss 0.0409 KLD0 0.0791 KLD1 0.1124 KLD2 0.0945 KLD3 0.0520 KLD4 0.0027 KLD5 0.0002 val: loss 0.0635 reconstruction loss 0.0498 KLD0 0.0903 KLD1 0.1214 KLD2 0.1028 KLD3 0.0631 KLD4 0.0037 KLD5 0.0000 7/40: train: loss 0.0514 reconstruction loss 0.0397 KLD0 0.0894 KLD1 0.1104 KLD2 0.0977 KLD3 0.0498 KLD4 0.0017 KLD5 0.0000 val: loss 0.0512 reconstruction loss 0.0390 KLD0 0.0952 KLD1 0.1151 KLD2 0.1007 KLD3 0.0532 KLD4 0.0015 KLD5 0.0000 8/40: train: loss 0.0504 reconstruction loss 0.0388 KLD0 0.0961 KLD1 0.1114 KLD2 0.0995 KLD3 0.0485 KLD4 0.0010 KLD5 0.0000 val: loss 0.0516 reconstruction loss 0.0392 KLD0 0.1018 KLD1 0.1169 KLD2 0.1075 KLD3 0.0511 KLD4 0.0006 KLD5 0.0000 9/40: train: loss 0.0498 reconstruction loss 0.0382 KLD0 0.1005 KLD1 0.1142 KLD2 0.1000 KLD3 0.0476 KLD4 0.0003 KLD5 0.0000 val: loss 0.0502 reconstruction loss 0.0385 KLD0 0.1013 KLD1 0.1152 KLD2 0.0986 KLD3 0.0488 KLD4 0.0001 KLD5 0.0000 10/40: train: loss 0.0492 reconstruction loss 0.0376 KLD0 0.1037 KLD1 0.1175 KLD2 0.0998 KLD3 0.0468 KLD4 0.0000 KLD5 0.0000 val: loss 0.0492 reconstruction loss 0.0368 KLD0 0.1049 KLD1 0.1226 KLD2 0.1057 KLD3 0.0516 KLD4 0.0000 KLD5 0.0000 11/40: train: loss 0.0489 reconstruction loss 0.0373 KLD0 0.1069 KLD1 0.1200 KLD2 0.0992 KLD3 0.0458 KLD4 0.0000 KLD5 0.0000 val: loss 0.0481 reconstruction loss 0.0368 KLD0 0.1080 KLD1 0.1189 KLD2 0.1005 KLD3 0.0431 KLD4 0.0000 KLD5 0.0000 12/40: train: loss 0.0487 reconstruction loss 0.0371 KLD0 0.1094 KLD1 0.1217 KLD2 0.0988 KLD3 0.0452 KLD4 0.0000 KLD5 0.0000 val: loss 0.0506 reconstruction loss 0.0388 KLD0 0.1104 KLD1 0.1213 KLD2 0.1012 KLD3 0.0470 KLD4 0.0000 KLD5 0.0000 13/40: train: loss 0.0482 reconstruction loss 0.0366 KLD0 0.1122 KLD1 0.1232 KLD2 0.0990 KLD3 0.0444 KLD4 0.0000 KLD5 0.0000 val: loss 0.0531 reconstruction loss 0.0413 KLD0 0.1120 KLD1 0.1259 KLD2 0.1030 KLD3 0.0448 KLD4 0.0000 KLD5 0.0000 14/40: train: loss 0.0479 reconstruction loss 0.0363 KLD0 0.1144 KLD1 0.1241 KLD2 0.0991 KLD3 0.0438 KLD4 0.0000 KLD5 0.0000 val: loss 0.0479 reconstruction loss 0.0360 KLD0 0.1151 KLD1 0.1236 KLD2 0.1005 KLD3 0.0469 KLD4 0.0000 KLD5 0.0000 15/40: train: loss 0.0476 reconstruction loss 0.0361 KLD0 0.1167 KLD1 0.1250 KLD2 0.0989 KLD3 0.0433 KLD4 0.0000 KLD5 0.0000 val: loss 0.0492 reconstruction loss 0.0380 KLD0 0.1206 KLD1 0.1265 KLD2 0.0952 KLD3 0.0404 KLD4 0.0000 KLD5 0.0000 16/40: train: loss 0.0472 reconstruction loss 0.0357 KLD0 0.1189 KLD1 0.1260 KLD2 0.0992 KLD3 0.0427 KLD4 0.0000 KLD5 0.0000 val: loss 0.0484 reconstruction loss 0.0366 KLD0 0.1209 KLD1 0.1325 KLD2 0.1019 KLD3 0.0425 KLD4 0.0000 KLD5 0.0000 17/40: train: loss 0.0470 reconstruction loss 0.0355 KLD0 0.1205 KLD1 0.1270 KLD2 0.0994 KLD3 0.0420 KLD4 0.0000 KLD5 0.0000 val: loss 0.0488 reconstruction loss 0.0369 KLD0 0.1198 KLD1 0.1276 KLD2 0.1025 KLD3 0.0445 KLD4 0.0000 KLD5 0.0000 18/40: train: loss 0.0467 reconstruction loss 0.0352 KLD0 0.1223 KLD1 0.1276 KLD2 0.0996 KLD3 0.0415 KLD4 0.0000 KLD5 0.0000 val: loss 0.0479 reconstruction loss 0.0363 KLD0 0.1227 KLD1 0.1250 KLD2 0.0985 KLD3 0.0425 KLD4 0.0000 KLD5 0.0000 19/40: train: loss 0.0466 reconstruction loss 0.0351 KLD0 0.1238 KLD1 0.1278 KLD2 0.0994 KLD3 0.0411 KLD4 0.0000 KLD5 0.0000 val: loss 0.0527 reconstruction loss 0.0409 KLD0 0.1250 KLD1 0.1277 KLD2 0.1023 KLD3 0.0425 KLD4 0.0000 KLD5 0.0000 20/40: train: loss 0.0465 reconstruction loss 0.0349 KLD0 0.1254 KLD1 0.1281 KLD2 0.0994 KLD3 0.0407 KLD4 0.0000 KLD5 0.0000 val: loss 0.0514 reconstruction loss 0.0399 KLD0 0.1232 KLD1 0.1279 KLD2 0.0989 KLD3 0.0414 KLD4 0.0000 KLD5 0.0000 21/40: train: loss 0.0463 reconstruction loss 0.0348 KLD0 0.1269 KLD1 0.1285 KLD2 0.0995 KLD3 0.0403 KLD4 0.0000 KLD5 0.0000 val: loss 0.0478 reconstruction loss 0.0361 KLD0 0.1286 KLD1 0.1292 KLD2 0.1018 KLD3 0.0419 KLD4 0.0000 KLD5 -0.0000 22/40: train: loss 0.0462 reconstruction loss 0.0347 KLD0 0.1281 KLD1 0.1286 KLD2 0.0993 KLD3 0.0400 KLD4 0.0000 KLD5 0.0000 val: loss 0.0459 reconstruction loss 0.0342 KLD0 0.1275 KLD1 0.1304 KLD2 0.1018 KLD3 0.0402 KLD4 0.0000 KLD5 0.0000 23/40: train: loss 0.0460 reconstruction loss 0.0345 KLD0 0.1299 KLD1 0.1293 KLD2 0.0994 KLD3 0.0395 KLD4 0.0000 KLD5 0.0000 val: loss 0.0474 reconstruction loss 0.0360 KLD0 0.1282 KLD1 0.1256 KLD2 0.0971 KLD3 0.0403 KLD4 0.0000 KLD5 0.0000 24/40: train: loss 0.0459 reconstruction loss 0.0344 KLD0 0.1312 KLD1 0.1295 KLD2 0.0993 KLD3 0.0392 KLD4 0.0000 KLD5 0.0000 val: loss 0.0463 reconstruction loss 0.0348 KLD0 0.1310 KLD1 0.1292 KLD2 0.1011 KLD3 0.0389 KLD4 0.0000 KLD5 0.0000 25/40: train: loss 0.0457 reconstruction loss 0.0342 KLD0 0.1327 KLD1 0.1301 KLD2 0.0994 KLD3 0.0387 KLD4 0.0000 KLD5 0.0000 val: loss 0.0473 reconstruction loss 0.0358 KLD0 0.1349 KLD1 0.1318 KLD2 0.0973 KLD3 0.0395 KLD4 0.0000 KLD5 0.0000 26/40: train: loss 0.0456 reconstruction loss 0.0341 KLD0 0.1340 KLD1 0.1302 KLD2 0.0993 KLD3 0.0385 KLD4 0.0000 KLD5 0.0000 val: loss 0.0456 reconstruction loss 0.0342 KLD0 0.1348 KLD1 0.1307 KLD2 0.1004 KLD3 0.0376 KLD4 0.0000 KLD5 -0.0000 27/40: train: loss 0.0454 reconstruction loss 0.0340 KLD0 0.1355 KLD1 0.1305 KLD2 0.0991 KLD3 0.0381 KLD4 0.0000 KLD5 0.0000 val: loss 0.0452 reconstruction loss 0.0340 KLD0 0.1349 KLD1 0.1278 KLD2 0.0986 KLD3 0.0368 KLD4 0.0000 KLD5 0.0000 28/40: train: loss 0.0453 reconstruction loss 0.0339 KLD0 0.1367 KLD1 0.1310 KLD2 0.0990 KLD3 0.0378 KLD4 0.0000 KLD5 0.0000 val: loss 0.0503 reconstruction loss 0.0382 KLD0 0.1359 KLD1 0.1301 KLD2 0.0991 KLD3 0.0454 KLD4 0.0000 KLD5 0.0000 29/40: train: loss 0.0452 reconstruction loss 0.0337 KLD0 0.1381 KLD1 0.1313 KLD2 0.0989 KLD3 0.0376 KLD4 0.0000 KLD5 0.0000 val: loss 0.0460 reconstruction loss 0.0341 KLD0 0.1371 KLD1 0.1298 KLD2 0.1031 KLD3 0.0408 KLD4 0.0000 KLD5 0.0000 30/40: train: loss 0.0450 reconstruction loss 0.0336 KLD0 0.1394 KLD1 0.1317 KLD2 0.0986 KLD3 0.0372 KLD4 0.0000 KLD5 0.0000 val: loss 0.0448 reconstruction loss 0.0333 KLD0 0.1390 KLD1 0.1294 KLD2 0.0991 KLD3 0.0383 KLD4 0.0000 KLD5 0.0000 31/40: train: loss 0.0449 reconstruction loss 0.0335 KLD0 0.1409 KLD1 0.1321 KLD2 0.0986 KLD3 0.0369 KLD4 0.0000 KLD5 0.0000 val: loss 0.0456 reconstruction loss 0.0342 KLD0 0.1407 KLD1 0.1310 KLD2 0.0974 KLD3 0.0378 KLD4 0.0000 KLD5 0.0000 32/40: train: loss 0.0448 reconstruction loss 0.0334 KLD0 0.1420 KLD1 0.1323 KLD2 0.0984 KLD3 0.0366 KLD4 0.0000 KLD5 0.0000 val: loss 0.0448 reconstruction loss 0.0334 KLD0 0.1421 KLD1 0.1337 KLD2 0.0980 KLD3 0.0360 KLD4 0.0000 KLD5 0.0000 33/40: train: loss 0.0446 reconstruction loss 0.0332 KLD0 0.1433 KLD1 0.1327 KLD2 0.0982 KLD3 0.0364 KLD4 0.0000 KLD5 0.0000 val: loss 0.0450 reconstruction loss 0.0337 KLD0 0.1439 KLD1 0.1311 KLD2 0.0988 KLD3 0.0360 KLD4 0.0000 KLD5 0.0000 34/40: train: loss 0.0445 reconstruction loss 0.0332 KLD0 0.1445 KLD1 0.1329 KLD2 0.0981 KLD3 0.0361 KLD4 0.0000 KLD5 0.0000 val: loss 0.0448 reconstruction loss 0.0334 KLD0 0.1447 KLD1 0.1323 KLD2 0.0983 KLD3 0.0358 KLD4 0.0000 KLD5 -0.0000 35/40: train: loss 0.0444 reconstruction loss 0.0330 KLD0 0.1456 KLD1 0.1334 KLD2 0.0979 KLD3 0.0359 KLD4 0.0000 KLD5 0.0000 val: loss 0.0442 reconstruction loss 0.0329 KLD0 0.1453 KLD1 0.1325 KLD2 0.0979 KLD3 0.0356 KLD4 -0.0000 KLD5 0.0000 36/40: train: loss 0.0443 reconstruction loss 0.0330 KLD0 0.1464 KLD1 0.1335 KLD2 0.0978 KLD3 0.0357 KLD4 0.0000 KLD5 0.0000 val: loss 0.0441 reconstruction loss 0.0328 KLD0 0.1467 KLD1 0.1330 KLD2 0.0975 KLD3 0.0352 KLD4 -0.0000 KLD5 0.0000 37/40: train: loss 0.0442 reconstruction loss 0.0329 KLD0 0.1472 KLD1 0.1338 KLD2 0.0977 KLD3 0.0356 KLD4 -0.0000 KLD5 -0.0000 val: loss 0.0439 reconstruction loss 0.0325 KLD0 0.1473 KLD1 0.1333 KLD2 0.0976 KLD3 0.0353 KLD4 -0.0000 KLD5 0.0000 38/40: train: loss 0.0442 reconstruction loss 0.0328 KLD0 0.1477 KLD1 0.1339 KLD2 0.0977 KLD3 0.0355 KLD4 -0.0000 KLD5 -0.0000 val: loss 0.0438 reconstruction loss 0.0325 KLD0 0.1477 KLD1 0.1327 KLD2 0.0974 KLD3 0.0352 KLD4 -0.0000 KLD5 -0.0000 39/40: train: loss 0.0442 reconstruction loss 0.0328 KLD0 0.1480 KLD1 0.1339 KLD2 0.0976 KLD3 0.0355 KLD4 -0.0000 KLD5 -0.0000 val: loss 0.0438 reconstruction loss 0.0325 KLD0 0.1474 KLD1 0.1325 KLD2 0.0979 KLD3 0.0355 KLD4 -0.0000 KLD5 -0.0000 40/40: train: loss 0.0441 reconstruction loss 0.0328 KLD0 0.1481 KLD1 0.1339 KLD2 0.0976 KLD3 0.0354 KLD4 -0.0000 KLD5 0.0000 val: loss 0.0438 reconstruction loss 0.0325 KLD0 0.1473 KLD1 0.1338 KLD2 0.0970 KLD3 0.0350 KLD4 -0.0000 KLD5 0.0000
plot_history_train_val(learner.history, 'loss')
plot_history_train_val(learner.history, 'reconstruction loss')
for num in range(NUM_DOWNSAMPLINGS + 1):
plot_history_train_val(learner.history, f'KLD{num}')
model.load_state_dict(torch.load(str(MODELS_DIR / 'final_model.pt')))
<All keys matched successfully>
model.eval();
Testing reconstruction
test_batch, _ = next(iter(val_loader))
reconstructed_batch = reconstruct(model, test_batch, DEVICE)
show_2_batches(test_batch[:64], reconstructed_batch[:64], "Validation Images", "Reconstructed Images")
Testing generation
sample_images = model.sample(64)
show_images(sample_images.cpu(), "Sample Images")