What is Gubmel-softmax?

Gumbel-softmax is an efficient gradient estimator for the non-differentiable sample from a categorical distribution.

Let $z$ be a categorical variable with class probabilities $\pi_1$, $\pi_2$, $\cdots$, $\pi_k$.

The Gumbel-Max trick is an efficient way to draw samples $z$ from a categorical distribution with class probabilities $\pi$:

$$ z = \text{one\_hot}(\argmax_{i}\left[g_i + log \pi_i\right]) $$

where $g_1$, $\cdots$, $g_k$ are i.i.d samples drawn from Gumbel(0, 1). The Gubel(0, 1) distribution can be sampled by drawing $u \sim \text{Uniform}(0, 1)$ and computing $g = - \log (- \log (u))$.

Then softmax function is used as a continuous, differentiable approximation to $\argmax$:

$$ y_i = \frac{\exp((\log(\pi_i) + g_i)/\tau)}{\sum_{j=1}^{k} \exp((\log(\pi_j) + g_j)/\tau)} \quad \text { for } i=1, \ldots, k $$

Let $y = [y_1, y_2, \cdots, y_k]$.

Straight-through Gumbel-softmax Estimator

Although the relaxed continuous one-hot vectors are differentiable, in some scenarios where sampling of random variables is required, the discretized $y$ is non-differentiable. In these scenarios, during the backpropagation process, the gradient of continuous variable y is used to approximate the gradient of the discrete variable $z$, that is, $\nabla_{\theta}z \approx \nabla_{\theta}y$.

In the PyTorch implementation, it appears as follows:

ret = y_hard - y_soft.detach() + y_soft

In the above code snippet, during the backward pass, neither y_hard nor y_soft.detach() has gradients, while the gradients are provided by y_soft. During the forward pass, the output value of this line code equals y_hard.

References

Jang, E., Gu, S., & Poole, B. (2016, November). Categorical Reparameterization with Gumbel-Softmax. In International Conference on Learning Representations.