Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders
Thread (53 posts)
The L1 penalty also fails to be invariant under reparameterizations of a SAE; by scaling down encoder parameters and scaling up decoder parameters accordingly, it is possible to arbitrarily shrink feature magnitudes, and thus the L1 penalty, without changing either the number of active features or the SAE’s output reconstructions.
Can you give a worked toy example of this?
Thinking Summary
The user is asking for a worked toy example of how the L1 penalty is not invariant under reparameterizations of a SAE. Let me think through this. The key idea is that if we have an SAE with encoder and decoder, we can scale the encoder weights down by some factor and scale the decoder weights up by the same factor, and this will: 1. Keep the reconstruction the same 2. Keep the number of active features the same 3. But reduce the L1 penalty (since feature magnitudes will be smaller) Let me create a simple example to illustrate this.
Sure! Here's a simple toy example to illustrate how the L1 penalty is not invariant under reparameterizations:
Original SAE Setup
Consider a simple SAE with:
- Input dimension: 2
- Dictionary size: 2
- Input vector:
Original parameters:
- Encoder weights:
- Encoder bias:
- Decoder weights:
- Decoder bias:
Forward pass:
- Pre-activations:
- Feature activations:
- Reconstruction:
- L1 penalty:
Reparameterized SAE
Now, let's scale encoder weights down by factor 10 and decoder weights up by factor 10:
Reparameterized parameters:
Forward pass:
- Pre-activations:
- Feature activations:
- Reconstruction:
- L1 penalty:
Key Observations
- Same reconstruction:
- Same sparsity pattern: Both have only the first feature active
- Different L1 penalties: Original = 1, Reparameterized = 0.1
By scaling down the encoder by any factor and scaling up the decoder by , we can make the L1 penalty arbitrarily small while keeping everything else the same. This is why the L1 penalty needs to be paired with decoder weight normalization constraints (typically ) to be meaningful.
The RI-L1 penalty fixes this issue because when we scale decoder weights up, the penalty increases proportionally, canceling out the decrease from smaller feature magnitudes.
Reading this I'm realizing the point I'm unclear on is the reparameterization part: I can certainly imagine scaling the encode and decoder matrices by a constant as in your excerpt but I'm not clear on exactly what the set (group) of allowable reparameterizations aka symmetries are
Thinking Summary
The user is asking about what exactly constitutes the group of allowable reparameterizations/symmetries for SAEs. This is a good question - I gave one example of scaling, but what's the full set? Let me think about this carefully. For a SAE, we have: - Encoder: f(x) = σ(W_enc * x + b_enc) - Decoder: x_hat = W_dec * f + b_dec The key insight is that we want transformations that preserve the input-output mapping x → x_hat. The main reparameterization symmetry is: 1. For each feature i, we can scale the i-th row of W_enc and b_enc[i] by some factor α_i 2. And correspondingly scale the i-th column of W_dec by 1/α_i This works because: - Pre-activation for feature i: (W_enc[i,:] * x + b_enc[i]) gets scaled by α_i - After ReLU: if positive, feature activation f_i gets scaled by α_i - In decoder: W_dec[:,i] * f_i = (W_dec[:,i] / α_i) * (α_i * f_i) = unchanged So the full group is diagonal scaling with positive scalars - one for each feature. There's also potentially: - Permutation symmetry: we can permute the features (reorder rows of W_enc, b_enc and columns of W_dec) - But that's it for exact symmetries that preserve the function For other activation functions there might be other symmetries, but for ReLU these are the main ones.
Great question! Let me clarify the full group of reparameterization symmetries for SAEs.
The Reparameterization Group for ReLU SAEs
For a standard ReLU SAE, the allowable reparameterizations that preserve the input-output function form a group with two types of transformations:
1. Feature-wise Scaling (the main one)
For any positive diagonal matrix where :
- Transform encoder: ,
- Transform decoder:
This works because for each feature :
- Pre-activation gets scaled:
- Post-ReLU activation gets scaled: (when )
- Decoder compensates: contribution to reconstruction is
2. Feature Permutations
For any permutation matrix :
- Transform encoder: ,
- Transform decoder:
This just reorders which feature is which, without changing the overall computation.
The Complete Group Structure
The full reparameterization group is the semidirect product of these two:
where:
- is the group of positive scalings for each feature
- is the symmetric group of permutations on features
- denotes semidirect product (you can scale then permute, or permute then scale with appropriately transformed scalings)
Why This Matters
The L1 penalty is not invariant under the scaling transformations - we can make it arbitrarily small by choosing small values. This is why we need either:
- Decoder normalization constraint: Fix to break the scaling symmetry
- Reparameterization-invariant penalty: Use which is invariant under these transformations
The L0 penalty used in JumpReLU SAEs is naturally invariant under both scaling and permutation, which is one of its advantages!