This paper proposes a saliency-guided training procedure aimed at reducing noisy gradients, which can lead to unfaithful feature attributions, while maintaining the predictive performance of the model.

During saliency-guided training, for every input $X$, they create a new input $\tilde{X}$ by masking the features with low gradient values as follows:

$$ \tilde{X} = M_{k}(S(\nabla_{X}f_{\theta}(X), X), $$

where $S(\cdot)$ is a sorting function and $M(\cdot)$ replaces the bottom $k$ elements with a mask distribution based on the order provided by $S(\cdot)$.

In addition to the classification loss, the saliency-guided training minimizes the KL divergence between $f_{\theta}(X)$ and $f_{\theta}(\tilde{X})$ to ensure that the trained model produces similar output probability distribution over labels for both masked and unmasked inputs. The optimization problem of the saliency-guided training is:

$$ \text{minimize}_{\theta} \frac{1}{n} \sum_{i=1}^{n} \left[\mathcal{L}(f_{\theta}(X_i), y_i) + \lambda D_{KL}(f_{\theta}(X_i) || f_{\theta}(\tilde{X_i}) ) \right] $$

I believe the superior performance of SGT comes from its ability to reduce the impact of unimportant features during the training phase.

References

Ismail, A. A., Corrada Bravo, H., & Feizi, S. (2021). Improving deep learning interpretability by saliency guided training. Advances in Neural Information Processing Systems, 34, 26726-26739.