Anchored VAE

PUBLISHED ON MAY 7, 2020 — CATEGORIES: explorations

Variational Autoencoders (VAEs) are such a beautiful and exciting model, both in their mathematical formulation and technological implications. Proposed by Kingma and Welling in 2013, they have been an extremely popular choice for many deep learning related tasks since then.

For this reason, details and implications are extensively covered elsewhere. Here it will suffice to know that autoencoder networks used for compression map their input to a lower-dimensional (called latent) space, generating a code, and then map the code back to the original space:

During training, they are enforced to make their output as similar as possible to the input. This way, the code can be used as a compressed version of the input. And they are good at that! In this 2019 paper you can see a relatively recent example where a particular type of VAE achieves state-of-the-art compression for speech.

One important difference between “regular” autoencoders and VAEs is that in VAEs, we know more about the latent space. In the usual case, if we randomly sample our codes from $\mathcal{N}(0, 1)$, it is likely that we will get valid codes. For this setup, I trained the following simple VAE with 2 latent dimensions on the MNIST dataset:

class VAE(torch.nn.Module):
    """
    A simple variational autoencoder
    """

    def __init__(self, latent_dims=20):
        """
        """
        super(VAE, self).__init__()
        #
        self.enc1 = torch.nn.Linear(784, 400)
        self.activ_enc1 = torch.nn.ReLU(inplace=True)
        self.enc2 = torch.nn.Linear(400, 100)
        self.activ_enc2 = torch.nn.ReLU(inplace=True)
        self.mean_layer = torch.nn.Linear(100, latent_dims)
        self.logvar_layer = torch.nn.Linear(100, latent_dims)
        #
        self.dec1 = torch.nn.Linear(latent_dims, 100)
        self.activ_dec1 = torch.nn.ReLU(inplace=True)
        self.dec2 = torch.nn.Linear(100, 400)
        self.activ_dec2 = torch.nn.ReLU(inplace=True)
        self.dec3 = torch.nn.Linear(400, 784)
        self.activ_dec3 = torch.nn.Sigmoid()

    def encode(self, batch_in):
        """
        Forward propagation for the encoder. It generates
        mean and logvar (ln(stddev^2) helps the computation
        of the loss and doesn't affect much optimization)
        """
        a = self.activ_enc1(self.enc1(batch_in))
        a = self.activ_enc2(self.enc2(a))
        return self.mean_layer(a), self.logvar_layer(a)

    def reparametrize(self, mean_batch, logvar_batch):
        """
        The encoder produces two identically shaped tensors
        which are here reparametrized following the theory
        given in the VAE paper,section 2.4: It returns a
        tensor of same shape than any of the inputs, and
        containing element-wise (mean+stddev*x), where x~norm(I).
        In this case stddev_batch could be assumed to be sigma, and
        it would work, but assuming ln(sigma) helps the
        optimization.
        """
        std = torch.exp(logvar_batch/2)
        eps = torch.FloatTensor(std.size()).normal_().to(DEVICE)
        return mean_batch + std * eps

    def decode(self, batch_in):
        """
        """
        b = self.activ_dec1(self.dec1(batch_in))
        b = self.activ_dec2(self.dec2(b))
        return self.activ_dec3(self.dec3(b))

    def forward(self, batch_in):
        """
        Full forward propagation of the network: Given a batch of
        images, performs (encoding->reparametrization->decoding)
        and returns (encoder_mean, encoder_logvar, reconstruction)
        """
        enc_mean, enc_logvar = self.encode(batch_in)
        repar = self.reparametrize(enc_mean, enc_logvar)
        reconstruction = self.decode(repar)
        return enc_mean, enc_logvar, reconstruction

Then I sampled some codes, and passed them through the decoder to obtain the corresponding images:

Indeed we see that many of these codes correspond to valid numbers. Note another property of VAE latent spaces: they are locally smooth, i.e., nearby codes result in “similar” images.


Anchored VAE

Now see what happens if I retrain the model:

We still see many valid codes, and the local smoothnes. But the overall structure has changed! Also note that the same number appears in multiple disjoint regions. This can be problematic: Imagine that you have trained a VAE with a 100-dimensional latent space using unlabeled data (a pretty common scenario), and you want to generate some new data. It is not trivial to find the codes associated with a particular feature.

One way of fixing that is introducing anchors into the objective: anchors are data instances with specific features, that the objective function associates to specific regions of the latent space. This way, we are explicitly enforcing the VAE to have a latent space with known semantics at a given point.

One interesting and simple example is the following: Imagine we want to associate all “round-ish” (like 0) numbers with one corner, and all “stick-ish” (like 1) numbers with the opposite. The corresponding objective would look like this:

class VAEAnchorLoss(torch.nn.Module):
    """
    Implementation of the 'parametrized' loss function for the
    VAE, as described in the paper, with the addendum of 'anchor'
    semantics.
    Logvar instead of var for numerical stability and faster
    computations, as explained here:
    http://louistiao.me/posts/implementing-variational-autoencoders-in-keras-beyond-the-quickstart-tutorial/
    """

    DONUT_MAP = torch.FloatTensor([[0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0],
                                   [0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0],
                                   [0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0],
                                   [0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0],
                                   [0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0],
                                   [0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0],
                                   [0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0],
                                   [0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0]])

    STICK_MAP = torch.FloatTensor([[0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0],
                                   [0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0]])

    def __init__(self, device):
        """
        """
        super(VAEAnchorLoss, self).__init__()
        self.recon_loss = torch.nn.MSELoss(size_average=False)
        self.donut_flat = self.DONUT_MAP.view(-1).to(device)
        self.stick_flat = self.STICK_MAP.view(-1).to(device)
        self.to(device)

    def forward(self, mean, logvar, pred, target, anchor_rate=0):
        """
        """
        recon_err = self.recon_loss(pred, target)
        #
        target_sum_by_elt = target.sum(-1)
        donut_score_by_elt = (target * self.donut_flat).sum(-1) / target_sum_by_elt
        stick_score_by_elt = (target * self.stick_flat).sum(-1) / target_sum_by_elt
        donutness = donut_score_by_elt - stick_score_by_elt
        donutness_expanded = donutness.view(-1,1).expand(mean.size()) # to be broadcastable to mean
        #
        var = logvar.exp()
        anchored_mean = mean + anchor_rate * donutness_expanded
        kl_err = -0.5 * torch.sum(1 + logvar - anchored_mean**2 - var)
        #
        return recon_err + kl_err

The DONUT and STICK maps are here just a proof of concept and could be replaced with aggregations of real-world data. As intended, training the VAE with this loss function causes the samples with more “donutness” to go to one side, and the samples with less to the opposite:

To make sure this wasn't a lucky shot, let's retrain the model and see if it works again:

Success! Here you can find the pretrained 2D anchorVAE if you want to run it yourself:

Of course, this is a proof of concept and more useful or refined constraints could be imposed. More generally, characterization of VAE latent spaces is a vivid field of research with many breakthroughs continuously coming up. In this context it is worth mentioning the work of my colleagues at GU Frankfurt, who did a much more thorough characterization of VAE latent spaces, see e.g. here and here.

Also tangentially related to this topic, but usually relevant when working with VAEs that have been trained with a normal prior, is this brief report that contains the proof and steps to perform statistical thresholding, given the $\mu$ and $\sigma$ outputs from the VAE.


Original media in this post is licensed under CC BY-NC-ND 4.0. Software snippets are licensed under GPLv3