Mechanistic Interpretability: A Mathematical Framework
1. Activation Space Decomposition
Core Idea: Analyze layer activations A ∈ R n × d A \in \mathbb{R}^{n \times d} A ∈ R n × d for n n n samples and d d d neurons through dimensionality reduction.
Methods:
Spectral Decomposition:
W ( d × k ) = U Σ V T (SVD) \underset{(d \times k)}{W} = U\Sigma V^T \quad \text{(SVD)} ( d × k ) W = U Σ V T (SVD)
Truncate to rank r r r : ∥ A − U r Σ r V r T ∥ F 2 ≤ ϵ \|A - U_r\Sigma_r V_r^T\|_F^2 \leq \epsilon ∥ A − U r Σ r V r T ∥ F 2 ≤ ϵ
Nonlinear Manifold Learning:
Find embedding E = f θ ( A ) E = f_\theta(A) E = f θ ( A ) minimizing:
L t o p = ∑ i , j ( d A ( a i , a j ) − ∥ e i − e j ∥ ) 2 \mathcal{L}_{top} = \sum_{i,j} (d_A(a_i,a_j) - \|e_i - e_j\|)^2 L t o p = ∑ i , j ( d A ( a i , a j ) − ∥ e i − e j ∥ ) 2
Strengths: Exact algebraic structure preservation
Limitations: Lie group approximations required for nonlinearities
2. Causal Abstraction
Formalization: For neural network N : X → Y N: \mathcal{X} \to \mathcal{Y} N : X → Y , define concepts C = { c i } \mathcal{C} = \{c_i\} C = { c i } with:
v c ( x ) = E [ N ( x ) ∣ concept c present ] v_c(x) = \mathbb{E}[N(x) | \text{concept } c \text{ present}] v c ( x ) = E [ N ( x ) ∣ concept c present ]
Through intervention operator I c \mathcal{I}_c I c :
Δ N ( x ) = ∥ N ( x ) − I c ( N ( x ) ) ∥ H \Delta N(x) = \|N(x) - \mathcal{I}_c(N(x))\|_\mathcal{H} Δ N ( x ) = ∥ N ( x ) − I c ( N ( x )) ∥ H
Testing Protocol:
Establish variance ratio R c = V [ v c ] V [ N ] R_c = \frac{\mathbb{V}[v_c]}{\mathbb{V}[N]} R c = V [ N ] V [ v c ]
Compute knockoff statistic:
τ c = sup x d d λ N ( x + λ v c ) ∣ λ = 0 \tau_c = \sup_{x} \frac{d}{d\lambda}N(x + \lambda v_c) \big|_{\lambda=0} τ c = sup x d λ d N ( x + λ v c ) λ = 0
3. Differential Analysis
First-Order Tooling:
For input x x x and class logit f c ( x ) f_c(x) f c ( x ) :
Saliency ( x ) = ∇ x f c ( x ) \text{Saliency}(x) = \nabla_x f_c(x) Saliency ( x ) = ∇ x f c ( x )
Second-Order Methods:
Construct influence matrix:
H i j = ∂ 2 L ∂ w i ∂ w j \mathcal{H}_{ij} = \frac{\partial^2 \mathcal{L}}{\partial w_i \partial w_j} H ij = ∂ w i ∂ w j ∂ 2 L
Eigendecomposition:
H = Q Λ Q T ⇒ λ min / λ max = κ − 1 \mathcal{H} = Q \Lambda Q^T \Rightarrow \lambda_{\min}/\lambda_{\max} = \kappa^{-1} H = Q Λ Q T ⇒ λ m i n / λ m a x = κ − 1
4. Computational Circuit Mapping
Component Identification:
Define functional units Φ = { ϕ i : R d → R k } \Phi = \{\phi_i: \mathbb{R}^d \to \mathbb{R}^k\} Φ = { ϕ i : R d → R k }
Compute path activation:
A π ( x ) = ∏ l ∈ π W ( l ) ⋅ x A_\pi(x) = \prod_{l \in \pi} W^{(l)} \cdot x A π ( x ) = ∏ l ∈ π W ( l ) ⋅ x
Optimize subnetwork mask:
min M ∥ N ( x ) − M ⊙ A π ( x ) ∥ + λ ∥ M ∥ 0 \min_M \|N(x) - M \odot A_\pi(x)\| + \lambda\|M\|_0 min M ∥ N ( x ) − M ⊙ A π ( x ) ∥ + λ ∥ M ∥ 0
Validation Metric:
α -compressibility = Circuit Score Full Network Score ≥ 1 − ϵ \alpha\text{-compressibility} = \frac{\text{Circuit Score}}{\text{Full Network Score}} \geq 1 - \epsilon α -compressibility = Full Network Score Circuit Score ≥ 1 − ϵ
5. Algebraic Topology Methods
Homological Analysis:
Construct activation simplicial complex S ϵ \mathcal{S}_\epsilon S ϵ with:
S ϵ = { σ ⊂ A ∣ diam ( σ ) < ϵ } \mathcal{S}_\epsilon = \{ \sigma \subset A | \text{diam}(\sigma) < \epsilon \} S ϵ = { σ ⊂ A ∣ diam ( σ ) < ϵ }
Compute Betti numbers b k ( S ϵ ) b_k(\mathcal{S}_\epsilon) b k ( S ϵ ) versus ϵ \epsilon ϵ
Theorem: For ReLU networks, ∃ ϵ ∗ \exists \epsilon^* ∃ ϵ ∗ where:
b 1 ( S ϵ ∗ ) = number of disconnected decision regions b_1(\mathcal{S}_{\epsilon^*}) = \text{number of disconnected decision regions} b 1 ( S ϵ ∗ ) = number of disconnected decision regions
6. Dynamic System Interpretation
Differential Approximation:
Approximate layer transition as:
d h d t = σ ( W h + b ) ⇒ h t + 1 ≈ h t + Δ t σ ( W h t + b ) \frac{dh}{dt} = \sigma(W h + b) \Rightarrow h_{t+1} \approx h_t + \Delta t \sigma(Wh_t + b) d t d h = σ ( Wh + b ) ⇒ h t + 1 ≈ h t + Δ t σ ( W h t + b )
Stability Analysis:
Compute Lyapunov exponent:
λ max = lim T → ∞ 1 T ∑ t = 0 T − 1 ln ∥ J t v t ∥ \lambda_{\max} = \lim_{T \to \infty} \frac{1}{T} \sum_{t=0}^{T-1} \ln \|J_t v_t\| λ m a x = lim T → ∞ T 1 ∑ t = 0 T − 1 ln ∥ J t v t ∥
where J t J_t J t is Jacobian at time t t t
Synthesis Methodology
Consistency Validation:
For interpretation I I I , verify:
∃ K > 0 : sup x ∈ X ∥ N ( x ) − I ( x ) ∥ ≤ K δ I \exists K > 0: \sup_{x \in \mathcal{X}} \|N(x) - I(x)\| \leq K \delta_I ∃ K > 0 : sup x ∈ X ∥ N ( x ) − I ( x ) ∥ ≤ K δ I
Completeness Metric:
Explanatory Coverage = μ ( { x ∣ δ I ( x ) < ϵ } ) μ ( X ) \text{Explanatory Coverage} = \frac{\mu(\{x | \delta_I(x) < \epsilon\})}{\mu(\mathcal{X})} Explanatory Coverage = μ ( X ) μ ({ x ∣ δ I ( x ) < ϵ })
Fundamental Limits:
For network depth L L L , sample complexity for interpretation:
Ω ( exp ( L ) ⋅ dim ( H ) ) \Omega\left( \exp(L) \cdot \dim(\mathcal{H}) \right) Ω ( exp ( L ) ⋅ dim ( H ) )
(Where H \mathcal{H} H is hypothesis space)
This framework reveals intrinsic tensions:
Completeness vs Compactness: Improved coverage requires exponential parameter growth
Pluralism Requirement: No single method achieves Coverage > 1 − ϵ \text{Coverage} > 1 - \epsilon Coverage > 1 − ϵ for ϵ < 0.2 \epsilon < 0.2 ϵ < 0.2 in practice
Depth Complexity: Interpretation fidelity inversely proportional to L \sqrt{L} L
Current frontiers focus on operator algebra methods and non-uniform approximation theory for bounding interpretation errors.