A major issue with post-hoc explainability is its inability to faithfully represent the model’s underlying decision-making process. This primarily arises from the fact that post-hoc methods typically generate explanations using perturbation techniques, which create perturbed instances by altering the features of an instance, potentially pushing them outside the original data distribution.

Two approaches can address this issue. One approach involving ensure that perturbed instances remain within the original data distribution, typically by leveraging generative models. The other approach requires the model’s predictions to remain unchanged when unimportant features in the instance are perturbed. This paper adopts the latter approach.

They propose the Distractor Erasure Tuning (DiET) method, which enables black-hox models to adapt to the removal of distractor, thereby enhancing model robustness. Consequently, DiET offters attributions that are both discriminative and faithful.

The following is the pseudocode of DiET. alt text

As shown in the pseudocode, DiET employs two objectives that alternately minimized between $\theta$ and $m$.

DiET optimizes $\mathcal{L}_{QFA}$ to obtain the optimal mask. Notably, the mask in this paper defines the signal-distractor decomposition. Assuming that $f_v$ is a $\mathcal{Q}$-robust model, $f_v$ exhibits robustness to the counterfactual distribution $\mathcal{Q}$. Specifically, the feature removal process is governed by the counterfactual distribution $\mathcal{Q}$.

$$ \mathcal{L}_{\mathrm{QFA}}\left(\{\mathbf{m}(\mathbf{x})\}_{\mathbf{x} \in \mathcal{X}}\right)=\underset{\mathbf{x} \in \mathcal{X}}{\mathbb{E}}[\underbrace{\|\mathbf{m}(\mathbf{x})\|_1}_{\text {mask sparsity }}+\lambda_1 \underbrace{\left.\| f_v(\mathbf{x} ; \theta)-f_v\left(\mathbf{x}_{\mathbf{s}}(\mathbf{m}, q)\right) ; \theta\right) \|_1}_{\text {data distillation }}] $$

The objective function is similar to other mask-based methods. The key difference is that DiET utilizes a $\mathcal{Q}$-robust model instead of the original model.

Another objective is $\mathcal{L}_{\text {train }}$. Assuming the optimal signal-distractor decomposition $m$ is known, this objective aims to obtain $f_v$ by optimizing $\mathcal{L}_{\text {train }}$.

$$\mathcal{L}_{\text {train }}\left(\theta,\{\mathbf{m}(\mathbf{x})\}_{\mathbf{x} \in \mathcal{X}}\right)=\underset{\mathbf{x} \in \mathcal{X}}{\mathbb{E}}[\underbrace{\left.\| f_v(\mathbf{x} ; \theta)-f_v\left(\mathbf{x}_{\mathbf{s}}(\mathbf{m}(\mathbf{x}), q)\right) ; \theta\right) \|_1}_{\text {data distillation }}+\lambda_2 \underbrace{\left\|f_b(\mathbf{x})-f_v(\mathbf{x} ; \theta)\right\|_1}_{\text {model distillation }}] $$

Code

In their code, they only use train data to optimize $\mathcal{Q}$-robust model.

# On train dataset
mask_rounding = [0.4, 0.32, 0.24, 0.16, 0.08]
weight = [1, 0.82, 0.64, 0.46, 0.28]
mask = [1, 1, ..., 1]
for k in range(num_rounding_steps):
    while (not mask_converged):
        update_mask() # loss = weight[k] * l_1(m) + l_1(f_v(x) - f_v(x')) + l_1(f_b(x) - f_v(x)); loss = regularization on m + data distillation + model distillation
    mask = torch.round(mask + mask_rounding[k])
    while (not model_converged):
        update_model() # loss = l_1(f_v(x) - f_v(x')) + l_1(f_b(x) - f_v(x)); loss =  data distillation + model distillation
# On test dataset
weight = [1, 0.82, 0.64, 0.46, 0.28]
for k in range(num_rounding_steps):
    mask = [1, 1, ..., 1]
    while (not mask_converged):
        update_mask() # loss = weight[k] * l_1(m) + l_1(f_v(x) - f_v(x')); loss = regularization on m + data distillation

References

Bhalla, U., Srinivas, S., & Lakkaraju, H. (2023, December). Discriminative feature attributions: bridging post hoc explainability and inherent interpretability. In Proceedings of the 37th International Conference on Neural Information Processing Systems (pp. 44105-44122).