Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders
Thread (53 posts)
loss function that is simply the weighted sum of a L2 reconstruction error term and a L0 sparsity penalty, eschewing easier-to-train proxies to L0, such as L1, and avoiding the need for auxiliary tasks to train the threshold.
Our key insight is to notice that although such a loss function is piecewise-constant with respect to the threshold – and therefore provides zero gradient to train this parameter – the derivative of the expected loss can be analytically derived, and is generally non-zero, albeit it is expressed in terms of probability densities of the feature activation distribution that need to be estimated.
First part makes sense but when they say expected loss, what is the precise setup for that expectation?
Thinking Summary
The user is asking about the expected loss in the JumpReLU SAE paper. Looking at the paper, they introduce a loss function that has a JumpReLU activation function with threshold parameters θ. The loss function is piecewise constant with respect to θ, so the gradient is zero. But they claim the derivative of the expected loss is non-zero. The user wants to know what expectation they're talking about. Looking at the paper, in Section 4 "How STEs enable training through the jump", they say: "Why does this work? The key is to notice that during SGD, we actually want to estimate the gradient of the expected loss, 𝔼_𝐱[ℒ_θ(𝐱)], in order to calculate parameter updates" So the expectation is over the data distribution - they're taking the expectation over input activations 𝐱 drawn from the dataset 𝒟. In Appendix B, they derive this more formally. The loss function is: ℒ_θ(𝐱) = ||𝐱 - 𝐱̂(𝐟(𝐱))||²₂ + λ||𝐟(𝐱)||₀ And they want to compute: ∂𝔼_𝐱[ℒ_θ(𝐱)]/∂θᵢ where 𝐱 are samples from the training data distribution. The key insight is that while ℒ_θ(𝐱) is piecewise constant in θ for any fixed 𝐱, the expected loss 𝔼_𝐱[ℒ_θ(𝐱)] is smooth because you're averaging over many different values of 𝐱, and the locations of the discontinuities vary with 𝐱.
Looking at Section 4 and Appendix B of the paper, the expectation is taken over the data distribution. Specifically:
𝔼_𝐱[ℒ_θ(𝐱)] where 𝐱 ~ 𝒟
Here, 𝐱 represents language model activations sampled from the training dataset 𝒟.
The key insight is that while the loss ℒ_θ(𝐱) is piecewise constant in θ for any fixed input 𝐱 (because the JumpReLU and Heaviside functions create discontinuities), when you take the expectation over many different inputs 𝐱, the expected loss becomes smooth and differentiable.
This happens because:
- For a fixed 𝐱, the discontinuities occur when pre-activations π_i(𝐱) equal the threshold θ_i
- Different inputs 𝐱 have different pre-activation values
- So the discontinuity locations vary across the data distribution
- Averaging over the distribution smooths out these discontinuities
In Equation 13, they show the derivative of this expected loss:
where is the probability density of pre-activations across the data distribution.
Looking at Section 4 and Appendix B of the paper, the expectation is taken over the data distribution. Specifically:
where
Here, represents language model activations sampled from the training dataset .
The key insight is that while the loss is piecewise constant in for any fixed input (because the JumpReLU and Heaviside functions create discontinuities), when you take the expectation over many different inputs , the expected loss becomes smooth and differentiable.
This happens because:
- For a fixed , the discontinuities occur when pre-activations equal the threshold
- Different inputs $\mathb
Generation Error: streaming error: received error while streaming: {"type":"error","error":{"details":null,"type":"overloaded_error","message":"Overloaded"} }
[deleted]
[deleted]
Generation Error: streaming error: received error while streaming: {"type":"error","error":{"details":null,"type":"overloaded_error","message":"Overloaded"} }
Our key insight is to notice that although such a loss function is piecewise-constant with respect to the threshold – and therefore provides zero gradient to train this parameter – the derivative of the expected loss can be analytically derived, and is generally non-zero, albeit it is expressed in terms of probability densities of the feature activation distribution that need to be estimated.
First part makes sense but when they say expected loss, what is the precise setup for that expectation?
Thinking Summary
The user is asking about the expectation in the expected loss for JumpReLU SAEs. Let me look at the paper details to understand what they mean by expected loss. From the paper: In Section 4, they write: "The key is to notice that during SGD, we actually want to estimate the gradient of the expected loss, $\mathbb{E}_{\mathbf{x}}\left[\mathcal{L}_{\boldsymbol{\theta}}(\mathbf{x})\right]$, in order to calculate parameter updates" So the expectation is over the data distribution. More specifically: - $\mathbf{x}$ represents language model activations - They're training on a dataset of language model activations $\mathbf{x} \sim \mathcal{D}$ - The loss function $\mathcal{L}_{\boldsymbol{\theta}}(\mathbf{x})$ depends on the input activations $\mathbf{x}$ and the threshold parameters $\boldsymbol{\theta}$ - The expected loss is $\mathbb{E}_{\mathbf{x}}\left[\mathcal{L}_{\boldsymbol{\theta}}(\mathbf{x})\right]$ where the expectation is taken over the distribution of language model activations The key insight is that while the loss for a single input is piecewise constant w.r.t. the threshold (so has zero gradient almost everywhere), the expected loss over the distribution of inputs is smooth and has non-zero gradients. This is explained in more detail in Appendix B where they derive the gradient of the expected loss.
The expectation is taken over the distribution of language model activations that the SAE is being trained on.
Specifically, in the paper's notation:
- represents language model activations sampled from the training dataset
- The loss function is which depends on both the input activations and the threshold parameters
- The expected loss is
The key insight is that while the loss for any single input is piecewise constant with respect to the threshold (and thus has zero gradient almost everywhere), the expected loss over the distribution of all possible activations is smooth.
This happens because different activation samples will have their pre-activations distributed around the threshold. As you vary the threshold slightly, you change which samples fall above or below it, creating a non-zero gradient in expectation.
In Appendix B, they show this gradient is:
where is the probability density function of the pre-activation distribution . This density needs to be estimated from data, which is where the kernel density estimation techniques come in.
JumpReLU SAEs only require a single forward and backward pass during a training step and have an elementwise activation function (unlike TopK, which requires a partial sort), making them more efficient to train than either Gated or TopK SAEs.
Can you elaborate on the training efficiency point here in more detail?
Thinking Summary
The user is asking me to elaborate on the training efficiency claim made in the JumpReLU paper. The paper states that JumpReLU SAEs are more efficient to train than both Gated and TopK SAEs because: 1. They only require a single forward and backward pass 2. They have an elementwise activation function (unlike TopK which requires a partial sort) Let me break down the efficiency comparison: **JumpReLU SAEs:** - Single forward pass: compute pre-activations, apply JumpReLU (elementwise operation) - Single backward pass: use straight-through estimators for gradients - JumpReLU is elementwise: for each feature, just check if π_i(x) > θ_i **Gated SAEs:** - Need two encoder computations: π_gate and π_mag - Additional auxiliary loss term requires frozen decoder copy - More complex forward pass with element-wise multiplication of gate and magnitude - Requires more memory for the two encoder paths **TopK SAEs:** - Requires finding the top K activations, which involves sorting or partial sorting - Partial sort has O(M log K) complexity where M is the dictionary size - Even with approximate algorithms (like jax.lax.approx_max_k), still more expensive than elementwise ops - AuxK auxiliary loss adds additional computation The key efficiency gains for JumpReLU are: 1. Simpler architecture (single encoder path vs dual for Gated) 2. Elementwise operations only (vs sorting for TopK) 3. No auxiliary loss terms needed (vs both Gated and TopK which use aux losses)
The training efficiency advantages of JumpReLU SAEs stem from their simpler computational requirements compared to Gated and TopK architectures:
JumpReLU vs Gated SAEs
Gated SAEs require:
- Two encoder paths: Computing both and means double the encoder computations
- Auxiliary loss with frozen decoder: The auxiliary term requires maintaining a frozen copy of the decoder and computing an additional reconstruction
- Element-wise multiplication: The final features are computed as
JumpReLU SAEs only need:
- Single encoder path: Just compute once
- No auxiliary loss: The L0 penalty provides sufficient gradient signal without auxiliary terms
- Simple thresholding: Apply JumpReLU elementwise
JumpReLU vs TopK SAEs
TopK SAEs require:
- Partial sorting: Finding the top K activations has complexity where is the dictionary size (e.g., 131k features)
- Even approximate TopK (like
jax.lax.approx_max_k
) is more expensive than elementwise operations - AuxK auxiliary loss: Adds computational overhead to reduce dead features
JumpReLU's elementwise operation is just:
feature_magnitudes = x * (x > threshold) # O(M) elementwise
Memory Efficiency
JumpReLU also has lower memory requirements:
- No need to store frozen decoder copies (unlike Gated)
- No need to maintain sorting indices or auxiliary computations (unlike TopK)
- Straight-through estimators add minimal memory overhead
The paper notes these efficiency gains are particularly important when training at scale - with 131k-width SAEs on billions of tokens, even small per-step efficiency improvements compound significantly.
JumpReLU SAEs only require a single forward and backward pass during a training step and have an elementwise activation function (unlike TopK, which requires a partial sort), making them more efficient to train than either Gated or TopK SAEs.
Can you elaborate on the training efficiency point here in more detail? Only focus comparison on TopK
Thinking Summary
The user is asking me to elaborate on the training efficiency comparison between JumpReLU and TopK SAEs, specifically focusing on TopK. The paper mentions that JumpReLU is more efficient to train than TopK, and I should explain why. From the paper, the key points are: 1. JumpReLU uses a single forward and backward pass 2. JumpReLU has an elementwise activation function 3. TopK requires a partial sort operation Let me think about the computational complexity implications: - TopK requires finding the top K values among M features for each input - This is typically O(M log K) or O(M) with partial sorting algorithms - JumpReLU just applies an elementwise threshold comparison, which is O(M) The paper mentions they use `jax.lax.approx_max_k` for their TopK implementation, which is an approximate algorithm that's faster than exact TopK but still more expensive than elementwise operations.
The training efficiency advantage of JumpReLU over TopK SAEs comes down to the fundamental difference in their activation mechanisms:
TopK's computational overhead:
- Partial sorting requirement: For each input, TopK must identify and retain only the K largest pre-activations out of M total features (where M is the dictionary size, e.g., 131k in the paper's experiments)
- Complexity: Even with efficient algorithms, this requires O(M log K) operations for exact TopK, or O(M) for approximate methods
- Non-local operation: The activation of any feature depends on the values of all other features, requiring global information
JumpReLU's efficiency:
- Simple elementwise comparison: Each feature activation is determined by a single comparison:
if pre_activation > threshold
- Complexity: O(M) with a very low constant factor - just M comparisons
- Local operation: Each feature's activation is independent of others
Practical implications:
The paper mentions using jax.lax.approx_max_k
(based on Chern et al., 2022) for TopK implementation, which is faster than exact sorting but still involves:
- Multiple passes through the data to approximate the top-K selection
- Additional memory access patterns that are less cache-friendly
- More complex gradient computation in the backward pass
For a 131k-width SAE trained on billions of tokens, these differences compound significantly. The elementwise nature of JumpReLU also makes it more amenable to parallelization and hardware acceleration compared to the inherently sequential nature of sorting operations.
This efficiency gain is particularly important when training multiple SAEs with different sparsity levels or when scaling to larger models and dictionary sizes.
In the caption for figure 1, "consider a direction" - what precisely does direction mean here, in excrete dear?
Thinking Summary
The user is asking about Figure 1's caption, specifically what "direction" means in this context. They have a typo - "excrete dear" should probably be "concrete terms" or similar. Looking at Figure 1's caption in the paper: "Figure 1: A toy model illustrating why JumpReLU (or similar activation functions, such as TopK) are an improvement over ReLU for training sparse yet faithful SAEs. Consider a direction in which the encoder pre-activation is high when the corresponding feature is active and low, but not always negative, when the feature is inactive (far-left)." In this context, "direction" refers to a feature direction in the SAE - essentially one of the learned dictionary elements that the SAE uses to decompose activations. Each "direction" corresponds to one feature in the SAE's dictionary, represented by: - A row in the encoder weight matrix W_enc - A column in the decoder weight matrix W_dec The figure is showing how the encoder pre-activation (the value before applying the activation function) behaves for one particular feature/direction.
In the context of Figure 1's caption, "direction" refers to a feature direction in the Sparse Autoencoder (SAE) - essentially one of the learned dictionary elements.
More precisely, each direction corresponds to:
- A row in the encoder weight matrix
- A column in the decoder weight matrix (denoted as in the paper)
So when the caption says "Consider a direction in which the encoder pre-activation is high...", it's referring to one specific feature in the SAE's dictionary. The figure shows how the encoder pre-activation behaves for that particular feature .
The key insight illustrated is that for this feature:
- When the feature is truly "active" (should be detected), the pre-activation is high
- When the feature is "inactive" (shouldn't be detected), the pre-activation is low but sometimes still positive
This creates the challenge that ReLU addresses poorly but JumpReLU handles well by introducing an adjustable threshold.
Just to double check, what basis is that row of the encoder matrix expressed in?
Thinking Summary
The user is asking about the basis for the encoder weight matrix rows. Let me think carefully about this. The encoder takes in activations $\mathbf{x} \in \mathbb{R}^n$ from the language model and produces pre-activations. So the encoder weight matrix $\mathbf{W}_{\text{enc}}$ has shape $M \times n$ where: - $M$ is the number of features in the SAE dictionary - $n$ is the dimension of the language model activations Each row of $\mathbf{W}_{\text{enc}}$ is therefore a vector in $\mathbb{R}^n$, expressed in the same basis as the input language model activations $\mathbf{x}$. This basis is the standard basis of the language model's activation space at whatever layer/site is being analyzed (e.g., residual stream, MLP output, or attention output).
Each row of the encoder matrix is expressed in the same basis as the language model activations .
Specifically:
- If the SAE is trained on residual stream activations, the basis is the standard basis of the residual stream space
- If trained on MLP outputs, it's the standard basis of the MLP output space
- If trained on attention outputs, it's the standard basis of the attention output space
So each encoder row is a vector in that defines a linear projection from the LM's activation space. When we compute , each element of the result is the dot product between a row of and the input activation vector , both expressed in the same coordinate system.
This is in contrast to the decoder columns , which also live in the same space but represent the actual feature directions that the SAE uses to reconstruct activations.
Am i correctly understanding the caption / point of figure I to be that because JumpReLU has a vector of offsets which is learned during training, it can precisely identify the "intended" activations without alter them?
Thinking Summary
The user is asking about Figure 1 in the paper. Let me look at what Figure 1 shows and its caption. From the paper: "Figure 1: A toy model illustrating why JumpReLU (or similar activation functions, such as TopK) are an improvement over ReLU for training sparse yet faithful SAEs. Consider a direction in which the encoder pre-activation is high when the corresponding feature is active and low, but not always negative, when the feature is inactive (far-left). Applying a ReLU activation function fails to remove all false positives (centre-left), harming sparsity. It is possible to get rid of false positives while maintaining the ReLU, e.g. by decreasing the encoder bias (centre-right), but this leads to feature magnitudes being systematically underestimated, harming fidelity. The JumpReLU activation function (far-right) provides an independent threshold below which pre-activations are screened out, minimising false positives, while leaving pre-activations above the threshold unaffected, improving fidelity." So the key points are: 1. The encoder pre-activation can be high when the feature is truly active, but also sometimes positive (though lower) when the feature is inactive 2. ReLU passes through all positive values, so it includes false positives (when the feature is inactive but pre-activation is positive) 3. You could lower the bias to make more pre-activations negative, but then you systematically underestimate the magnitudes of true activations 4. JumpReLU has a learned threshold θ that can be set to separate true activations from false positives, while preserving the magnitudes of true activations above the threshold The user is asking if they correctly understand that JumpReLU can "precisely identify the 'intended' activations without altering them" - this is essentially correct. The threshold allows it to filter out false positives while preserving the magnitudes of true activations.
Yes, you've got it exactly right! Figure 1 illustrates the key advantage of JumpReLU over standard ReLU for sparse autoencoders.
The core insight is that with a learned threshold vector , JumpReLU can:
-
Filter out false positives: Set the threshold above the "noise floor" of pre-activations that occur when the feature is actually inactive (but happen to be slightly positive)
-
Preserve true activation magnitudes: Unlike solutions that shift the bias to make more pre-activations negative, JumpReLU doesn't systematically shrink the magnitudes of true activations - it just passes them through unchanged if they're above the threshold
The figure shows this nicely:
- ReLU (center-left): Lets through all positive values, including false positives
- ReLU with shifted bias (center-right): Eliminates false positives but systematically underestimates all magnitudes
- JumpReLU (far-right): Clean separation - zeros out everything below , preserves everything above
This is why the paper emphasizes that JumpReLU (like TopK and Gated SAEs) can "separate out deciding which features are active from estimating active features' magnitudes." The threshold handles the binary active/inactive decision, while the magnitude above threshold captures the strength of activation.
OK and the point of this part is that there's actually a way to learn that parameter theta?
Thinking Summary
The user is asking about whether this part (referring to the quote about expected loss) is showing that there's actually a way to learn the parameter theta. They seem to understand that the individual loss is piecewise constant (no gradient) but are asking about how the expected loss insight enables learning theta.
Yes, exactly! This is the paper's key technical contribution - showing how to actually learn despite the fact that the loss function has zero gradient with respect to it.
The clever trick is:
- Individual samples give no gradient: For any single activation , the loss is piecewise constant in (flat everywhere except at discontinuities)
- But the expected loss is smooth: When you average over the distribution of all activations, you get a smooth function with non-zero gradients
The authors then use straight-through estimators (STEs) to approximate this expected gradient during training. Specifically:
- They define pseudo-derivatives (Equations 11-12) that approximate the true gradient of the expected loss
- These use kernel functions to "smooth out" the discontinuity in a small window of width around the threshold
- During backpropagation, whenever they would normally get zero gradient, they substitute these pseudo-derivatives instead
The mathematical justification (Section 4) shows that this STE approach is actually computing an unbiased kernel density estimate of the true expected loss gradient:
So yes - despite the discontinuous activation function, they can train effectively using standard gradient descent, because the STEs give them a principled way to estimate gradients that would otherwise be zero.
What is theta called in the literature? Does it have any other interesting interpretations than the obvious one.
Thinking Summary
The user is asking about what $\theta$ (the threshold parameter in JumpReLU) is called in the literature and if it has any interesting interpretations beyond the obvious one of being a threshold. Let me think about this: 1. In the JumpReLU paper, they simply call it the "threshold" or "threshold parameter" 2. In the context of sparse coding/dictionary learning, similar parameters might be called: - Activation threshold - Firing threshold - Detection threshold - Gating threshold (in Gated SAEs) 3. Interesting interpretations: - Statistical interpretation: $\theta_i$ could be seen as defining a hypothesis test threshold - we're testing whether feature $i$ is "truly present" vs just noise - Information theoretic: It's setting a signal-to-noise threshold - Neuroscience analogy: Similar to neuronal firing thresholds - Bayesian interpretation: Could be viewed as related to a decision boundary in a classification problem - It implicitly defines a "baseline activation level" below which we consider the feature to be absent Let me check if the paper mentions any of these interpretations...
In this paper and related SAE literature, is typically just called the threshold or threshold parameter. In the Gated SAE paper, the equivalent parameter (when using weight sharing) emerges from the "gating" mechanism, hence "gating threshold."
Some interesting interpretations beyond the obvious:
-
Statistical significance threshold: implicitly defines a hypothesis test - "is this feature truly active or just noise?" The threshold separates signal from background fluctuations in the pre-activations.
-
Implicit prior on sparsity: The threshold encodes an assumption about how strong a feature's pre-activation needs to be before we believe it's "really there." Higher thresholds encode stronger priors that features should be rare.
-
Neuroscience analogy: Similar to neuronal firing thresholds - a neuron needs sufficient input before it "decides" to fire. The JumpReLU makes this binary decision explicit.
-
Information-theoretic view: The threshold defines a channel capacity constraint - features must exceed a certain "signal strength" to transmit information through the autoencoder.
-
Learned noise floor: As Figure 1 illustrates, can be interpreted as learning the typical "false positive" activation level for each feature when it's not truly present.
The paper hints at another interpretation in Equation 13 - the gradient involves terms that measure:
- How much reconstruction error increases when we raise the threshold (screening out marginal activations)
- How much sparsity improves when we raise the threshold
So represents the learned equilibrium point where the marginal cost to reconstruction equals times the marginal benefit to sparsity.
The L1 penalty also fails to be invariant under reparameterizations of a SAE; by scaling down encoder parameters and scaling up decoder parameters accordingly, it is possible to arbitrarily shrink feature magnitudes, and thus the L1 penalty, without changing either the number of active features or the SAE’s output reconstructions.
Can you give a worked toy example of this?
Thinking Summary
The user is asking for a worked toy example of how the L1 penalty is not invariant under reparameterizations of a SAE. Let me think through this. The key idea is that if we have an SAE with encoder and decoder, we can scale the encoder weights down by some factor and scale the decoder weights up by the same factor, and this will: 1. Keep the reconstruction the same 2. Keep the number of active features the same 3. But reduce the L1 penalty (since feature magnitudes will be smaller) Let me create a simple example to illustrate this.
Sure! Here's a simple toy example to illustrate how the L1 penalty is not invariant under reparameterizations:
Original SAE Setup
Consider a simple SAE with:
- Input dimension: 2
- Dictionary size: 2
- Input vector:
Original parameters:
- Encoder weights:
- Encoder bias:
- Decoder weights:
- Decoder bias:
Forward pass:
- Pre-activations:
- Feature activations:
- Reconstruction:
- L1 penalty:
Reparameterized SAE
Now, let's scale encoder weights down by factor 10 and decoder weights up by factor 10:
Reparameterized parameters:
Forward pass:
- Pre-activations:
- Feature activations:
- Reconstruction:
- L1 penalty:
Key Observations
- Same reconstruction:
- Same sparsity pattern: Both have only the first feature active
- Different L1 penalties: Original = 1, Reparameterized = 0.1
By scaling down the encoder by any factor and scaling up the decoder by , we can make the L1 penalty arbitrarily small while keeping everything else the same. This is why the L1 penalty needs to be paired with decoder weight normalization constraints (typically ) to be meaningful.
The RI-L1 penalty fixes this issue because when we scale decoder weights up, the penalty increases proportionally, canceling out the decrease from smaller feature magnitudes.
Reading this I'm realizing the point I'm unclear on is the reparameterization part: I can certainly imagine scaling the encode and decoder matrices by a constant as in your excerpt but I'm not clear on exactly what the set (group) of allowable reparameterizations aka symmetries are
Thinking Summary
The user is asking about what exactly constitutes the group of allowable reparameterizations/symmetries for SAEs. This is a good question - I gave one example of scaling, but what's the full set? Let me think about this carefully. For a SAE, we have: - Encoder: f(x) = σ(W_enc * x + b_enc) - Decoder: x_hat = W_dec * f + b_dec The key insight is that we want transformations that preserve the input-output mapping x → x_hat. The main reparameterization symmetry is: 1. For each feature i, we can scale the i-th row of W_enc and b_enc[i] by some factor α_i 2. And correspondingly scale the i-th column of W_dec by 1/α_i This works because: - Pre-activation for feature i: (W_enc[i,:] * x + b_enc[i]) gets scaled by α_i - After ReLU: if positive, feature activation f_i gets scaled by α_i - In decoder: W_dec[:,i] * f_i = (W_dec[:,i] / α_i) * (α_i * f_i) = unchanged So the full group is diagonal scaling with positive scalars - one for each feature. There's also potentially: - Permutation symmetry: we can permute the features (reorder rows of W_enc, b_enc and columns of W_dec) - But that's it for exact symmetries that preserve the function For other activation functions there might be other symmetries, but for ReLU these are the main ones.
Great question! Let me clarify the full group of reparameterization symmetries for SAEs.
The Reparameterization Group for ReLU SAEs
For a standard ReLU SAE, the allowable reparameterizations that preserve the input-output function form a group with two types of transformations:
1. Feature-wise Scaling (the main one)
For any positive diagonal matrix where :
- Transform encoder: ,
- Transform decoder:
This works because for each feature :
- Pre-activation gets scaled:
- Post-ReLU activation gets scaled: (when )
- Decoder compensates: contribution to reconstruction is
2. Feature Permutations
For any permutation matrix :
- Transform encoder: ,
- Transform decoder:
This just reorders which feature is which, without changing the overall computation.
The Complete Group Structure
The full reparameterization group is the semidirect product of these two:
where:
- is the group of positive scalings for each feature
- is the symmetric group of permutations on features
- denotes semidirect product (you can scale then permute, or permute then scale with appropriately transformed scalings)
Why This Matters
The L1 penalty is not invariant under the scaling transformations - we can make it arbitrarily small by choosing small values. This is why we need either:
- Decoder normalization constraint: Fix to break the scaling symmetry
- Reparameterization-invariant penalty: Use which is invariant under these transformations
The L0 penalty used in JumpReLU SAEs is naturally invariant under both scaling and permutation, which is one of its advantages!
This is a loss function of the standard form Eq. 5 where crucially we are using a L0 sparsity penalty to avoid the limitations of training with a L1 sparsity penalty (Wright and Sharkey, 2024; Rajamanoharan et al., 2024)
What are the limitations of training with an L1 sparsity penalty in the inked paper?
Thinking Summary
The user is asking about the limitations of training with an L1 sparsity penalty that are mentioned in the JumpReLU SAEs paper. The paper references Wright and Sharkey (2024) and Rajamanoharan et al. (2024) for this point. Looking at the paper text, I can find several mentions of L1 penalty limitations: 1. In the Sparsity penalties section: "While this has the advantage of providing a useful gradient for training (unlike the L0-norm), it has the disadvantage of penalising feature magnitudes in addition to sparsity, which harms reconstruction fidelity (Rajamanoharan et al., 2024; Wright and Sharkey, 2024)." 2. Also in the same section: "The L1 penalty also fails to be invariant under reparameterizations of a SAE; by scaling down encoder parameters and scaling up decoder parameters accordingly, it is possible to arbitrarily shrink feature magnitudes, and thus the L1 penalty, without changing either the number of active features or the SAE's output reconstructions. As a result, it is necessary to impose a further constraint on SAE parameters during training to enforce sparsity: typically this is achieved by constraining columns of the decoder weight matrix 𝐝_i to have unit norm (Bricken et al., 2023)." 3. In Section 3: "This is a loss function of the standard form Eq. 5 where crucially we are using a L0 sparsity penalty to avoid the limitations of training with a L1 sparsity penalty (Wright and Sharkey, 2024; Rajamanoharan et al., 2024)." The linked paper "Addressing Feature Suppression in SAEs" by Wright and Sharkey seems to discuss this issue in more detail. So the main limitations are: 1. L1 penalizes feature magnitudes in addition to sparsity, which harms reconstruction fidelity 2. L1 penalty is not invariant under reparameterizations of the SAE - you can scale encoder weights down and decoder weights up to arbitrarily reduce the L1 penalty without changing the actual sparsity or reconstructions 3. This requires additional constraints during training (like constraining decoder columns to unit norm) I notice the SAE training approach reveals critical weaknesses in using L1 penalties. The technique allows parameter manipulation that artificially reduces penalty without genuinely improving feature sparsity. Researchers must implement additional normalization constraints to prevent this parameter scaling trick, typically by enforcing unit norm on decoder weight columns to maintain meaningful sparsity measurements.
The paper identifies two key limitations of L1 sparsity penalties for training SAEs:
1. Feature Magnitude Suppression
The L1 penalty penalizes feature magnitudes in addition to sparsity, which harms reconstruction fidelity. This is because L1 measures the sum of absolute values, so it incentivizes making active features smaller even when they should have larger values for accurate reconstruction.
2. Lack of Reparameterization Invariance
The L1 penalty is not invariant under reparameterizations of the SAE. Specifically:
- You can scale down encoder parameters and scale up decoder parameters proportionally
- This arbitrarily shrinks feature magnitudes (and thus the L1 penalty)
- But it doesn't change the number of active features or the SAE's output reconstructions
- This creates a "cheat" where the model can reduce the loss without actually becoming more sparse
Due to this issue, L1-trained SAEs require additional constraints during training, typically constraining decoder weight matrix columns (𝐝ᵢ) to have unit norm.
The paper introduces the RI-L1 (Reparameterization-Invariant L1) penalty as an alternative:
This weights feature coefficients by the norms of corresponding dictionary directions, making it invariant to the reparameterization problem.
By using L0 instead, JumpReLU SAEs avoid both issues - L0 only counts the number of non-zero features, not their magnitudes, and is inherently invariant to reparameterization.
Sorry when you say "the paper" which paper do you mean? I was asking about the contents of the "addressing feature suppression"
Thinking Summary
Oh, I see! The user was asking about what the linked paper "Addressing Feature Suppression in SAEs" by Wright and Sharkey (2024) says about the limitations of L1 sparsity penalties, not what the JumpReLU paper says about it. I should look at the content snapshot from the embedded link to see what information is available about that paper. Looking at the embedded content, I only see: - Title: "Addressing Feature Suppression in SAEs" - URL: https://www.alignmentforum.org/posts/3JuSjTZyMzaSeTxKk/addressing-feature-suppression-in-saes - Description: "Produced as part of the ML Alignment Theory Scholars Program - Winter 2023-24 Cohort as part of Lee Sharkey's stream. …" - Content Snapshot: "AI ALIGNMENT FORUMAF" Unfortunately, the content snapshot doesn't actually contain any substantive content from the Wright and Sharkey paper - it just shows "AI ALIGNMENT FORUMAF" which appears to be a header or navigation element from the website. So I cannot actually tell what Wright and Sharkey (2024) specifically says about the limitations of L1 sparsity penalties based on the provided embed. I should clarify this to the user.
You're right - I apologize for the confusion! You were asking about what Wright and Sharkey (2024) "Addressing Feature Suppression in SAEs" says about L1 limitations.
Unfortunately, the embedded content snapshot for that paper only shows "AI ALIGNMENT FORUMAF" and doesn't include the actual content of the paper. Based on just the title "Addressing Feature Suppression in SAEs", I can infer that Wright and Sharkey likely discuss how L1 penalties cause feature suppression (reducing the magnitudes of features that should be active), but I cannot see their specific arguments or findings from the provided embed.
To properly answer your question about what limitations of L1 sparsity penalties Wright and Sharkey identify, I would need access to the actual content of their paper rather than just the title and URL.
Equation 10 is just "inlining" the defn of JumpReLU into the definition Of the loss function in equation 9 correct?
Thinking Summary
The user is asking about Equation 10 in the paper. Let me find it: Looking at the paper, Equation 9 defines the loss function: $$\mathcal{L}(\mathbf{x}):=\underbrace{\left\lVert\mathbf{x}-\hat{\mathbf{x}}(\mathbf{f}(\mathbf{x}))\right\rVert_{2}^{2}}_{\mathcal{L}_{\text{reconstruct}}}+\underbrace{\lambda\left\lVert\mathbf{f}(\mathbf{x})\right\rVert_{0}}_{\mathcal{L}_{\text{sparsity}}}.$$ Then Equation 10 says: $$\mathcal{L}_{\text{sparsity}}:=\lambda\left\lVert\mathbf{f}(\mathbf{x})\right\rVert_{0}=\lambda\sum_{i=1}^{M}H(\pi_{i}(\mathbf{x})-\theta_{i}).$$ Where $\mathbf{f}(\mathbf{x})$ is defined in Equation 8 as: $$\mathbf{f}(\mathbf{x}):=\text{JumpReLU}_{\boldsymbol{\theta}}\left(\mathbf{W}_{\text{enc}}\mathbf{x}+\mathbf{b}_{\text{enc}}\right).$$ And JumpReLU is defined in Equation 4 as: $$\text{JumpReLU}_{\theta}(z):=z\,H(z-\theta)$$ So if we have $\mathbf{f}(\mathbf{x})$ with component $f_i(\mathbf{x}) = \text{JumpReLU}_{\theta_i}(\pi_i(\mathbf{x}))$, where $\pi_i(\mathbf{x})$ is the i-th pre-activation, then: $f_i(\mathbf{x}) = \pi_i(\mathbf{x}) \cdot H(\pi_i(\mathbf{x}) - \theta_i)$ The L0 norm essentially tracks non-zero entries. Specifically, $\|f_i(\mathbf{x})\|_0$ equals 1 when $f_i(\mathbf{x})$ is non-zero, and 0 otherwise. This means the component becomes zero only when the pre-activation falls below the threshold. The Heaviside function determines when the feature becomes active, effectively creating a sparse representation by selectively activating features based on their pre-activation values relative to learned thresholds. Consequently, the total L0 norm across all features sums the number of activated components, providing a compact measure of feature sparsity in the representation.
Yes, exactly! Equation 10 is showing that when you expand the L0 norm of the feature vector using the JumpReLU definition, you get a sum of Heaviside step functions.
Here's the reasoning:
- Each feature
- The L0 norm counts non-zero entries
- is non-zero exactly when (since pre-activations are continuous, happens with probability 0)
- Therefore
So the sparsity penalty becomes:
This reformulation is crucial for the paper's approach because it shows that the threshold parameters only appear inside Heaviside functions in the loss, which is what motivates their use of straight-through estimators to get gradient signals for training these thresholds.
Can you write a short tutorial on straight - through estimators?
Thinking Summary
The user is asking for a tutorial on straight-through estimators (STEs). This is a good opportunity to explain this important concept clearly, as it's central to the JumpReLU paper and many other works in deep learning. I should: 1. Start with the motivation/problem 2. Explain the basic idea 3. Give a simple example 4. Show how it works in practice 5. Discuss when/why it works 6. Maybe mention some applications I should keep it concise but comprehensive, and use examples to make it concrete.
A Tutorial on Straight-Through Estimators (STEs)
The Problem
Imagine you're training a neural network with a step function:
def step(x):
return 1.0 if x > 0 else 0.0
The gradient of this function is zero everywhere (except at x=0 where it's undefined). If you try to backpropagate through it, you get zero gradients - your parameters won't update!
This problem appears whenever we have:
- Discrete decisions (rounding, thresholding, sampling)
- Piecewise constant functions
- Non-differentiable operations
The Core Idea
A Straight-Through Estimator (STE) says: "In the forward pass, use the non-differentiable function. In the backward pass, pretend it was something differentiable."
The simplest STE is the identity STE:
# Forward: y = step(x)
# Backward: pretend dy/dx = 1
A Concrete Example
Let's implement a sign function with an STE:
import torch
class SignSTE(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return torch.sign(x) # Returns -1, 0, or 1
@staticmethod
def backward(ctx, grad_output):
return grad_output # Pretend derivative is 1
Now gradients can flow through the sign function!
How STEs Work in Practice
Consider training a network to output binary values:
# Without STE: gradients get stuck
logits = model(input)
binary_output = (logits > 0).float() # Gradient is 0!
loss = criterion(binary_output, target)
# With STE: gradients flow
logits = model(input)
binary_output = SignSTE.apply(logits) # Gradient passes through
loss = criterion(binary_output, target)
The JumpReLU Example
In the JumpReLU paper, they need to train through:
- The L0 penalty:
Both involve the Heaviside step function , which has zero gradient with respect to .
Their STE solution uses a "soft" gradient in a small window:
# Forward: use actual step function
y = (x > theta).float()
# Backward: use soft gradient
grad_theta = -(1/ε) * rectangle((x - theta)/ε) * grad_output
This gives gradient signal when is within of .
When Do STEs Work?
STEs are surprisingly effective when:
- The forward function preserves information: Even though loses magnitude, it preserves direction
- Gradients point in roughly the right direction: We don't need exact gradients, just reasonable descent directions
- The mismatch is local: STEs work best when the difference between forward and backward is small
Common Applications
- Quantized Neural Networks: Training networks with discrete weights/activations
- Discrete Latent Variables: Training VAEs with categorical variables
- Hard Attention: Training attention mechanisms that make discrete selections
- Sparse Models: Like the JumpReLU SAEs, where features are either on or off
Key Insights
The magic of STEs is that optimization often just needs "good enough" gradients. Perfect gradients are great, but biased gradients that point roughly the right way can still find good solutions.
As the JumpReLU paper shows, you can even design custom STEs based on your problem structure - they use kernel density estimation to create gradients that estimate the true expected loss gradient!
Code Example: Threshold Learning
Here's a simplified version of how JumpReLU uses STEs:
class ThresholdSTE(torch.autograd.Function):
@staticmethod
def forward(ctx, x, threshold, bandwidth=0.001):
ctx.save_for_backward(x, threshold)
ctx.bandwidth = bandwidth
return x * (x > threshold)
@staticmethod
def backward(ctx, grad_output):
x, threshold = ctx.saved_tensors
ε = ctx.bandwidth
# Gradient w.r.t x: normal gradient where x > threshold
grad_x = grad_output * (x > threshold)
# Gradient w.r.t threshold: custom STE
in_window = ((x - threshold).abs() < ε/2)
grad_threshold = -threshold/ε * in_window * grad_output
return grad_x, grad_threshold, None
This allows learning thresholds even though the step function has zero gradient!
In equation 11, why do they use an "eth" symbol in place of a partial derivative? Is this a known convention?
Thinking Summary
The user is asking about the notation in equation 11 of the paper, where they use ð (eth) instead of the standard partial derivative symbol ∂. Let me look at the context. Looking at equation 11: ``` ððθJumpReLUθ(z):=−θεK(z−θε) ``` And in the footnote 4, they explain: "We use the notation ð/ðz to denote pseudo-derivatives, to avoid conflating them with actual partial derivatives for these functions." So they're using ð (the lowercase eth letter from Old English and Icelandic) to denote "pseudo-derivatives" rather than true partial derivatives. This is because the JumpReLU function and Heaviside step function are not differentiable in the classical sense at the threshold point - they have discontinuities. The authors are defining custom gradients (straight-through estimators) for these non-differentiable functions. This doesn't seem to be a widely established convention - it appears to be a notation choice made by the authors of this specific paper to clearly distinguish between: - True partial derivatives (∂) - Their custom-defined pseudo-derivatives for the backward pass (ð) The eth symbol (ð) is sometimes used in mathematics, but not typically for this purpose. It's more commonly seen in: - Icelandic and Old English text - Some areas of mathematics like differential geometry (though not commonly) - Phonetics (to represent the voiced dental fricative) The authors are being careful to make it clear that these aren't true derivatives but rather custom gradient definitions used for training through non-differentiable functions.
Great observation! The ð (eth) symbol is not a standard mathematical convention - it's a notation choice specific to this paper.
The authors explain in footnote 4:
"We use the notation ð/ðz to denote pseudo-derivatives, to avoid conflating them with actual partial derivatives for these functions."
They're using ð instead of the standard ∂ (partial derivative) because the JumpReLU and Heaviside functions aren't differentiable in the classical sense - they have discontinuities at the threshold. The "pseudo-derivatives" are custom gradient definitions they create for the backward pass during training.
This is actually quite thoughtful notation design! Since they're defining straight-through estimators (STEs) that approximate gradients where true gradients don't exist or are zero, using a visually similar but distinct symbol helps avoid confusion. It makes it immediately clear when they're talking about:
- ∂: True partial derivatives (for differentiable functions)
- ð: Their custom-defined gradients for training (for non-differentiable points)
The eth symbol (ð) comes from Old English and Icelandic alphabets. While it's occasionally used in some mathematical contexts, this specific usage as "pseudo-derivative" appears to be novel to this paper. It's a clever way to maintain mathematical clarity while introducing new gradient approximations for discontinuous functions.
Maybe jumping ahead but where does the parameter epsilon used to define the pseudo derivative come from?
Thinking Summary
The user is asking about the epsilon parameter used in the pseudo-derivatives defined in equations 11 and 12. Let me look at what the paper says about this. From the paper: 1. In equations 11 and 12, epsilon appears as a bandwidth parameter in the kernel function K. 2. The paper mentions that epsilon plays the role of a KDE (kernel density estimation) bandwidth. 3. In footnote 5, they mention: "For the experiments in this paper, we swept this parameter and found ε=0.001 (assuming a dataset normalised such that 𝔼𝐱[𝐱²]=1) works well across different models, layers and sites. However, we suspect there are more principled ways to determine this parameter, borrowing from the literature on KDE bandwidth selection." 4. In Section 4, they explain: "As we show in Section 4, the hyperparameter ε plays the role of a KDE bandwidth, and needs to be selected accordingly: too low and gradient estimates become too noisy, too high and estimates become too biased." 5. The paper shows that the STEs are actually computing kernel density estimates of the true gradient of the expected loss. So epsilon is a bandwidth parameter that controls the width of the window around the threshold where gradients are non-zero. It's a hyperparameter they had to tune.
Excellent question! The epsilon (ε) parameter is essentially a hyperparameter that controls the "bandwidth" or width of the window where the pseudo-derivatives are non-zero.
Here's what the paper tells us about it:
Where it comes from
-
It's borrowed from kernel density estimation (KDE): The authors show that their STEs are actually computing kernel density estimates of the true gradient. In KDE, you always need a bandwidth parameter that controls the smoothing.
-
They found it empirically: From footnote 5:
"For the experiments in this paper, we swept this parameter and found ε=0.001 (assuming a dataset normalised such that 𝔼[𝐱²]=1) works well across different models, layers and sites."
What it does
Looking at the pseudo-derivatives:
The epsilon controls:
- Width of the gradient window: Only when (for rectangular kernel) do we get non-zero gradients
- Magnitude of gradients: Smaller ε means larger gradient magnitudes (note the 1/ε factor)
The tradeoff
As they explain in Section 4:
- ε too small: Gradient estimates become too noisy (few samples fall in the narrow window)
- ε too large: Estimates become too biased (you're averaging over too wide a region)
Visual intuition
From Figure 4, you can see that epsilon defines a small window around the threshold where the pseudo-derivative is non-zero. Outside this window, the gradient is zero (just like the true derivative).
Why not something more principled?
The authors acknowledge this is somewhat ad-hoc:
"However, we suspect there are more principled ways to determine this parameter, borrowing from the literature on KDE bandwidth selection."
So epsilon is essentially a "gradient window width" hyperparameter they tuned empirically. It's small enough to be local (ε=0.001) but large enough to catch some activations and provide gradient signal. The fact that the same value works across different models and layers suggests it's reasonably robust, though more principled selection methods could likely improve it.
§5.1: can you pull up the definition of fraction of variance unexplained from the cited paper?
Thinking Summary
The user is asking about the definition of "fraction of variance unexplained" from the cited paper. Looking at the embed they shared, this is from the paper "Scaling and evaluating sparse autoencoders" by Gao et al. 2024. Let me search through the content for this definition. Looking at the paper content... Found it! In the "Setup" section under Methods (Section 2.1), they write: "Evaluation: After training, we evaluate autoencoders on sparsity L0, and reconstruction mean-squared error (MSE). We report a normalized version of all MSE numbers, where we divide by a baseline reconstruction error of always predicting the mean activations." So the fraction of variance unexplained (FVU) is the mean reconstruction loss (MSE) normalized by the reconstruction loss obtained by always predicting the dataset mean. This is a standard way to normalize reconstruction error - it's the fraction of variance that remains unexplained by the model. The user is referring to this in the context of Section 5.1 of the JumpReLU paper, where they mention FVU as a secondary metric for measuring reconstruction fidelity.
Looking at the Gao et al. paper, the definition of "fraction of variance unexplained" (FVU) can be found in Section 2.1 under "Evaluation":
Evaluation: After training, we evaluate autoencoders on sparsity L0, and reconstruction mean-squared error (MSE). We report a normalized version of all MSE numbers, where we divide by a baseline reconstruction error of always predicting the mean activations.
So FVU is calculated as:
This is also referred to as the "normalized loss" in the paper. It's a standard normalization that measures what fraction of the variance in the data remains unexplained by the model. A perfect reconstruction would have FVU = 0, while a model that does no better than predicting the mean would have FVU = 1.
The JumpReLU paper uses this as a secondary metric alongside their primary metric of delta LM loss (the increase in cross-entropy loss when the SAE is spliced into the LM's forward pass).
I missed that part while reading the paper, can you backfill a b7 of context on the setup like what is the data, what mean is that MSE over, etc.
Thinking Summary
The user wants me to provide more context about the setup for FVU (fraction of variance unexplained). Let me look at both papers to understand: 1. What data is being used 2. What the MSE is calculated over 3. What "mean activations" refers to From the JumpReLU paper Section 5: - They trained SAEs on Gemma 2 9B activations - Specifically on residual stream, attention output, and MLP output sites after layers 9, 20, and 31 - All metrics computed on 2,048 sequences of length 1,024, excluding special tokens From the Gao et al. paper Section 2.1: - They train on residual streams of GPT-2 small and GPT-4 models - They use a layer 5/6 of the way into the network for GPT-4 series models, and layer 8 (3/4 of the way) for GPT-2 small - Context length of 64 tokens - They subtract the mean over the d_model dimension and normalize to unit norm before passing to autoencoder So the setup involves: 1. Data: Language model activations (residual stream, MLP outputs, or attention outputs) from specific layers 2. The MSE is calculated over tokens/positions in sequences 3. The "mean activations" baseline is the mean activation vector across the dataset
Sure! Let me provide more context on the setup for FVU (fraction of variance unexplained):
Data Setup
What data is being evaluated:
- The JumpReLU paper evaluates SAEs trained on Gemma 2 9B activations at specific sites:
- Residual stream activations
- MLP output activations
- Attention output activations (before the output projection)
- These are taken from layers 9, 20, and 31 of the model
Preprocessing (from Gao et al.):
- Activations are preprocessed by:
- Subtracting the mean over the
d_model
dimension - Normalizing to unit norm
- Subtracting the mean over the
Evaluation dataset:
- All metrics are computed on 2,048 sequences of length 1,024 tokens
- Special tokens (pad, start, end of sequence) are excluded
What the MSE is calculated over
The MSE measures reconstruction error:
Where:
- is the original activation vector (e.g., dimension 2304 for Gemma 2 9B)
- is the SAE reconstruction
- The expectation is over all token positions in the evaluation dataset
The baseline for normalization
The "mean activations" baseline is:
- Simply predicting the dataset mean activation vector for every input
- This gives the variance of the data as the MSE
- So FVU = (SAE's MSE) / (Variance of data)
This normalization makes the metric interpretable: FVU = 0.1 means the SAE fails to explain 10% of the variance in the activations.
This is either naive or pedantic or both but in what sense is this "variance"? I thought variance was defined as ie the expectation of squared deviation from the mean. but here we are using reconstruction error.
Thinking Summary
The user is making a good point. They're asking about why this is called "fraction of variance unexplained" when the denominator is the MSE of predicting the mean, which is indeed the variance, but the numerator is the MSE of the SAE reconstruction, not necessarily related to variance in the traditional sense. Let me think about this carefully: 1. The denominator: MSE of always predicting the mean = $\mathbb{E}[\|\mathbf{x} - \mathbb{E}[\mathbf{x}]\|^2]$ = variance of the data 2. The numerator: MSE of SAE reconstruction = $\mathbb{E}[\|\mathbf{x} - \hat{\mathbf{x}}\|^2]$ where $\hat{\mathbf{x}}$ is the SAE reconstruction The name "fraction of variance unexplained" comes from the fact that: - If you could perfectly reconstruct, MSE = 0, so FVU = 0 (all variance explained) - If your reconstruction is no better than predicting the mean, MSE = variance, so FVU = 1 (no variance explained) - In between, FVU tells you what fraction of the total variance remains as error This is actually a standard concept in statistics, related to R² (coefficient of determination). In fact, FVU = 1 - R². The user is right that it's not literally "variance" in the numerator, but it's the residual variance after the model's predictions.
You're absolutely right to point this out! This is actually a really good question about the terminology.
You're correct that variance is . The name "fraction of variance unexplained" comes from statistical tradition, particularly from linear regression where this concept originates.
Here's why it makes sense:
The denominator is indeed the variance:
The numerator is the MSE of reconstruction:
The interpretation is:
- The total variance in the data represents all the "variation" we'd like to explain
- The reconstruction error represents the "residual variance" - the variation that remains unexplained after our model
- So FVU measures what fraction of the original variance is still left as error
This is exactly analogous to (coefficient of determination) in statistics:
So while the numerator isn't literally "variance" in the formal sense, it represents the variance that remains unexplained. The terminology comes from thinking of the total variance as something to be "explained away" by the model, with perfect explanation leaving zero residual error.
In figure 5 can you elaborate on the significance of plotting delta LM loss against activation frequency?
Thinking Summary
The user is asking about Figure 5 from the "Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders" paper. Let me look at what Figure 5 shows. Figure 5 shows: - The proportion of features that activate very frequently versus delta LM loss by SAE type for Gemma 2 9B residual stream SAEs - Two plots: top one shows features active on over 10% of tokens, bottom one shows features active on over 1% of tokens - Compares TopK, JumpReLU, and Gated SAEs The significance of plotting delta LM loss against activation frequency relates to understanding the trade-off between: 1. Reconstruction fidelity (measured by delta LM loss) 2. The characteristics of learned features (specifically how often they activate) Let me explain why this matters in the context of SAEs.
Figure 5 is plotting two important characteristics of Sparse Autoencoders (SAEs) against each other to reveal a key trade-off:
Delta LM Loss (X-axis)
- What it measures: The increase in cross-entropy loss when the SAE is spliced into the language model's forward pass
- Why it matters: This is a measure of reconstruction fidelity - lower values mean the SAE is better at faithfully reconstructing the original activations without harming the model's performance
- Interpretation: Moving left on the x-axis means better reconstruction quality
Feature Activation Frequency (Y-axis)
- What it measures: The proportion of SAE features that fire very frequently (>10% of tokens in top plot, >1% in bottom plot)
- Why it matters: Very high frequency features tend to be:
- Less interpretable (they activate so often they may not represent specific, meaningful concepts)
- Less sparse (defeating the purpose of "sparse" autoencoders)
- Potentially acting more like bias terms than meaningful feature detectors
The Significance of This Plot
-
Different architectures have different characteristics: The plot reveals that TopK and JumpReLU SAEs tend to learn more high-frequency features than Gated SAEs, especially in the low delta LM loss (high fidelity) regime.
-
Trade-off visualization: As you improve reconstruction fidelity (move left), you tend to get more high-frequency features. This suggests there's a fundamental tension between:
- Making accurate reconstructions (low delta LM loss)
- Learning sparse, interpretable features (few high-frequency features)
-
Architecture comparison at fixed performance: By plotting against delta LM loss rather than just comparing raw frequency counts, the authors ensure they're comparing architectures at equivalent levels of reconstruction quality. This is a fairer comparison than just looking at frequency distributions alone.
-
Practical implications: For practitioners, this helps in choosing architectures - if you care more about avoiding high-frequency features for interpretability, you might prefer Gated SAEs despite slightly worse reconstruction fidelity.
The paper notes that while JumpReLU has more very high frequency features (>10%) like TopK, the absolute numbers remain small (<0.06% of all features even in the worst case), suggesting this may not be a major practical concern.
The x axis in the charts has the following text
2 5 0.1 2 5
the 0.1 is larger than the others. there are no other labels. is this indicating a log plot? what are the units for delta LM loss?
Thinking Summary
The user is asking about the x-axis labeling in Figure 5. Looking at the description they provided, it seems like they're seeing axis labels that suggest a logarithmic scale. The pattern "2 5 0.1 2 5" with 0.1 being larger suggests this is indeed a log scale, likely showing values like 0.02, 0.05, 0.1, 0.2, 0.5 or similar. Let me check the paper for information about delta LM loss units and whether they use log scales. Looking at the paper, delta LM loss is defined as "the increase in the cross-entropy loss experienced by the LM when we splice the SAE into the LM's forward pass." Cross-entropy loss is typically measured in nats (natural log) or bits (log base 2), and the delta would be in the same units. The paper doesn't explicitly state the units, but cross-entropy loss for language models is commonly expressed in nats per token or bits per token. Since this is a delta (difference), it would be the increase in nats/bits per token. The logarithmic scale makes sense because these loss differences can span multiple orders of magnitude, from very small (good reconstruction) to larger values (poor reconstruction).
Yes, you're absolutely right - this is a logarithmic scale on the x-axis! The pattern you're describing (2, 5, 0.1, 2, 5) is the classic labeling for a log scale, where it cycles through powers of 10. It would represent values like:
- 0.02, 0.05, 0.1, 0.2, 0.5 (or similar sequences)
The "0.1" appears larger because it's the major tick mark at a power of 10, while 2 and 5 are minor tick marks within each decade.
Units for Delta LM Loss
Delta LM loss measures the increase in cross-entropy loss when the SAE is inserted into the model. Cross-entropy loss for language models is typically measured in:
- Nats per token (using natural logarithm)
- Bits per token (using log base 2)
The paper doesn't explicitly state which unit they're using, but nats per token is more common in deep learning contexts.
Since this is a delta (difference), the units would be the increase in nats/bits per token. For example:
- A delta LM loss of 0.1 nats means the model's perplexity increases by a factor of e^0.1 ≈ 1.105 when using the SAE reconstruction
- A very small delta (e.g., 0.01) means the SAE barely affects model performance
- A larger delta (e.g., 0.5) indicates significant degradation
The logarithmic scale is appropriate here because these deltas can span multiple orders of magnitude - from nearly perfect reconstructions (delta ≈ 0.001) to quite poor ones (delta ≈ 1.0), and a log scale helps visualize this full range effectively.
Ok, I'm not understanding how I move along the graph though. For a given SAE the SAE is fixed right? So it has a certain LM loss and a certain proportor of features activity on 10% of tokens and both of those are "fixed” right? Where does the series in the chart come from?
Thinking Summary
Ah, the user is confused about what creates the curve/series in the plot. They think each SAE should just be a single point since it has fixed characteristics. Let me clarify that each point represents a different SAE trained with different hyperparameters.
Excellent question! You're right that each individual SAE, once trained, has fixed characteristics. The key insight is that each point on the curves represents a different SAE.
How the Series is Generated
The authors trained multiple SAEs of each type (JumpReLU, Gated, TopK) with different hyperparameters:
-
For JumpReLU and Gated SAEs: They varied the sparsity coefficient λ
- Higher λ → more sparse → fewer features active → worse reconstruction (higher delta LM loss)
- Lower λ → less sparse → more features active → better reconstruction (lower delta LM loss)
-
For TopK SAEs: They varied K (the number of top features to keep)
- Lower K → more sparse → worse reconstruction
- Higher K → less sparse → better reconstruction
Creating the Curves
So the process was:
- Train many SAEs with different sparsity settings (different λ or K values)
- For each trained SAE, measure:
- Its delta LM loss (reconstruction quality)
- What proportion of its features fire on >10% of tokens
- Plot each SAE as a point
- Connect points of the same architecture type to form curves
Why This Matters
This allows them to answer: "If I want a JumpReLU SAE with delta LM loss of 0.2, how many high-frequency features should I expect?" You can't just train one SAE and move along the curve - you need to train multiple SAEs with different sparsity penalties to explore the trade-off space.
The curves show the Pareto frontier - the best achievable trade-off between reconstruction quality and feature frequency characteristics for each architecture type.