Latent variables are split into $L$ layers. Each layer has a codebook consisting of $K_i$ embedding vectors $\boldsymbol{e}_{i,j}\in R^{D_i}$, $j=1,2,\ldots,K_i$. Posterior categorical distribution of discrete latent variables is $$ q(k_i|k_{i<},\boldsymbol{x})=\delta_{k_i,k_i^{*}}\,, $$ where $$ k_i^{*}=\mathrm{argmin}_{j}\left\Vert q_{\boldsymbol{\phi}}(\boldsymbol{z}_i|\boldsymbol{z}_{i<},\boldsymbol{x})-\boldsymbol{e}_{i,j}\right\Vert _{2}\,. $$ The Kullback-Leibler divergence is $$ D_{\mathrm{KL}}\big(q(k_i|k_{i<},\boldsymbol{x})||p_{\boldsymbol{\theta}}(k_i|k_{i<})\big)= \sum_{k=1}^{K_i}q(k_i|k_{i<},\boldsymbol{x})\log\left(\frac{q(k_i|k_{i<},\boldsymbol{x})}{p_{\boldsymbol{\theta}}(k_i|k_{i<})}\right)= -\log(p_{\boldsymbol{\theta}}(k_i^{*}|k_{i<}))\,. $$ We see that Kullback-Leibler divergence is the same as cross-entropy loss for the prior.
Imports
from pathlib import Path
from functools import partial
from collections import defaultdict
import math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as distributions
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision.utils as vision_utils
Configuration
DATA_DIR = Path("./data")
MODELS_DIR = Path("./models")
NUM_WORKERS = 8
BATCH_SIZE = 64
IMAGE_SIZE = 128
IMAGE_CHANNELS = 3
ENCODER_CHANNELS = 16
LATENT_REDUCTION = 16 # ratio of the number of feature channels to the number of latent channels
NUM_DOWNSAMPLINGS = 5
NUM_EMBEDDINGS = 128
EPOCHS = 10
LEARNING_RATE = 1e-2
WEIGHT_DECAY = 1e-2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ConvBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, act=True):
padding = (kernel_size - 1) // 2
layers = [
nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm2d(out_channels)
]
if act: layers.append(nn.ReLU(inplace=True))
super().__init__(*layers)
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_normal_(self[0].weight)
bn = self[1]
nn.init.ones_(bn.weight)
nn.init.zeros_(bn.bias)
Vector Quantizer
class VectorQuantizer(nn.Module):
def __init__(self, num_embeddings, embedding_dim, commitment_cost = 0.25):
super().__init__()
self.commitment_cost = commitment_cost
self.num_embeddings = num_embeddings
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
self.reset_parameters()
def forward(self, latents):
# Compute L2 distances between latents and embedding weights
dist = (torch.sum(latents**2, dim=1).unsqueeze(-1) +
torch.sum(self.embedding.weight**2, dim=1) -
2 * torch.matmul(latents.movedim(1, -1), self.embedding.weight.T))
# Alternative implementation:
# dist = torch.linalg.vector_norm(latents.movedim(1, -1).unsqueeze(-2) - self.embedding.weight, dim=-1)
encoding_inds = torch.argmin(dist, dim=-1) # Get the number of the nearest codebook vector
quantized_latents = self.quantize(encoding_inds) # Quantize the latents
# Compute the VQ Losses
codebook_loss = F.mse_loss(latents.detach(), quantized_latents)
commitment_loss = F.mse_loss(latents, quantized_latents.detach())
vq_loss = codebook_loss + self.commitment_cost * commitment_loss
# Make the gradient with respect to latents be equal to the gradient with respect to quantized latents
quantized_latents = latents + (quantized_latents - latents).detach()
return quantized_latents, encoding_inds, vq_loss
def quantize(self, encoding_indices):
z = self.embedding(encoding_indices)
z = z.movedim(-1, 1) # Move channels back
return z
def reset_parameters(self):
nn.init.uniform_(self.embedding.weight, -1 / self.num_embeddings, 1 / self.num_embeddings)
VAE for skip connections
class VAELayer(nn.Module):
def __init__(self, channels, latent_reduction, num_embeddings):
super().__init__()
latent_channels = channels // latent_reduction
self.to_prior = nn.Sequential(nn.ReLU(inplace=False),
ConvBlock(channels, channels, 3),
nn.Conv2d(channels, num_embeddings, 1))
self.encoder = ConvBlock(channels, latent_channels, 3, act=False)
self.vq = VectorQuantizer(num_embeddings, latent_channels)
self.decoder = ConvBlock(latent_channels, channels, 3)
self.ce = nn.CrossEntropyLoss(reduction='none')
self.γ = nn.Parameter(torch.zeros(1))
self.act = nn.ReLU(inplace=True)
def forward(self, x, x_prior):
logits_prior = self.to_prior(x_prior)
latents = self.encoder(x)
z, encoding_inds, vq_loss = self.vq(latents)
kld = self.ce(logits_prior, encoding_inds)
z = self.decoder(z)
out = z + self.γ * x_prior
out = self.act(out)
return out, vq_loss, kld
def sample(self, x_prior, t=1.):
logits_prior = self.to_prior(x_prior) / t
dist = distributions.Categorical(logits=logits_prior.movedim(1, -1))
ind = dist.sample()
z = self.vq.quantize(ind)
z = self.decoder(z)
out = z + self.γ * x_prior
out = self.act(out)
return out
class DownBlock(nn.Sequential):
def __init__(self, in_channels, out_channels):
super().__init__(
nn.MaxPool2d(2),
ConvBlock(in_channels, out_channels, 3)
)
class Encoder(nn.Module):
def __init__(self, in_channels, channels, num_downsamplings):
super().__init__()
self.stem = ConvBlock(in_channels, channels, 3)
self.blocks = nn.ModuleList()
in_channels = channels
for _ in range(num_downsamplings):
out_channels = in_channels * 2
self.blocks.append(DownBlock(in_channels, out_channels))
in_channels = out_channels
def forward(self, x):
x = self.stem(x)
xs = []
for block in self.blocks:
xs.append(x)
x = block(x)
return x, xs
class DecoderStem(nn.Module):
def __init__(self, channels, latent_reduction, num_embeddings, shape):
super().__init__()
self.vae = VAELayer(channels, latent_reduction, num_embeddings)
self.prior = nn.Parameter(torch.full((1, channels, *shape), 0.01))
def forward(self, x):
out, vq_loss, kld = self.vae(x, self.prior.expand(x.size(0), -1, -1, -1))
return out, vq_loss, kld
def sample(self, num_samples, t=1.):
out = self.vae.sample(self.prior.expand(num_samples, -1, -1, -1), t)
return out
class UpBlock(nn.Sequential):
def __init__(self, in_channels, out_channels):
super().__init__(
nn.Upsample(scale_factor=2, mode='nearest'),
ConvBlock(in_channels, out_channels, 3, act=False)
)
class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, latent_reduction, num_embeddings):
super().__init__()
self.up = UpBlock(in_channels, out_channels)
self.vae = VAELayer(out_channels, latent_reduction, num_embeddings)
def forward(self, x, skip):
x = self.up(x)
out, vq_loss, kld = self.vae(skip, x)
return out, vq_loss, kld
def sample(self, x, t=1.):
x = self.up(x)
out = self.vae.sample(x, t)
return out
class DecoderHead(nn.Sequential):
def __init__(self, in_channels, out_channels):
super().__init__(
ConvBlock(in_channels, out_channels, 3, act=False),
nn.Sigmoid()
)
class Decoder(nn.Module):
def __init__(self, in_channels, out_channels, latent_reduction, num_upsamplings, num_embeddings, shape):
super().__init__()
self.stem = DecoderStem(in_channels, latent_reduction, num_embeddings, shape)
self.blocks = nn.ModuleList()
for _ in range(num_upsamplings):
channels = in_channels // 2
self.blocks.append(DecoderBlock(in_channels, channels, latent_reduction, num_embeddings))
in_channels = channels
self.head = DecoderHead(channels, out_channels)
def forward(self, x, xs):
vq_losses = []
klds = []
x, vql, kld = self.stem(x)
vq_losses.append(vql)
klds.append(kld)
for block, skip in zip(self.blocks, reversed(xs)):
x, vql, kld = block(x, skip)
vq_losses.append(vql)
klds.append(kld)
x = self.head(x)
return x, vq_losses, klds
def sample(self, num_samples, t=1.):
z = self.stem.sample(num_samples, t)
for block in self.blocks:
z = block.sample(z, t)
z = self.head(z)
return z
class VQVAE(nn.Module):
def __init__(self, num_downsamplings, latent_reduction, enc_channels, num_embeddings, image_shape, in_channels=3):
super().__init__()
reduction = 2**num_downsamplings
shape = (image_shape[0] // reduction, image_shape[1] // reduction)
self.encoder = Encoder(in_channels, enc_channels, num_downsamplings)
self.decoder = Decoder(enc_channels * reduction, in_channels, latent_reduction, num_downsamplings, num_embeddings, shape)
def forward(self, x):
x, xs = self.encoder(x)
out, vq_losses, klds = self.decoder(x, xs)
return out, vq_losses, klds
def sample(self, num_samples, t=1.):
with torch.no_grad():
out = self.decoder.sample(num_samples, t)
return out
model = VQVAE(num_downsamplings=NUM_DOWNSAMPLINGS, latent_reduction=LATENT_REDUCTION, enc_channels=ENCODER_CHANNELS,
num_embeddings=NUM_EMBEDDINGS, image_shape=(IMAGE_SIZE, IMAGE_SIZE), in_channels=IMAGE_CHANNELS)
model = model.to(DEVICE)
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
print("Number of encoder parameters: {:,}".format(sum(p.numel() for p in model.encoder.parameters())))
print("Number of decoder parameters: {:,}".format(sum(p.numel() for p in model.decoder.parameters())))
Number of parameters: 6,834,826 Number of encoder parameters: 1,573,776 Number of decoder parameters: 5,261,050
class VAELoss(nn.Module):
def __init__(self, λ_vq=1., λ_kld=1.):
super().__init__()
self.reconstruction_loss = nn.L1Loss() # nn.MSELoss()
self.λ_vq = λ_vq
self.λ_kld = λ_kld
def forward(self, outputs, target):
output, vq_losses, klds = outputs
reconst_loss = self.reconstruction_loss(output, target)
total_vq_loss = sum(vq_losses)
kld_losses = [torch.mean(kld) for kld in klds]
ndims = np.prod(output.shape[1:])
nlatent = np.array([np.prod(kld.shape[1:]) for kld in klds])
kld_weights = nlatent / ndims
total_kld_loss = sum(weight * kld_loss for weight, kld_loss in zip(kld_weights, kld_losses))
loss = reconst_loss + self.λ_vq * total_vq_loss + self.λ_kld * total_kld_loss
loss_dict = {"loss": loss, "reconstruction": reconst_loss}
for num, (vq_loss, kld_loss) in enumerate(zip(vq_losses, kld_losses)):
loss_dict[f"VQ{num}"] = vq_loss
loss_dict[f"KLD{num}"] = kld_loss
return loss_dict
def plot_batch(ax, batch, title=None, **kwargs):
imgs = vision_utils.make_grid(batch, padding=2, normalize=True)
imgs = np.moveaxis(imgs.numpy(), 0, -1)
ax.set_axis_off()
if title is not None: ax.set_title(title)
return ax.imshow(imgs, **kwargs)
def show_images(batch, title):
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111)
plot_batch(ax, batch, title)
plt.show()
def show_2_batches(batch1, batch2, title1, title2):
fig = plt.figure(figsize=(16, 8))
ax = fig.add_subplot(121)
plot_batch(ax, batch1, title1)
ax = fig.add_subplot(122)
plot_batch(ax, batch2, title2)
plt.show()
def reconstruct(model, batch, device):
batch = batch.to(device)
with torch.no_grad():
reconstructed_batch = model(batch)[0]
reconstructed_batch = reconstructed_batch.cpu()
return reconstructed_batch
train_transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.CenterCrop(IMAGE_SIZE),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
val_transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.CenterCrop(IMAGE_SIZE),
transforms.ToTensor()
])
train_dset = datasets.CelebA(str(DATA_DIR), split='train', transform=train_transform, download=False)
val_dset = datasets.CelebA(str(DATA_DIR), split='test', transform=val_transform, download=False)
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_WORKERS)
val_loader = torch.utils.data.DataLoader(val_dset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS)
test_batch, _ = next(iter(train_loader))
show_images(test_batch[:64], "Training Images")
class AverageLoss():
def __init__(self, name):
self.name = name
self.reset()
def reset(self):
self.num_samples = 0
self.total_loss = 0.
def update(self, data):
batch_size = data['batch_size']
self.num_samples += batch_size
self.total_loss += batch_size * data[self.name]
def compute(self):
avg_loss = self.total_loss / self.num_samples
metrics = {self.name: avg_loss}
return metrics
class Learner:
def __init__(self, model, loss, optimizer, train_loader, val_loader, device,
batch_scheduler=None, epoch_scheduler=None):
self.model = model
self.loss = loss
self.optimizer = optimizer
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.batch_scheduler = batch_scheduler
self.epoch_scheduler = epoch_scheduler
self.history = defaultdict(list)
self.metrics = [AverageLoss(x) for x in ["loss", "reconstruction"]]
for num in range(NUM_DOWNSAMPLINGS + 1):
self.metrics.append(AverageLoss(f"VQ{num}"))
self.metrics.append(AverageLoss(f"KLD{num}"))
def iterate(self, loader, train=False):
for metric in self.metrics:
metric.reset()
for batch in loader:
images = batch[0].to(self.device)
outputs = self.model(images)
losses = self.loss(outputs, images)
if train: self.backward_pass(losses["loss"])
data = {k: v.item() for k, v in losses.items()}
data["batch_size"] = len(images)
for metric in self.metrics:
metric.update(data)
summary = {}
for metric in self.metrics:
summary.update(metric.compute())
return summary
def backward_pass(self, batch_loss):
self.optimizer.zero_grad()
batch_loss.backward()
self.optimizer.step()
if self.batch_scheduler is not None:
self.batch_scheduler.step()
def log_metrics(self, metrics, name):
print(f"{name}: ", end='', flush=True)
newline = False
for key, val in metrics.items():
self.history[name + ' ' + key].append(val)
print(f"{key} {val:.4f} ", end='\n' if newline else '')
newline = not newline
def train(self):
self.model.train()
metrics = self.iterate(self.train_loader, train=True)
self.log_metrics(metrics, 'train')
def validate(self):
self.model.eval()
with torch.no_grad():
metrics = self.iterate(self.val_loader)
self.log_metrics(metrics, 'val')
def fit(self, epochs):
for epoch in range(1, epochs + 1):
print(f"{epoch}/{epochs}:")
self.train()
print()
self.validate()
print()
if self.epoch_scheduler is not None:
self.epoch_scheduler.step()
torch.save(model.state_dict(), str(MODELS_DIR / 'final_model.pt'))
def plot_history_train_val(history, key):
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['train ' + key]) + 1)
ax.plot(xs, history['train ' + key], '.-', label='train')
ax.plot(xs, history['val ' + key], '.-', label='val')
ax.set_xlabel('epoch')
ax.set_ylabel(key)
ax.grid()
ax.legend()
plt.show()
loss = VAELoss(λ_kld=1.)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
steps_per_epoch=len(train_loader), epochs=EPOCHS)
learner = Learner(model, loss, optimizer, train_loader, val_loader, DEVICE, batch_scheduler=lr_scheduler)
learner.fit(EPOCHS)
1/10: train: loss 2.4117 reconstruction 0.0708 VQ0 0.1906 KLD0 1.3742 VQ1 0.2228 KLD1 1.8777 VQ2 0.2260 KLD2 1.8539 VQ3 0.2435 KLD3 1.6117 VQ4 0.2054 KLD4 2.3494 VQ5 0.3219 KLD5 2.0666 val: loss 1.1344 reconstruction 0.0295 VQ0 0.0006 KLD0 2.1680 VQ1 0.0075 KLD1 3.3566 VQ2 0.0148 KLD2 3.1963 VQ3 0.0172 KLD3 2.7910 VQ4 0.0089 KLD4 2.4617 VQ5 0.0001 KLD5 2.3121 2/10: train: loss 1.1344 reconstruction 0.0323 VQ0 0.0002 KLD0 3.4017 VQ1 0.0052 KLD1 3.8634 VQ2 0.0126 KLD2 3.2622 VQ3 0.0270 KLD3 3.0208 VQ4 0.0329 KLD4 2.5932 VQ5 0.0002 KLD5 2.1653 val: loss 1.1733 reconstruction 0.0458 VQ0 0.0001 KLD0 3.8078 VQ1 0.0045 KLD1 3.9697 VQ2 0.0129 KLD2 3.3749 VQ3 0.0356 KLD3 3.0983 VQ4 0.0494 KLD4 2.6788 VQ5 0.0001 KLD5 2.1396 3/10: train: loss 1.1795 reconstruction 0.0306 VQ0 0.0001 KLD0 3.7786 VQ1 0.0044 KLD1 3.8532 VQ2 0.0113 KLD2 3.3692 VQ3 0.0394 KLD3 3.0674 VQ4 0.0614 KLD4 2.6834 VQ5 0.0001 KLD5 2.1630 val: loss 1.1999 reconstruction 0.0404 VQ0 0.0000 KLD0 3.8461 VQ1 0.0083 KLD1 3.9766 VQ2 0.0160 KLD2 3.6928 VQ3 0.0436 KLD3 3.1772 VQ4 0.0693 KLD4 2.6736 VQ5 0.0001 KLD5 2.1229 4/10: train: loss 1.1901 reconstruction 0.0273 VQ0 0.0000 KLD0 3.7572 VQ1 0.0048 KLD1 3.8406 VQ2 0.0124 KLD2 3.3442 VQ3 0.0443 KLD3 2.9983 VQ4 0.0808 KLD4 2.6673 VQ5 0.0001 KLD5 2.1360 val: loss 1.2644 reconstruction 0.0415 VQ0 0.0000 KLD0 3.7596 VQ1 0.0061 KLD1 3.9554 VQ2 0.0134 KLD2 3.4190 VQ3 0.0485 KLD3 3.1371 VQ4 0.0919 KLD4 2.7719 VQ5 0.0001 KLD5 2.2271 5/10: train: loss 1.1924 reconstruction 0.0259 VQ0 0.0000 KLD0 3.7586 VQ1 0.0052 KLD1 3.8576 VQ2 0.0135 KLD2 3.3073 VQ3 0.0468 KLD3 2.9630 VQ4 0.1004 KLD4 2.7068 VQ5 0.0001 KLD5 2.0692 val: loss 1.4005 reconstruction 0.0278 VQ0 0.0000 KLD0 3.7296 VQ1 0.0093 KLD1 4.0881 VQ2 0.0208 KLD2 3.6042 VQ3 0.0656 KLD3 3.1830 VQ4 0.1038 KLD4 2.9225 VQ5 0.0001 KLD5 2.5142 6/10: train: loss 1.2022 reconstruction 0.0253 VQ0 0.0001 KLD0 3.7480 VQ1 0.0055 KLD1 3.8617 VQ2 0.0144 KLD2 3.2894 VQ3 0.0476 KLD3 2.9636 VQ4 0.1065 KLD4 2.7472 VQ5 0.0001 KLD5 2.0660 val: loss 1.3441 reconstruction 0.0291 VQ0 0.0001 KLD0 3.7801 VQ1 0.0058 KLD1 3.9030 VQ2 0.0156 KLD2 3.4361 VQ3 0.0489 KLD3 3.0342 VQ4 0.1063 KLD4 2.7525 VQ5 0.0001 KLD5 2.4645 7/10: train: loss 1.1900 reconstruction 0.0246 VQ0 0.0001 KLD0 3.7402 VQ1 0.0060 KLD1 3.8588 VQ2 0.0154 KLD2 3.2641 VQ3 0.0493 KLD3 2.9599 VQ4 0.1107 KLD4 2.7596 VQ5 0.0001 KLD5 2.0073 val: loss 1.2423 reconstruction 0.0280 VQ0 0.0001 KLD0 3.6870 VQ1 0.0064 KLD1 3.8951 VQ2 0.0148 KLD2 3.3283 VQ3 0.0489 KLD3 3.0313 VQ4 0.1159 KLD4 2.8151 VQ5 0.0001 KLD5 2.1207 8/10: train: loss 1.1782 reconstruction 0.0240 VQ0 0.0001 KLD0 3.7342 VQ1 0.0066 KLD1 3.8475 VQ2 0.0167 KLD2 3.2387 VQ3 0.0515 KLD3 2.9604 VQ4 0.1138 KLD4 2.7716 VQ5 0.0001 KLD5 1.9489 val: loss 1.2094 reconstruction 0.0223 VQ0 0.0001 KLD0 3.7249 VQ1 0.0098 KLD1 3.8730 VQ2 0.0212 KLD2 3.2905 VQ3 0.0595 KLD3 3.0230 VQ4 0.1227 KLD4 2.8599 VQ5 0.0001 KLD5 1.9469 9/10: train: loss 1.1730 reconstruction 0.0234 VQ0 0.0001 KLD0 3.7385 VQ1 0.0073 KLD1 3.8387 VQ2 0.0180 KLD2 3.2102 VQ3 0.0539 KLD3 2.9612 VQ4 0.1170 KLD4 2.7848 VQ5 0.0001 KLD5 1.9093 val: loss 1.1597 reconstruction 0.0211 VQ0 0.0001 KLD0 3.7650 VQ1 0.0067 KLD1 3.8191 VQ2 0.0175 KLD2 3.1914 VQ3 0.0530 KLD3 2.9719 VQ4 0.1175 KLD4 2.8129 VQ5 0.0001 KLD5 1.8740 10/10: train: loss 1.1708 reconstruction 0.0231 VQ0 0.0001 KLD0 3.7462 VQ1 0.0076 KLD1 3.8386 VQ2 0.0188 KLD2 3.1996 VQ3 0.0550 KLD3 2.9657 VQ4 0.1182 KLD4 2.7927 VQ5 0.0001 KLD5 1.8916 val: loss 1.1524 reconstruction 0.0207 VQ0 0.0001 KLD0 3.7456 VQ1 0.0071 KLD1 3.8184 VQ2 0.0180 KLD2 3.1951 VQ3 0.0535 KLD3 2.9655 VQ4 0.1147 KLD4 2.7785 VQ5 0.0001 KLD5 1.8662
plot_history_train_val(learner.history, 'loss')
plot_history_train_val(learner.history, 'reconstruction')
for num in range(NUM_DOWNSAMPLINGS + 1):
plot_history_train_val(learner.history, f'VQ{num}')
for num in range(NUM_DOWNSAMPLINGS + 1):
plot_history_train_val(learner.history, f'KLD{num}')
model.load_state_dict(torch.load(str(MODELS_DIR / 'final_model.pt')))
<All keys matched successfully>
model.eval();
Testing reconstruction
test_batch, _ = next(iter(val_loader))
reconstructed_batch = reconstruct(model, test_batch, DEVICE)
show_2_batches(test_batch[:64], reconstructed_batch[:64], "Validation Images", "Reconstructed Images")
Testing generation
sample_images = model.sample(64, t=0.3)
show_images(sample_images.cpu(), "Sample Images")