In this post we’ll briefly explore the latent space behaviour of a classical Variational Autoencoder (VAE) trained on MNIST, and then we’ll impose further constraints to the latent space via an easy-to-understand, minimal working example written in PyTorch.
In a way, our example is a very simplified and rudimentary version of how Conditional VAEs work. Hopefully it can be helpful providing some intuition!
Variational Autoencoders (VAEs) are such a beautiful and exciting model, both in terms of their mathematical formulation and technological implications. Proposed by Kingma and Welling in 2013, they have been an extremely popular choice for many deep learning related tasks since then.
For this reason, details and implications are extensively covered elsewhere. Here it will suffice to know that autoencoder networks used for compression map their input to a lower-dimensional (called latent) space, generating a code, and then map the code back to the original space:
During training, they are enforced to make their output as similar as possible to the input. This way, the code can be used as a compressed version of the input. And they are good at that! In this 2019 paper you can see a relatively recent example where a particular type of VAE achieves state-of-the-art compression for speech.
One important difference between “regular” AEs and VAEs is that in VAEs, we know more about the latent space, because we enforce it to follow a certain distribution (in the simplest case, a multivariate Gaussian with unit covariance). If we randomly sample our codes from $\mathcal{N}(0, 1)$, it is likely that we will get valid codes, which makes them useful as generative models (otherwise we would have to explore a high-dimensional latent space or have access to the training data, usually difficult tasks).
For this setup, I trained the following simple VAE with 2 latent dimensions on the MNIST dataset:
class VAE(torch.nn.Module):
"""
A simple variational autoencoder
"""
def __init__(self, latent_dims=20):
"""
"""
super(VAE, self).__init__()
#
self.enc1 = torch.nn.Linear(784, 400)
self.activ_enc1 = torch.nn.ReLU(inplace=True)
self.enc2 = torch.nn.Linear(400, 100)
self.activ_enc2 = torch.nn.ReLU(inplace=True)
self.mean_layer = torch.nn.Linear(100, latent_dims)
self.logvar_layer = torch.nn.Linear(100, latent_dims)
#
self.dec1 = torch.nn.Linear(latent_dims, 100)
self.activ_dec1 = torch.nn.ReLU(inplace=True)
self.dec2 = torch.nn.Linear(100, 400)
self.activ_dec2 = torch.nn.ReLU(inplace=True)
self.dec3 = torch.nn.Linear(400, 784)
self.activ_dec3 = torch.nn.Sigmoid()
def encode(self, batch_in):
"""
Forward propagation for the encoder. It generates
mean and logvar (ln(stddev^2) helps the computation
of the loss and doesn't affect much optimization)
"""
a = self.activ_enc1(self.enc1(batch_in))
a = self.activ_enc2(self.enc2(a))
return self.mean_layer(a), self.logvar_layer(a)
def reparametrize(self, mean_batch, logvar_batch):
"""
The encoder produces two identically shaped tensors
which are here reparametrized following the theory
given in the VAE paper,section 2.4: It returns a
tensor of same shape than any of the inputs, and
containing element-wise (mean+stddev*x), where x~norm(I).
In this case stddev_batch could be assumed to be sigma, and
it would work, but assuming ln(sigma) helps the
optimization.
"""
std = torch.exp(logvar_batch/2)
eps = torch.FloatTensor(std.size()).normal_().to(DEVICE)
return mean_batch + std * eps
def decode(self, batch_in):
"""
"""
b = self.activ_dec1(self.dec1(batch_in))
b = self.activ_dec2(self.dec2(b))
return self.activ_dec3(self.dec3(b))
def forward(self, batch_in):
"""
Full forward propagation of the network: Given a batch of
images, performs (encoding->reparametrization->decoding)
and returns (encoder_mean, encoder_logvar, reconstruction)
"""
enc_mean, enc_logvar = self.encode(batch_in)
repar = self.reparametrize(enc_mean, enc_logvar)
reconstruction = self.decode(repar)
return enc_mean, enc_logvar, reconstruction
Then I sampled some codes from $\mathcal{N}(0, 1)$, and passed them through the decoder to obtain the corresponding images:
Indeed we see that many of these codes correspond to valid numbers. Note another property of VAE latent spaces: they are locally smooth, i.e., nearby codes result in “similar” images.
Now let’s re-train the model from scratch:
Although we can indeed efficiently sample and retain some local semantics, we can observe still 3 issues that hinder the power of VAEs as generative models:
Can we do something about it?
Let’s make a fun experiment: we’ll introduce anchors into the objective: the anchors will be data instances with specific features, and we want the objective function to associate specific regions of the latent space to them. This way, we are explicitly enforcing the VAE to have a latent space with known semantics at a given point, similarly to what a CVAE does.
One interesting and simple example is the following: Imagine we want to associate all “round-ish” (like 0) numbers with one corner, and all “stick-ish” (like 1) numbers with the opposite. We will define 2 archetypical 28x28 arrays, one with the form of a “donut”, and other as a vertical “stick”. The corresponding objective would look like this:
class VAEAnchorLoss(torch.nn.Module):
"""
Implementation of the 'parametrized' loss function for the
VAE, as described in the paper, with the addendum of 'anchor'
semantics.
Logvar instead of var for numerical stability and faster
computations, as explained here:
http://louistiao.me/posts/implementing-variational-autoencoders-in-keras-beyond-the-quickstart-tutorial/
"""
DONUT_MAP = torch.FloatTensor([[0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0],
[0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0],
[0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0],
[0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0],
[0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0],
[0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
[0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
[0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0],
[0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0],
[0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0],
[0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0],
[0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0]])
STICK_MAP = torch.FloatTensor([[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0]])
def __init__(self, device):
"""
"""
super(VAEAnchorLoss, self).__init__()
self.recon_loss = torch.nn.MSELoss(size_average=False)
self.donut_flat = self.DONUT_MAP.view(-1).to(device)
self.stick_flat = self.STICK_MAP.view(-1).to(device)
self.to(device)
def forward(self, mean, logvar, pred, target, anchor_rate=0):
"""
"""
recon_err = self.recon_loss(pred, target)
#
target_sum_by_elt = target.sum(-1)
donut_score_by_elt = (target * self.donut_flat).sum(-1) / target_sum_by_elt
stick_score_by_elt = (target * self.stick_flat).sum(-1) / target_sum_by_elt
donutness = donut_score_by_elt - stick_score_by_elt
donutness_expanded = donutness.view(-1,1).expand(mean.size()) # to be broadcastable to mean
#
var = logvar.exp()
anchored_mean = mean + anchor_rate * donutness_expanded
kl_err = -0.5 * torch.sum(1 + logvar - anchored_mean**2 - var)
#
return recon_err + kl_err
We are computing the donutness
score of each sample, and we want samples with higher score to have lower mean. Conversely, samples with lower score (more similar to the stick
) are enforced to have higher mean. This is reflected perfectly when we train using this objective and sample from $\mathcal{N}(0, 1)$:
The zeros are at the upper left corner (typically associated with the lowest indexes in numpy arrays), and the ones at the bottom right corner (highest indexes). To make sure this wasn’t a lucky shot, let’s retrain the model and see if it works again:
Note that the DONUT
and STICK
maps are here just a proof of concept to provide an intuition on the relation between loss function, data, and latent space.
In general, more useful or refined constraints could be imposed. The maps can also be replaced with e.g. aggregations of real-world data or actual batches (as in the CVAE paper).
Here you can find the pretrained model in case you want to reproduce the experiment:
Characterization of VAE latent spaces is a vivid field of research with many breakthroughs continuously coming up. In this context it is worth mentioning the work of my colleagues at GU Frankfurt, whose work in Continual Learning contributed substantially to the characterization of VAE latent spaces, see e.g. here and here.
Also tangentially related to this topic, but usually relevant when working with VAEs that have been trained with a normal prior, is this brief report that contains the proof and steps to perform statistical thresholding, given the $\mu$ and $\sigma$ outputs from the VAE.
Original media in this post is licensed under CC BY-NC-ND 4.0. Software snippets are licensed under GPLv3