Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders
Thread (53 posts)
In figure 5 can you elaborate on the significance of plotting delta LM loss against activation frequency?
Thinking Summary
The user is asking about Figure 5 from the "Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders" paper. Let me look at what Figure 5 shows. Figure 5 shows: - The proportion of features that activate very frequently versus delta LM loss by SAE type for Gemma 2 9B residual stream SAEs - Two plots: top one shows features active on over 10% of tokens, bottom one shows features active on over 1% of tokens - Compares TopK, JumpReLU, and Gated SAEs The significance of plotting delta LM loss against activation frequency relates to understanding the trade-off between: 1. Reconstruction fidelity (measured by delta LM loss) 2. The characteristics of learned features (specifically how often they activate) Let me explain why this matters in the context of SAEs.
Figure 5 is plotting two important characteristics of Sparse Autoencoders (SAEs) against each other to reveal a key trade-off:
Delta LM Loss (X-axis)
- What it measures: The increase in cross-entropy loss when the SAE is spliced into the language model's forward pass
- Why it matters: This is a measure of reconstruction fidelity - lower values mean the SAE is better at faithfully reconstructing the original activations without harming the model's performance
- Interpretation: Moving left on the x-axis means better reconstruction quality
Feature Activation Frequency (Y-axis)
- What it measures: The proportion of SAE features that fire very frequently (>10% of tokens in top plot, >1% in bottom plot)
- Why it matters: Very high frequency features tend to be:
- Less interpretable (they activate so often they may not represent specific, meaningful concepts)
- Less sparse (defeating the purpose of "sparse" autoencoders)
- Potentially acting more like bias terms than meaningful feature detectors
The Significance of This Plot
-
Different architectures have different characteristics: The plot reveals that TopK and JumpReLU SAEs tend to learn more high-frequency features than Gated SAEs, especially in the low delta LM loss (high fidelity) regime.
-
Trade-off visualization: As you improve reconstruction fidelity (move left), you tend to get more high-frequency features. This suggests there's a fundamental tension between:
- Making accurate reconstructions (low delta LM loss)
- Learning sparse, interpretable features (few high-frequency features)
-
Architecture comparison at fixed performance: By plotting against delta LM loss rather than just comparing raw frequency counts, the authors ensure they're comparing architectures at equivalent levels of reconstruction quality. This is a fairer comparison than just looking at frequency distributions alone.
-
Practical implications: For practitioners, this helps in choosing architectures - if you care more about avoiding high-frequency features for interpretability, you might prefer Gated SAEs despite slightly worse reconstruction fidelity.
The paper notes that while JumpReLU has more very high frequency features (>10%) like TopK, the absolute numbers remain small (<0.06% of all features even in the worst case), suggesting this may not be a major practical concern.
Ok, I'm not understanding how I move along the graph though. For a given SAE the SAE is fixed right? So it has a certain LM loss and a certain proportor of features activity on 10% of tokens and both of those are "fixed” right? Where does the series in the chart come from?
Thinking Summary
Ah, the user is confused about what creates the curve/series in the plot. They think each SAE should just be a single point since it has fixed characteristics. Let me clarify that each point represents a different SAE trained with different hyperparameters.
Excellent question! You're right that each individual SAE, once trained, has fixed characteristics. The key insight is that each point on the curves represents a different SAE.
How the Series is Generated
The authors trained multiple SAEs of each type (JumpReLU, Gated, TopK) with different hyperparameters:
-
For JumpReLU and Gated SAEs: They varied the sparsity coefficient λ
- Higher λ → more sparse → fewer features active → worse reconstruction (higher delta LM loss)
- Lower λ → less sparse → more features active → better reconstruction (lower delta LM loss)
-
For TopK SAEs: They varied K (the number of top features to keep)
- Lower K → more sparse → worse reconstruction
- Higher K → less sparse → better reconstruction
Creating the Curves
So the process was:
- Train many SAEs with different sparsity settings (different λ or K values)
- For each trained SAE, measure:
- Its delta LM loss (reconstruction quality)
- What proportion of its features fire on >10% of tokens
- Plot each SAE as a point
- Connect points of the same architecture type to form curves
Why This Matters
This allows them to answer: "If I want a JumpReLU SAE with delta LM loss of 0.2, how many high-frequency features should I expect?" You can't just train one SAE and move along the curve - you need to train multiple SAEs with different sparsity penalties to explore the trade-off space.
The curves show the Pareto frontier - the best achievable trade-off between reconstruction quality and feature frequency characteristics for each architecture type.