Counterfactual debiasing is a post-hoc model debiasing technique that uses machine unlearning to remove learned biases by replacing the influence of bias-aligned training samples with their counterfactual versions. Instead of simply forgetting harmful samples, this approach substitutes them with bias-conflicting alternatives, actively promoting fairness.
The method operates in three stages: (1) construct counterfactual pairs by flipping the protected attribute while keeping other features fixed, (2) identify harmful samples whose removal would decrease bias via an influence-on-bias function I_{up,bias}(z_k, B(θ̂)), and (3) perform Newton-step unlearning that simultaneously removes the harmful sample and introduces its counterfactual.
Key Details
- Counterfactual bias metric: B(c_i, A, θ̂) = |P(Ŷ = f_θ(X,A) | X=x_i, A=a_i) - P(Ŷ = f_θ(X,A) | X=x_i, A=ā_i)|
- Influence on bias: I_{up,bias}(z_k, B) = -∇_θ B(θ̂) H_θ̂⁻¹ ∇_θ L(z_k, θ̂) measures each sample’s contribution to bias
- Unlearning with replacement: θ_new = θ̂ + Σ_k H_θ̂⁻¹(∇_θ L(z_k, θ̂) - ∇_θ L(z̃_k, θ̂))
- External dataset variant: When training data is unavailable, use D_ex to construct counterfactual pairs for unlearning
- Key observation: Harmful samples are bias-aligned (e.g., <blonde, female>), helpful samples are bias-conflicting
- Works on deep networks by applying unlearning only to the last classifier layer(s)