A major issue with post-hoc explainability is its inability to faithfully represent the model’s underlying decision-making process. This primarily arises from the fact that post-hoc methods typically generate explanations using perturbation techniques, which create perturbed instances by altering the features of an instance, potentially pushing them outside the original data distribution.
Two approaches can address this issue. One approach involving ensure that perturbed instances remain within the original data distribution, typically by leveraging generative models. The other approach requires the model’s predictions to remain unchanged when unimportant features in the instance are perturbed. This paper adopts the latter approach.
They propose the Distractor Erasure Tuning (DiET) method, which enables black-hox models to adapt to the removal of distractor, thereby enhancing model robustness. Consequently, DiET offters attributions that are both discriminative and faithful.
The following is the pseudocode of DiET.
As shown in the pseudocode, DiET employs two objectives that alternately minimized between $\theta$ and $m$.
$$ \mathcal{L}_{\mathrm{QFA}}\left(\{\mathbf{m}(\mathbf{x})\}_{\mathbf{x} \in \mathcal{X}}\right)=\underset{\mathbf{x} \in \mathcal{X}}{\mathbb{E}}[\underbrace{\|\mathbf{m}(\mathbf{x})\|_1}_{\text {mask sparsity }}+\lambda_1 \underbrace{\left.\| f_v(\mathbf{x} ; \theta)-f_v\left(\mathbf{x}_{\mathbf{s}}(\mathbf{m}, q)\right) ; \theta\right) \|_1}_{\text {data distillation }}] $$The objective function is similar to other mask-based methods. The key difference is that DiET utilizes a $\mathcal{Q}$-robust model instead of the original model.
$$\mathcal{L}_{\text {train }}\left(\theta,\{\mathbf{m}(\mathbf{x})\}_{\mathbf{x} \in \mathcal{X}}\right)=\underset{\mathbf{x} \in \mathcal{X}}{\mathbb{E}}[\underbrace{\left.\| f_v(\mathbf{x} ; \theta)-f_v\left(\mathbf{x}_{\mathbf{s}}(\mathbf{m}(\mathbf{x}), q)\right) ; \theta\right) \|_1}_{\text {data distillation }}+\lambda_2 \underbrace{\left\|f_b(\mathbf{x})-f_v(\mathbf{x} ; \theta)\right\|_1}_{\text {model distillation }}] $$Code
In their code, they only use train data to optimize $\mathcal{Q}$-robust model.
# On train dataset
mask_rounding = [0.4, 0.32, 0.24, 0.16, 0.08]
weight = [1, 0.82, 0.64, 0.46, 0.28]
mask = [1, 1, ..., 1]
for k in range(num_rounding_steps):
while (not mask_converged):
update_mask() # loss = weight[k] * l_1(m) + l_1(f_v(x) - f_v(x')) + l_1(f_b(x) - f_v(x)); loss = regularization on m + data distillation + model distillation
mask = torch.round(mask + mask_rounding[k])
while (not model_converged):
update_model() # loss = l_1(f_v(x) - f_v(x')) + l_1(f_b(x) - f_v(x)); loss = data distillation + model distillation
# On test dataset
weight = [1, 0.82, 0.64, 0.46, 0.28]
for k in range(num_rounding_steps):
mask = [1, 1, ..., 1]
while (not mask_converged):
update_mask() # loss = weight[k] * l_1(m) + l_1(f_v(x) - f_v(x')); loss = regularization on m + data distillation
References
Bhalla, U., Srinivas, S., & Lakkaraju, H. (2023, December). Discriminative feature attributions: bridging post hoc explainability and inherent interpretability. In Proceedings of the 37th International Conference on Neural Information Processing Systems (pp. 44105-44122).