Distribution-Aware Heatmap Decoder for Human Pose Estimation

PUBLISHED ON JUN 8, 2020 — CATEGORIES: utilities

Why DARK?

When performing heatmap-based human pose estimation (HPE), it is important to take into account the relevance of the heatmap encoding and decoding process. The DARK paper (arXiv version here) is the first attempt ever to systematically study this aspect. In it, the authors propose a “model-agnostic plug-in that significantly improves the performance of a variety of state-of-the-art human pose estimation models”.

Indeed, this table shows that DARK greatly boosts the precision of many models. And not only that: the model is computationally efficient, which is crucial when striving for real-time performance, since decoding process can be a real bottleneck in many instances.

These qualities, and the fact that it acts like a plug-in make it very desirable to have DARK as part of any HPE setup. Unfortunately, at this point in time an implementation seems to be missing from the paper’s GitHub repository. In this entry I provide my own implementation together with some remarks.

If you are further interested in this topic you can check some of my work on human pose estimation here. Among other things, I include a review of some recent literature, including the DARK paper.


Decoding Process

In this image you can see how the heatmap decoding process would look like: Given an image (from the COCO val2017 dataset), a convolutional neural network generates a set of heatmaps, and the decoder extracts the “central location” of the heatmap:

.

(Note that the channels for the different keypoints have been aggregated into a single image for easier visualization).

Traditionally, performing a non-maximum suppression (NMS) and some noise removal was the to-go procedure, which could be refined by some sub-pixel correction heuristics. Among other things, the DARK paper points out the great relevance of the decoding step and proposes a process that is summarized in the following image (from the paper):

.

In it, we can see that $m$ is the discretized location that we would obtain by e.g. a naïve NMS, and $\mu$ is the actual location of the mode, unobserved to us. To estimate $\mu$, they propose a formula based on the Taylor-expansion of the assumed distribution. This formula requires to compute the first and second derivative around $m$, which can be done very efficiently.

Check the paper for more details about the model and its assumptions.


Implementation

The implementation follows the described process, with a few quirks:

  1. The input heatmaps are (optionally) modulated via 2D Gaussian convolution
  2. NMS is performed via the “max pooling” trick. Results below a threshold are filtered out (useful to remove e.g. background plateaus)
  3. From each connected component in the NMS map, a single location is extracted (useful to prevent regarding a 10-pixel plateau as 10 different keypoints).
  4. The vicinity around each location is retrieved, and the first and second derivatives computed
  5. The derivatives are used to shift the location towards the mode of the assumed distribution

And without further delay, here is the software (Python 3):

"""
Implementation of the DARK modulation and decoding steps from
https://arxiv.org/pdf/1910.06278.pdf
"""

import numpy as np
from skimage.measure import label, regionprops
import torch


__author__ = "aferro"
__license__ = "MIT"

def gaussian_2d_kernel(length=5, stddev=1.0, dtype=np.float32):
    """
    creates gaussian kernel with side length l and a sigma of sig
    Source:  https://stackoverflow.com/a/43346070
    """
    ax = np.linspace(-(length - 1) / 2., (length - 1) / 2., length)
    xx, yy = np.meshgrid(ax, ax)
    kernel = np.exp(-0.5 * (np.square(xx) + np.square(yy)) / np.square(stddev))
    kernel = kernel.astype(dtype)
    return kernel / np.sum(kernel)


class DarkKeypointDecoder:
    """
    http://arxiv.org/abs/1910.06278
    """
    NUM_SIGMAS = 2
    def __init__(self, conv_stddev=0, num_keypoints=17,
                 nms_ksize=3, get_centroids=True):
        """
        :param get_centroids: If true, compute the centroid of each NMS
          connected component. If false, just pick any pixel of it. Note
          that this slows down decoding, and doesn't make a difference for
          single-pixel regions (the most common ones).
        """
        assert conv_stddev >= 0, "Negative stddev!"
        assert nms_ksize % 2 == 1, \
            "Only odd ksize (3, 5, etc) allowed ! (makes life easier)"
        # modulation (gaussian conv)
        self.conv_stddev = conv_stddev
        if conv_stddev > 0:
            conv_ksize = self. NUM_SIGMAS * 2 * conv_stddev + 3
            conv_pad = conv_ksize // 2
            with torch.no_grad():  # just in case
                arr = gaussian_2d_kernel(conv_ksize, conv_stddev)
                kernel = torch.nn.Parameter(data=torch.from_numpy(arr),
                                            requires_grad=False)
                self.conv = torch.nn.Conv2d(
                    num_keypoints, num_keypoints, groups=num_keypoints,
                    kernel_size=conv_ksize, padding=conv_pad, bias=False)
                kernel = kernel.unsqueeze(0).unsqueeze(0).expand(
                    num_keypoints, -1, -1, -1)
                self.conv.weight.data = kernel
                self.conv.weight.requires_grad = False
        # nms
        self.nms_ksize = nms_ksize
        self.nms_pad = nms_ksize // 2
        self.nms_pool = torch.nn.MaxPool2d(
            self.nms_ksize, stride=1, padding=self.nms_pad)
        #
        self.get_centroids = get_centroids
        # y prev, x prev, y next, x next
        self.vicinity = np.int32([[-1, 0], [0, -1], [1, 0], [0, 1]])

    def rectify_dark(self, location_yx, weight_loc, prev_weight_yx,
                     next_weight_yx):
        """
        paper assumes y and x are independent (i.e. diagonal cov matrix),
        so the vertical and horiz. rectifications can be treated as 2
        separate 1-dimensional (here we vectorize).
        """
        # The paper formula is based on the log PDF
        w = np.log(weight_loc)
        pw = np.log(prev_weight_yx)
        nw = np.log(next_weight_yx)
        # Numeric first derivative: [f(x+h)-f(x-h)] / 2h
        d = (nw - pw) / 2
        # Numeric second derivative: [f(x+h)-2f(x)+f(x-h)] / h^2
        dd = nw + pw - 2 * w
        # DARK formula:
        refined = location_yx - (d/dd)
        #
        return refined

    def __call__(self, heatmaps, nms_thresh=0.1):
        """
        """
        _, hm_h, hm_w = heatmaps.shape
        #
        with torch.no_grad():
            if self.conv_stddev > 0:
                heatmaps = self.conv(heatmaps.unsqueeze(0))[0]
            heatmaps = heatmaps.clamp(0, 1)
            nms = self.nms_pool(heatmaps)
            nms = (nms == heatmaps)
            nms = (heatmaps * nms) >= nms_thresh
        all_centers = {}
        #
        for i, ch in enumerate(nms):

            # isolate a single pixel per connected component
            segm, n_segm = label(ch, return_num=True)
            rprops = regionprops(segm)
            #
            if self.get_centroids:
                centers = [rp.centroid for rp in rprops]
            else:
                centers = [rp.coords[0] for rp in rprops]
            #
            refined_centers = []
            for center in centers:
                # extract vicinity values to compute numerical derivatives
                rounded_c = np.array(center).round().astype(np.int32)
                vic = rounded_c + self.vicinity
                vic = vic[(0 <= vic[:, 0]) & (vic[:, 0] < hm_h)]
                vic = vic[(0 <= vic[:, 1]) & (vic[:, 1] < hm_w)]
                w = heatmaps[i][rounded_c[0], rounded_c[1]].numpy()
                vic_w = heatmaps[i][vic.T].numpy()
                # compute DARK shift as in paper
                refined_c = self.rectify_dark(rounded_c, w,
                                              vic_w[0:2], vic_w[2:4])
                refined_centers.append(refined_c)
            all_centers[i] = refined_centers
        return all_centers

The software is released under MIT license, so feel free to use! Note that a proper ablation study and quantitative evaluation is still pending (contributions welcome!). Of course I greatly appreciate any feedback on the code.


Discussion

This implementation assumes that the output gaussians have diagonal covariance matrix, i.e. no correlation between vertical and horizontal axes. For that reason, the 2D derivatives can be computed as two separate 1D components. The following notes on numerical differentiation (online version here) can be used to implement step 4.

Of course, we can observe from the left output at the image below that this assumption is wrong: many “blobs” show some sort of correlation. In this context, the modulation process (provided an uncorrelated gaussian kernel) dampens the correlations, as it can be seen at the right image below:

But also note that the values at the peaks also sink, which has to be taken into account for the NMS thresholding.

Another important remark in this context is that the runtime of the decoder is dominated by the modulation step. If no modulation is applied, the runtime is comparable to a regular non-maximum suppresion (NMS) decoder.

Obviously, it is crucial to implement it with as many groups as channels (to save a lot of redundant computation). But this implementation also bears room for further optimization: since the modulation kernel has no covariance, it could be further optimized by performing and merging two 1D convolutions.

TAGS: computer vision, machine learning, pytorch, signal processing