What is Gumbel-softmax?

Gumbel-softmax is an efficient tool to sample from a categorical distribution.

Let $z$ be a categorical variable with class probabilities $\pi_1$, $\pi_2$, $\cdots$, $\pi_k$.

$$ z = \text{one\_hot}(\argmax_{i}\left[g_i + log \pi_i\right]) $$

where $g_1$, $\cdots$, $g_k$ are i.i.d samples drawn from Gumbel(0, 1). The Gumbel(0, 1) distribution can be sampled by drawing $u \sim \text{Uniform}(0, 1)$ and computing $g = - \log (- \log (u))$.

However, due to the use of argmax, the Gumbel-Max trick is not-differentiable operation.

$$ y_i = \frac{\exp((\log(\pi_i) + g_i)/\tau)}{\sum_{j=1}^{k} \exp((\log(\pi_j) + g_j)/\tau)} \quad \text { for } i=1, \ldots, k $$

Let $y = [y_1, y_2, \cdots, y_k]$.

The following figure presents the impact of $\tau$ for expectation and sample.

alt text

$\tau$ is an important parameter. In practice, they usually use an annealing schedule to gradually reduce the temperature during training.

Straight-through Gumbel-softmax Estimator. Although the relaxed continuous one-hot vectors are differentiable, in some scenarios where sampling of random variables is required, the discretized $y$ is non-differentiable. In these scenarios, during the backpropagation process, the gradient of continuous variable y is used to approximate the gradient of the discrete variable $z$, that is, $\nabla_{\theta}z \approx \nabla_{\theta}y$.

In the PyTorch implementation, it appears as follows:

ret = y_hard - y_soft.detach() + y_soft

In the above code snippet, during the backward pass, neither y_hard nor y_soft.detach() has gradients, while the gradients are provided by y_soft. During the forward pass, the output value of this line code equals y_hard.

We also conducted several experiments using gumbel softmax to generate images within the VAE or to generate random variables. Here is the link to the experimental code.

Relation with the Reparamaterization Trick

$$ z \sim \mathcal{N}(\mu, \sigma) \rightarrow z = \mu + \sigma \epsilon \text{ where} \epsilon \sim \mathcal{N}(0, 1). $$$$ z \sim \left[\pi_1, \pi_2, \pi_3, \cdot, \pi_k \right] \rightarrow z = \text{softmax}(\frac{(\log(\pi_i) + g_i)}{\tau}), $$

where $g = - \log (- \log (u))$ and $u \sim \text{Uniform}(0, 1)$.

References

  1. Jang, E., Gu, S., & Poole, B. (2016, November). Categorical Reparameterization with Gumbel-Softmax. In International Conference on Learning Representations.

  2. https://sassafras13.github.io/GumbelSoftmax/

  3. https://neptune.ai/blog/gumbel-softmax-loss-function-guide-how-to-implement-it-in-pytorch