What is Gubmel-softmax?
Gumbel-softmax is an efficient gradient estimator for the non-differentiable sample from a categorical distribution.
Let $z$ be a categorical variable with class probabilities $\pi_1$, $\pi_2$, $\cdots$, $\pi_k$.
The Gumbel-Max trick is an efficient way to draw samples $z$ from a categorical distribution with class probabilities $\pi$:
$$ z = \text{one\_hot}(\argmax_{i}\left[g_i + log \pi_i\right]) $$where $g_1$, $\cdots$, $g_k$ are i.i.d samples drawn from Gumbel(0, 1). The Gubel(0, 1) distribution can be sampled by drawing $u \sim \text{Uniform}(0, 1)$ and computing $g = - \log (- \log (u))$.
Then softmax function is used as a continuous, differentiable approximation to $\argmax$:
$$ y_i = \frac{\exp((\log(\pi_i) + g_i)/\tau)}{\sum_{j=1}^{k} \exp((\log(\pi_j) + g_j)/\tau)} \quad \text { for } i=1, \ldots, k $$Let $y = [y_1, y_2, \cdots, y_k]$.
Straight-through Gumbel-softmax Estimator
Although the relaxed continuous one-hot vectors are differentiable, in some scenarios where sampling of random variables is required, the discretized $y$ is non-differentiable. In these scenarios, during the backpropagation process, the gradient of continuous variable y is used to approximate the gradient of the discrete variable $z$, that is, $\nabla_{\theta}z \approx \nabla_{\theta}y$.
In the PyTorch implementation, it appears as follows:
ret = y_hard - y_soft.detach() + y_soft
In the above code snippet, during the backward pass, neither y_hard
nor y_soft.detach()
has gradients, while the gradients are provided by y_soft
. During the forward pass, the output value of this line code equals y_hard
.
References
Jang, E., Gu, S., & Poole, B. (2016, November). Categorical Reparameterization with Gumbel-Softmax. In International Conference on Learning Representations.