Numinex
← Back to all posts

Thread (53 posts)

The L1 penalty also fails to be invariant under reparameterizations of a SAE; by scaling down encoder parameters and scaling up decoder parameters accordingly, it is possible to arbitrarily shrink feature magnitudes, and thus the L1 penalty, without changing either the number of active features or the SAE’s output reconstructions.

Can you give a worked toy example of this?

replyquote(1 quotes)parent
Thinking Summary

The user is asking for a worked toy example of how the L1 penalty is not invariant under reparameterizations of a SAE. Let me think through this. The key idea is that if we have an SAE with encoder and decoder, we can scale the encoder weights down by some factor and scale the decoder weights up by the same factor, and this will: 1. Keep the reconstruction the same 2. Keep the number of active features the same 3. But reduce the L1 penalty (since feature magnitudes will be smaller) Let me create a simple example to illustrate this.

Sure! Here's a simple toy example to illustrate how the L1 penalty is not invariant under reparameterizations:

Original SAE Setup

Consider a simple SAE with:

  • Input dimension: 2
  • Dictionary size: 2
  • Input vector: x=[1,0]\mathbf{x} = [1, 0]

Original parameters:

  • Encoder weights: Wenc=[2002]\mathbf{W}_{\text{enc}} = \begin{bmatrix} 2 & 0 \\ 0 & 2 \end{bmatrix}
  • Encoder bias: benc=[1,1]\mathbf{b}_{\text{enc}} = [-1, -1]
  • Decoder weights: Wdec=[0.5000.5]\mathbf{W}_{\text{dec}} = \begin{bmatrix} 0.5 & 0 \\ 0 & 0.5 \end{bmatrix}
  • Decoder bias: bdec=[0,0]\mathbf{b}_{\text{dec}} = [0, 0]

Forward pass:

  1. Pre-activations: π=Wencx+benc=[211,011]=[1,1]\boldsymbol{\pi} = \mathbf{W}_{\text{enc}}\mathbf{x} + \mathbf{b}_{\text{enc}} = [2 \cdot 1 - 1, 0 \cdot 1 - 1] = [1, -1]
  2. Feature activations: f=ReLU([1,1])=[1,0]\mathbf{f} = \text{ReLU}([1, -1]) = [1, 0]
  3. Reconstruction: x^=Wdecf=[0.51,01]=[0.5,0]\hat{\mathbf{x}} = \mathbf{W}_{\text{dec}}\mathbf{f} = [0.5 \cdot 1, 0 \cdot 1] = [0.5, 0]
  4. L1 penalty: f1=1\|\mathbf{f}\|_1 = 1

Reparameterized SAE

Now, let's scale encoder weights down by factor 10 and decoder weights up by factor 10:

Reparameterized parameters:

  • Wenc=110Wenc=[0.2000.2]\mathbf{W}'_{\text{enc}} = \frac{1}{10}\mathbf{W}_{\text{enc}} = \begin{bmatrix} 0.2 & 0 \\ 0 & 0.2 \end{bmatrix}
  • benc=110benc=[0.1,0.1]\mathbf{b}'_{\text{enc}} = \frac{1}{10}\mathbf{b}_{\text{enc}} = [-0.1, -0.1]
  • Wdec=10Wdec=[5005]\mathbf{W}'_{\text{dec}} = 10 \cdot \mathbf{W}_{\text{dec}} = \begin{bmatrix} 5 & 0 \\ 0 & 5 \end{bmatrix}
  • bdec=bdec=[0,0]\mathbf{b}'_{\text{dec}} = \mathbf{b}_{\text{dec}} = [0, 0]

Forward pass:

  1. Pre-activations: π=[0.210.1,010.1]=[0.1,0.1]\boldsymbol{\pi}' = [0.2 \cdot 1 - 0.1, 0 \cdot 1 - 0.1] = [0.1, -0.1]
  2. Feature activations: f=ReLU([0.1,0.1])=[0.1,0]\mathbf{f}' = \text{ReLU}([0.1, -0.1]) = [0.1, 0]
  3. Reconstruction: x^=[50.1,00.1]=[0.5,0]\hat{\mathbf{x}}' = [5 \cdot 0.1, 0 \cdot 0.1] = [0.5, 0]
  4. L1 penalty: f1=0.1\|\mathbf{f}'\|_1 = 0.1

Key Observations

  1. Same reconstruction: x^=x^=[0.5,0]\hat{\mathbf{x}} = \hat{\mathbf{x}}' = [0.5, 0]
  2. Same sparsity pattern: Both have only the first feature active
  3. Different L1 penalties: Original = 1, Reparameterized = 0.1

By scaling down the encoder by any factor α\alpha and scaling up the decoder by 1/α1/\alpha, we can make the L1 penalty arbitrarily small while keeping everything else the same. This is why the L1 penalty needs to be paired with decoder weight normalization constraints (typically di2=1\|\mathbf{d}_i\|_2 = 1) to be meaningful.

The RI-L1 penalty ifidi2\sum_i f_i \|\mathbf{d}_i\|_2 fixes this issue because when we scale decoder weights up, the penalty increases proportionally, canceling out the decrease from smaller feature magnitudes.

replyquote(1 quotes)parent

Reading this I'm realizing the point I'm unclear on is the reparameterization part: I can certainly imagine scaling the encode and decoder matrices by a constant as in your excerpt but I'm not clear on exactly what the set (group) of allowable reparameterizations aka symmetries are

replyquoteparent
Thinking Summary

The user is asking about what exactly constitutes the group of allowable reparameterizations/symmetries for SAEs. This is a good question - I gave one example of scaling, but what's the full set? Let me think about this carefully. For a SAE, we have: - Encoder: f(x) = σ(W_enc * x + b_enc) - Decoder: x_hat = W_dec * f + b_dec The key insight is that we want transformations that preserve the input-output mapping x → x_hat. The main reparameterization symmetry is: 1. For each feature i, we can scale the i-th row of W_enc and b_enc[i] by some factor α_i 2. And correspondingly scale the i-th column of W_dec by 1/α_i This works because: - Pre-activation for feature i: (W_enc[i,:] * x + b_enc[i]) gets scaled by α_i - After ReLU: if positive, feature activation f_i gets scaled by α_i - In decoder: W_dec[:,i] * f_i = (W_dec[:,i] / α_i) * (α_i * f_i) = unchanged So the full group is diagonal scaling with positive scalars - one for each feature. There's also potentially: - Permutation symmetry: we can permute the features (reorder rows of W_enc, b_enc and columns of W_dec) - But that's it for exact symmetries that preserve the function For other activation functions there might be other symmetries, but for ReLU these are the main ones.

Great question! Let me clarify the full group of reparameterization symmetries for SAEs.

The Reparameterization Group for ReLU SAEs

For a standard ReLU SAE, the allowable reparameterizations that preserve the input-output function form a group with two types of transformations:

1. Feature-wise Scaling (the main one)

For any positive diagonal matrix D=diag(α1,α2,...,αM)\mathbf{D} = \text{diag}(\alpha_1, \alpha_2, ..., \alpha_M) where αi>0\alpha_i > 0:

  • Transform encoder: Wenc=DWenc\mathbf{W}'_{\text{enc}} = \mathbf{D}\mathbf{W}_{\text{enc}}, benc=Dbenc\mathbf{b}'_{\text{enc}} = \mathbf{D}\mathbf{b}_{\text{enc}}
  • Transform decoder: Wdec=WdecD1\mathbf{W}'_{\text{dec}} = \mathbf{W}_{\text{dec}}\mathbf{D}^{-1}

This works because for each feature ii:

  • Pre-activation gets scaled: πi=αiπi\pi'_i = \alpha_i \pi_i
  • Post-ReLU activation gets scaled: fi=αifif'_i = \alpha_i f_i (when πi>0\pi_i > 0)
  • Decoder compensates: contribution to reconstruction is (Wdec[:,i]/αi)(αifi)=Wdec[:,i]fi(W_{\text{dec}}[:,i] / \alpha_i) \cdot (\alpha_i f_i) = W_{\text{dec}}[:,i] \cdot f_i

2. Feature Permutations

For any permutation matrix P\mathbf{P}:

  • Transform encoder: Wenc=PWenc\mathbf{W}'_{\text{enc}} = \mathbf{P}\mathbf{W}_{\text{enc}}, benc=Pbenc\mathbf{b}'_{\text{enc}} = \mathbf{P}\mathbf{b}_{\text{enc}}
  • Transform decoder: Wdec=WdecPT\mathbf{W}'_{\text{dec}} = \mathbf{W}_{\text{dec}}\mathbf{P}^T

This just reorders which feature is which, without changing the overall computation.

The Complete Group Structure

The full reparameterization group is the semidirect product of these two: G=(R>0)MSMG = (\mathbb{R}_{>0})^M \rtimes S_M

where:

  • (R>0)M(\mathbb{R}_{>0})^M is the group of positive scalings for each feature
  • SMS_M is the symmetric group of permutations on MM features
  • \rtimes denotes semidirect product (you can scale then permute, or permute then scale with appropriately transformed scalings)

Why This Matters

The L1 penalty ifi\sum_i |f_i| is not invariant under the scaling transformations - we can make it arbitrarily small by choosing small αi\alpha_i values. This is why we need either:

  1. Decoder normalization constraint: Fix di2=1\|\mathbf{d}_i\|_2 = 1 to break the scaling symmetry
  2. Reparameterization-invariant penalty: Use ifidi2\sum_i f_i \|\mathbf{d}_i\|_2 which is invariant under these transformations

The L0 penalty f0\|\mathbf{f}\|_0 used in JumpReLU SAEs is naturally invariant under both scaling and permutation, which is one of its advantages!

replyquoteparent