Intuition on Conditional VAEs

PUBLISHED ON MAY 7, 2020 — CATEGORIES: explorations

Summary

In this post we’ll briefly explore the latent space behaviour of a classical Variational Autoencoder (VAE) trained on MNIST, and then we’ll impose further constraints to the latent space via an easy-to-understand, minimal working example written in PyTorch.

In a way, our example is a very simplified and rudimentary version of how Conditional VAEs work. Hopefully it can be helpful providing some intuition!


From AE to VAE

Variational Autoencoders (VAEs) are such a beautiful and exciting model, both in terms of their mathematical formulation and technological implications. Proposed by Kingma and Welling in 2013, they have been an extremely popular choice for many deep learning related tasks since then.

For this reason, details and implications are extensively covered elsewhere. Here it will suffice to know that autoencoder networks used for compression map their input to a lower-dimensional (called latent) space, generating a code, and then map the code back to the original space:

Schema of an autoencoder (source: Wikipedia)

During training, they are enforced to make their output as similar as possible to the input. This way, the code can be used as a compressed version of the input. And they are good at that! In this 2019 paper you can see a relatively recent example where a particular type of VAE achieves state-of-the-art compression for speech.

One important difference between “regular” AEs and VAEs is that in VAEs, we know more about the latent space, because we enforce it to follow a certain distribution (in the simplest case, a multivariate Gaussian with unit covariance). If we randomly sample our codes from $\mathcal{N}(0, 1)$, it is likely that we will get valid codes, which makes them useful as generative models (otherwise we would have to explore a high-dimensional latent space or have access to the training data, usually difficult tasks).

For this setup, I trained the following simple VAE with 2 latent dimensions on the MNIST dataset:

class VAE(torch.nn.Module):
    """
    A simple variational autoencoder
    """

    def __init__(self, latent_dims=20):
        """
        """
        super(VAE, self).__init__()
        #
        self.enc1 = torch.nn.Linear(784, 400)
        self.activ_enc1 = torch.nn.ReLU(inplace=True)
        self.enc2 = torch.nn.Linear(400, 100)
        self.activ_enc2 = torch.nn.ReLU(inplace=True)
        self.mean_layer = torch.nn.Linear(100, latent_dims)
        self.logvar_layer = torch.nn.Linear(100, latent_dims)
        #
        self.dec1 = torch.nn.Linear(latent_dims, 100)
        self.activ_dec1 = torch.nn.ReLU(inplace=True)
        self.dec2 = torch.nn.Linear(100, 400)
        self.activ_dec2 = torch.nn.ReLU(inplace=True)
        self.dec3 = torch.nn.Linear(400, 784)
        self.activ_dec3 = torch.nn.Sigmoid()

    def encode(self, batch_in):
        """
        Forward propagation for the encoder. It generates
        mean and logvar (ln(stddev^2) helps the computation
        of the loss and doesn't affect much optimization)
        """
        a = self.activ_enc1(self.enc1(batch_in))
        a = self.activ_enc2(self.enc2(a))
        return self.mean_layer(a), self.logvar_layer(a)

    def reparametrize(self, mean_batch, logvar_batch):
        """
        The encoder produces two identically shaped tensors
        which are here reparametrized following the theory
        given in the VAE paper,section 2.4: It returns a
        tensor of same shape than any of the inputs, and
        containing element-wise (mean+stddev*x), where x~norm(I).
        In this case stddev_batch could be assumed to be sigma, and
        it would work, but assuming ln(sigma) helps the
        optimization.
        """
        std = torch.exp(logvar_batch/2)
        eps = torch.FloatTensor(std.size()).normal_().to(DEVICE)
        return mean_batch + std * eps

    def decode(self, batch_in):
        """
        """
        b = self.activ_dec1(self.dec1(batch_in))
        b = self.activ_dec2(self.dec2(b))
        return self.activ_dec3(self.dec3(b))

    def forward(self, batch_in):
        """
        Full forward propagation of the network: Given a batch of
        images, performs (encoding->reparametrization->decoding)
        and returns (encoder_mean, encoder_logvar, reconstruction)
        """
        enc_mean, enc_logvar = self.encode(batch_in)
        repar = self.reparametrize(enc_mean, enc_logvar)
        reconstruction = self.decode(repar)
        return enc_mean, enc_logvar, reconstruction

Then I sampled some codes from $\mathcal{N}(0, 1)$, and passed them through the decoder to obtain the corresponding images:

Indeed we see that many of these codes correspond to valid numbers. Note another property of VAE latent spaces: they are locally smooth, i.e., nearby codes result in “similar” images.

Now let’s re-train the model from scratch:

Although we can indeed efficiently sample and retain some local semantics, we can observe still 3 issues that hinder the power of VAEs as generative models:

  • The space is inconsistent at global scale (e.g. the same number can appear in different regions)
  • We still don’t know where to find a specific number before sampling
  • The space is inconsistent throughout training: even if we retrain the exact same architecture, the latent distribution will change

Can we do something about it?


Adding Constraints to the VAE Latent Space

Let’s make a fun experiment: we’ll introduce anchors into the objective: the anchors will be data instances with specific features, and we want the objective function to associate specific regions of the latent space to them. This way, we are explicitly enforcing the VAE to have a latent space with known semantics at a given point, similarly to what a CVAE does.

One interesting and simple example is the following: Imagine we want to associate all “round-ish” (like 0) numbers with one corner, and all “stick-ish” (like 1) numbers with the opposite. We will define 2 archetypical 28x28 arrays, one with the form of a “donut”, and other as a vertical “stick”. The corresponding objective would look like this:

class VAEAnchorLoss(torch.nn.Module):
    """
    Implementation of the 'parametrized' loss function for the
    VAE, as described in the paper, with the addendum of 'anchor'
    semantics.
    Logvar instead of var for numerical stability and faster
    computations, as explained here:
    http://louistiao.me/posts/implementing-variational-autoencoders-in-keras-beyond-the-quickstart-tutorial/
    """

    DONUT_MAP = torch.FloatTensor([[0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0],
                                   [0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0],
                                   [0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0],
                                   [0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0],
                                   [0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0],
                                   [0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0],
                                   [0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0]])

    STICK_MAP = torch.FloatTensor([[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0]])

    def __init__(self, device):
        """
        """
        super(VAEAnchorLoss, self).__init__()
        self.recon_loss = torch.nn.MSELoss(size_average=False)
        self.donut_flat = self.DONUT_MAP.view(-1).to(device)
        self.stick_flat = self.STICK_MAP.view(-1).to(device)
        self.to(device)

    def forward(self, mean, logvar, pred, target, anchor_rate=0):
        """
        """
        recon_err = self.recon_loss(pred, target)
        #
        target_sum_by_elt = target.sum(-1)
        donut_score_by_elt = (target * self.donut_flat).sum(-1) / target_sum_by_elt
        stick_score_by_elt = (target * self.stick_flat).sum(-1) / target_sum_by_elt
        donutness = donut_score_by_elt - stick_score_by_elt
        donutness_expanded = donutness.view(-1,1).expand(mean.size()) # to be broadcastable to mean
        #
        var = logvar.exp()
        anchored_mean = mean + anchor_rate * donutness_expanded
        kl_err = -0.5 * torch.sum(1 + logvar - anchored_mean**2 - var)
        #
        return recon_err + kl_err

We are computing the donutness score of each sample, and we want samples with higher score to have lower mean. Conversely, samples with lower score (more similar to the stick) are enforced to have higher mean. This is reflected perfectly when we train using this objective and sample from $\mathcal{N}(0, 1)$:

The zeros are at the upper left corner (typically associated with the lowest indexes in numpy arrays), and the ones at the bottom right corner (highest indexes). To make sure this wasn’t a lucky shot, let’s retrain the model and see if it works again:

Success!
  • Now the $0, 1$ entries occupy unique regions consistently
  • We can still efficiently sample from the latent distribution using $\mathcal{N}(0, 1)$, and furthermore, we can look for $(0, 1)$ entries at low/high mean regions, respectively
  • The $0, 1$ positions are consistent across training instances

Conclusion

Note that the DONUT and STICK maps are here just a proof of concept to provide an intuition on the relation between loss function, data, and latent space.

In general, more useful or refined constraints could be imposed. The maps can also be replaced with e.g. aggregations of real-world data or actual batches (as in the CVAE paper).

Download pretrained model

Here you can find the pretrained model in case you want to reproduce the experiment:

Characterization of VAE latent spaces is a vivid field of research with many breakthroughs continuously coming up. In this context it is worth mentioning the work of my colleagues at GU Frankfurt, whose work in Continual Learning contributed substantially to the characterization of VAE latent spaces, see e.g. here and here.

Also tangentially related to this topic, but usually relevant when working with VAEs that have been trained with a normal prior, is this brief report that contains the proof and steps to perform statistical thresholding, given the $\mu$ and $\sigma$ outputs from the VAE.


Original media in this post is licensed under CC BY-NC-ND 4.0. Software snippets are licensed under GPLv3

TAGS: computer vision, machine learning, proof, pytorch, vae