VQ-VAE, arXiv:1711.00937 [cs.LG]
VQ-VAE-2, arXiv:1906.00446 [cs.LG]
Encoder output is $q_{\boldsymbol{\phi}}(\boldsymbol{z}|\boldsymbol{x})$. We have a codebook consisting of $K$ embedding vectors $\boldsymbol{e}_{j}\in R^{D}$, $j=1,2,\ldots,K$. Posterior categorical distribution of discrete latent variables is $$ q(k|\boldsymbol{x})=\delta_{k,k^{*}}\,, $$ where $$ k^{*}=\mathrm{argmin}_{j}\left\Vert q_{\boldsymbol{\phi}}(\boldsymbol{z}|\boldsymbol{x})-\boldsymbol{e}_{j}\right\Vert _{2}\,. $$ Prior distribution $p(k)=\frac{1}{K}$. The Kullback-Leibler divergence is constant: $$ D_{\mathrm{KL}}(q(k|\boldsymbol{x})||p(k))=\sum_{k=1}^{K}q(k|\boldsymbol{x})\log\left(\frac{q(k|\boldsymbol{x})}{p(k)}\right)=-\log(p(k^{*}))=\log K\,. $$ Decoder input is $\boldsymbol{z}_{q}=\boldsymbol{e}_{k^{*}}$, decoder output $p_{\boldsymbol{\theta}}(\boldsymbol{x}|\boldsymbol{z}_{q})$. During learning we copy the gradients from the decoder input $\boldsymbol{z}_{q}$ to the encoder output $q_{\boldsymbol{\phi}}(\boldsymbol{z}|\boldsymbol{x})$. Total loss: $$ \mathcal{L}=-\log p_{\boldsymbol{\theta}}(\boldsymbol{x}|\boldsymbol{z}_{q}) +\left\Vert \mathrm{sg}[q_{\boldsymbol{\phi}}(\boldsymbol{z}|\boldsymbol{x})]-\boldsymbol{e}_{k^{*}}\right\Vert _{2}^{2} +\beta\left\Vert q_{\boldsymbol{\phi}}(\boldsymbol{z}|\boldsymbol{x})-\mathrm{sg}[\boldsymbol{e}_{k^{*}}]\right\Vert _{2}^{2}\,. $$ Here $\mathrm{sg}[\cdot]$ is the stopgradient operator. The first term is the reconstruction loss, the second is the codebook loss, the third term is the commitment loss.
Instead of embedding loss we can set the embedding vectors to the moving averages of encoder outputs $q_{\boldsymbol{\phi}}(\boldsymbol{z}|\boldsymbol{x})$ that are closest to the embedding vector.
Imports
from pathlib import Path
from functools import partial
from collections import defaultdict
import math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision.utils as vision_utils
Configuration
DATA_DIR = Path("./data")
MODELS_DIR = Path("./models")
NUM_WORKERS = 8
BATCH_SIZE = 128
IMAGE_SIZE = 128
IMAGE_CHANNELS = 3
ENCODER_CHANNELS = 16
LATENT_CHANNELS = 16
NUM_DOWNSAMPLINGS = 5
NUM_EMBEDDINGS = 256
EPOCHS = 20
LEARNING_RATE = 1e-2
WEIGHT_DECAY = 1e-2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class VectorQuantizer(nn.Module):
def __init__(self, num_embeddings, embedding_dim, commitment_cost = 0.25):
super().__init__()
self.commitment_cost = commitment_cost
self.num_embeddings = num_embeddings
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
self.reset_parameters()
def forward(self, latents):
# Compute L2 distances between latents and embedding weights
dist = torch.linalg.vector_norm(latents.movedim(1, -1).unsqueeze(-2) - self.embedding.weight, dim=-1)
encoding_inds = torch.argmin(dist, dim=-1) # Get the number of the nearest codebook vector
quantized_latents = self.quantize(encoding_inds) # Quantize the latents
# Compute the VQ Losses
codebook_loss = F.mse_loss(latents.detach(), quantized_latents)
commitment_loss = F.mse_loss(latents, quantized_latents.detach())
vq_loss = codebook_loss + self.commitment_cost * commitment_loss
# Make the gradient with respect to latents be equal to the gradient with respect to quantized latents
quantized_latents = latents + (quantized_latents - latents).detach()
return quantized_latents, vq_loss
def quantize(self, encoding_indices):
z = self.embedding(encoding_indices)
z = z.movedim(-1, 1) # Move channels back
return z
def reset_parameters(self):
nn.init.uniform_(self.embedding.weight, -1 / self.num_embeddings, 1 / self.num_embeddings)
class ConvBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, act=True):
padding = (kernel_size - 1) // 2
layers = [
nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm2d(out_channels)
]
if act: layers.append(nn.ReLU(inplace=True))
super().__init__(*layers)
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_normal_(self[0].weight)
bn = self[1]
nn.init.ones_(bn.weight)
nn.init.zeros_(bn.bias)
class Stack(nn.Sequential):
def __init__(self, channels_list, block):
layers = []
for in_channels, out_channels in zip(channels_list[:-1], channels_list[1:]):
layers.append(block(in_channels, out_channels))
super().__init__(*layers)
class DownBlock(nn.Sequential):
def __init__(self, in_channels, out_channels):
super().__init__(
nn.MaxPool2d(2),
ConvBlock(in_channels, out_channels, 3)
)
class Encoder(nn.Module):
def __init__(self, in_channels, channels_list, latent_channels):
super().__init__()
self.stem = ConvBlock(in_channels, channels_list[0], 3)
self.blocks = Stack(channels_list, DownBlock)
self.to_latent = ConvBlock(channels_list[-1], latent_channels, 3)
def forward(self, x):
x = self.stem(x)
x = self.blocks(x)
x = self.to_latent(x)
return x
class UpBlock(nn.Sequential):
def __init__(self, in_channels, out_channels):
super().__init__(
nn.Upsample(scale_factor=2, mode='nearest'),
ConvBlock(in_channels, out_channels, 3)
)
class Decoder(nn.Module):
def __init__(self, latent_channels, channels_list, out_channels):
super().__init__()
self.stem = ConvBlock(latent_channels, channels_list[0], 3)
self.blocks = Stack(channels_list, UpBlock)
self.to_output = nn.Conv2d(channels_list[-1], out_channels, 3, padding=1)
def forward(self, x):
x = self.stem(x)
x = self.blocks(x)
x = self.to_output(x)
x = torch.sigmoid(x)
return x
class VQVAE(nn.Module):
def __init__(self, num_downsamplings, latent_channels, num_embeddings, channels=32, in_channels=3):
super().__init__()
channels_list = [channels * 2**i for i in range(num_downsamplings + 1)]
self.encoder = Encoder(in_channels, channels_list, latent_channels)
self.vq = VectorQuantizer(num_embeddings, latent_channels)
channels_list.reverse()
self.decoder = Decoder(latent_channels, channels_list, in_channels)
self.reduction = 2**num_downsamplings
self.num_embeddings = num_embeddings
def forward(self, x):
latents = self.encoder(x)
z, vq_loss = self.vq(latents)
out = self.decoder(z)
return out, vq_loss
def sample(self, num_samples, shape, device):
latent_shape = (num_samples, shape[0] // self.reduction, shape[1] // self. reduction)
ind = torch.randint(0, self.num_embeddings, latent_shape, device=device)
with torch.no_grad():
z = self.vq.quantize(ind)
out = self.decoder(z)
return out
model = VQVAE(num_downsamplings=NUM_DOWNSAMPLINGS, latent_channels=LATENT_CHANNELS, num_embeddings=NUM_EMBEDDINGS,
channels=ENCODER_CHANNELS, in_channels=IMAGE_CHANNELS)
model = model.to(DEVICE)
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
Number of parameters: 3,299,139
class VAELoss(nn.Module):
def __init__(self, λ=1.):
super().__init__()
self.λ = λ
self.reconstruction_loss = nn.MSELoss()
def forward(self, outputs, target):
output, vq_loss = outputs
reconst_loss = self.reconstruction_loss(output, target)
loss = reconst_loss + self.λ * vq_loss
return {"loss": loss, "reconstruction loss": reconst_loss, "VQ loss": vq_loss}
def plot_batch(ax, batch, title=None, **kwargs):
imgs = vision_utils.make_grid(batch, padding=2, normalize=True)
imgs = np.moveaxis(imgs.numpy(), 0, -1)
ax.set_axis_off()
if title is not None: ax.set_title(title)
return ax.imshow(imgs, **kwargs)
def show_images(batch, title):
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111)
plot_batch(ax, batch, title)
plt.show()
def reconstruct(model, batch, device):
batch = batch.to(device)
with torch.no_grad():
reconstructed_batch = model(batch)[0]
reconstructed_batch = reconstructed_batch.cpu()
return reconstructed_batch
def show_2_batches(batch1, batch2, title1, title2):
fig = plt.figure(figsize=(16, 8))
ax = fig.add_subplot(121)
plot_batch(ax, batch1, title1)
ax = fig.add_subplot(122)
plot_batch(ax, batch2, title2)
plt.show()
train_transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.CenterCrop(IMAGE_SIZE),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
val_transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.CenterCrop(IMAGE_SIZE),
transforms.ToTensor()
])
train_dset = datasets.CelebA(str(DATA_DIR), split='train', transform=train_transform, download=False)
val_dset = datasets.CelebA(str(DATA_DIR), split='test', transform=val_transform, download=False)
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_WORKERS)
val_loader = torch.utils.data.DataLoader(val_dset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS)
test_batch, _ = next(iter(train_loader))
show_images(test_batch[:64], "Training Images")
class AverageLoss():
def __init__(self, name):
self.name = name
self.reset()
def reset(self):
self.num_samples = 0
self.total_loss = 0.
def update(self, data):
batch_size = data['batch_size']
self.num_samples += batch_size
self.total_loss += batch_size * data[self.name]
def compute(self):
avg_loss = self.total_loss / self.num_samples
metrics = {self.name: avg_loss}
return metrics
class Learner:
def __init__(self, model, loss, optimizer, train_loader, val_loader, device,
batch_scheduler=None, epoch_scheduler=None):
self.model = model
self.loss = loss
self.optimizer = optimizer
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.batch_scheduler = batch_scheduler
self.epoch_scheduler = epoch_scheduler
self.history = defaultdict(list)
self.metrics = [AverageLoss(x) for x in ["loss", "reconstruction loss", "VQ loss"]]
def iterate(self, loader, train=False):
for metric in self.metrics:
metric.reset()
for batch in loader:
images = batch[0].to(self.device)
outputs = self.model(images)
losses = self.loss(outputs, images)
if train: self.backward_pass(losses["loss"])
data = {k: v.item() for k, v in losses.items()}
data["batch_size"] = len(images)
for metric in self.metrics:
metric.update(data)
summary = {}
for metric in self.metrics:
summary.update(metric.compute())
return summary
def backward_pass(self, batch_loss):
self.optimizer.zero_grad()
batch_loss.backward()
self.optimizer.step()
if self.batch_scheduler is not None:
self.batch_scheduler.step()
def log_metrics(self, metrics, name):
print(f"{name}: ", end='', flush=True)
for key, val in metrics.items():
self.history[name + ' ' + key].append(val)
print(f"{key} {val:.3f} ", end='')
def train(self):
self.model.train()
metrics = self.iterate(self.train_loader, train=True)
self.log_metrics(metrics, 'train')
def validate(self):
self.model.eval()
with torch.no_grad():
metrics = self.iterate(self.val_loader)
self.log_metrics(metrics, 'val')
def fit(self, epochs):
for epoch in range(1, epochs + 1):
print(f"{epoch}/{epochs} - ", end='')
self.train()
print('; ', end='')
self.validate()
print()
if self.epoch_scheduler is not None:
self.epoch_scheduler.step()
torch.save(model.state_dict(), str(MODELS_DIR / 'final_model.pt'))
def plot_history_train_val(history, key):
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['train ' + key]) + 1)
ax.plot(xs, history['train ' + key], '.-', label='train')
ax.plot(xs, history['val ' + key], '.-', label='val')
ax.set_xlabel('epoch')
ax.set_ylabel(key)
ax.grid()
ax.legend()
plt.show()
loss = VAELoss(λ=0.1)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
steps_per_epoch=len(train_loader), epochs=EPOCHS)
learner = Learner(model, loss, optimizer, train_loader, val_loader, DEVICE, batch_scheduler=lr_scheduler)
learner.fit(EPOCHS)
1/20 - train: loss 0.126 reconstruction loss 0.036 VQ loss 0.896 ; val: loss 0.054 reconstruction loss 0.021 VQ loss 0.327 2/20 - train: loss 0.035 reconstruction loss 0.018 VQ loss 0.169 ; val: loss 0.027 reconstruction loss 0.016 VQ loss 0.108 3/20 - train: loss 0.024 reconstruction loss 0.015 VQ loss 0.095 ; val: loss 0.023 reconstruction loss 0.014 VQ loss 0.087 4/20 - train: loss 0.021 reconstruction loss 0.013 VQ loss 0.077 ; val: loss 0.020 reconstruction loss 0.012 VQ loss 0.072 5/20 - train: loss 0.020 reconstruction loss 0.013 VQ loss 0.072 ; val: loss 0.020 reconstruction loss 0.013 VQ loss 0.075 6/20 - train: loss 0.020 reconstruction loss 0.012 VQ loss 0.078 ; val: loss 0.020 reconstruction loss 0.012 VQ loss 0.083 7/20 - train: loss 0.020 reconstruction loss 0.012 VQ loss 0.085 ; val: loss 0.021 reconstruction loss 0.012 VQ loss 0.089 8/20 - train: loss 0.021 reconstruction loss 0.012 VQ loss 0.090 ; val: loss 0.022 reconstruction loss 0.012 VQ loss 0.099 9/20 - train: loss 0.021 reconstruction loss 0.012 VQ loss 0.094 ; val: loss 0.022 reconstruction loss 0.011 VQ loss 0.101 10/20 - train: loss 0.021 reconstruction loss 0.011 VQ loss 0.098 ; val: loss 0.022 reconstruction loss 0.011 VQ loss 0.103 11/20 - train: loss 0.021 reconstruction loss 0.011 VQ loss 0.100 ; val: loss 0.022 reconstruction loss 0.011 VQ loss 0.106 12/20 - train: loss 0.021 reconstruction loss 0.011 VQ loss 0.101 ; val: loss 0.021 reconstruction loss 0.011 VQ loss 0.100 13/20 - train: loss 0.021 reconstruction loss 0.011 VQ loss 0.102 ; val: loss 0.021 reconstruction loss 0.011 VQ loss 0.097 14/20 - train: loss 0.021 reconstruction loss 0.011 VQ loss 0.103 ; val: loss 0.021 reconstruction loss 0.011 VQ loss 0.103 15/20 - train: loss 0.021 reconstruction loss 0.011 VQ loss 0.103 ; val: loss 0.021 reconstruction loss 0.011 VQ loss 0.102 16/20 - train: loss 0.021 reconstruction loss 0.011 VQ loss 0.104 ; val: loss 0.021 reconstruction loss 0.010 VQ loss 0.106 17/20 - train: loss 0.021 reconstruction loss 0.010 VQ loss 0.104 ; val: loss 0.021 reconstruction loss 0.010 VQ loss 0.108 18/20 - train: loss 0.021 reconstruction loss 0.010 VQ loss 0.104 ; val: loss 0.021 reconstruction loss 0.010 VQ loss 0.104 19/20 - train: loss 0.021 reconstruction loss 0.010 VQ loss 0.105 ; val: loss 0.021 reconstruction loss 0.010 VQ loss 0.105 20/20 - train: loss 0.021 reconstruction loss 0.010 VQ loss 0.105 ; val: loss 0.021 reconstruction loss 0.010 VQ loss 0.104
plot_history_train_val(learner.history, 'loss')
plot_history_train_val(learner.history, 'reconstruction loss')
plot_history_train_val(learner.history, 'VQ loss')
model.load_state_dict(torch.load(str(MODELS_DIR / 'final_model.pt')))
<All keys matched successfully>
model.eval();
Testing reconstruction
test_batch, _ = next(iter(val_loader))
reconstructed_batch = reconstruct(model, test_batch, DEVICE)
show_2_batches(test_batch[:64], reconstructed_batch[:64], "Validation Images", "Reconstructed Images")
Testing generation
sample_images = model.sample(64, (IMAGE_SIZE, IMAGE_SIZE), DEVICE)
show_images(sample_images.cpu(), "Sample Images")