Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders
Thread (53 posts)
In figure 5 can you elaborate on the significance of plotting delta LM loss against activation frequency?
Thinking Summary
The user is asking about Figure 5 from the "Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders" paper. Let me look at what Figure 5 shows. Figure 5 shows: - The proportion of features that activate very frequently versus delta LM loss by SAE type for Gemma 2 9B residual stream SAEs - Two plots: top one shows features active on over 10% of tokens, bottom one shows features active on over 1% of tokens - Compares TopK, JumpReLU, and Gated SAEs The significance of plotting delta LM loss against activation frequency relates to understanding the trade-off between: 1. Reconstruction fidelity (measured by delta LM loss) 2. The characteristics of learned features (specifically how often they activate) Let me explain why this matters in the context of SAEs.
Figure 5 is plotting two important characteristics of Sparse Autoencoders (SAEs) against each other to reveal a key trade-off:
Delta LM Loss (X-axis)
- What it measures: The increase in cross-entropy loss when the SAE is spliced into the language model's forward pass
- Why it matters: This is a measure of reconstruction fidelity - lower values mean the SAE is better at faithfully reconstructing the original activations without harming the model's performance
- Interpretation: Moving left on the x-axis means better reconstruction quality
Feature Activation Frequency (Y-axis)
- What it measures: The proportion of SAE features that fire very frequently (>10% of tokens in top plot, >1% in bottom plot)
- Why it matters: Very high frequency features tend to be:
- Less interpretable (they activate so often they may not represent specific, meaningful concepts)
- Less sparse (defeating the purpose of "sparse" autoencoders)
- Potentially acting more like bias terms than meaningful feature detectors
The Significance of This Plot
-
Different architectures have different characteristics: The plot reveals that TopK and JumpReLU SAEs tend to learn more high-frequency features than Gated SAEs, especially in the low delta LM loss (high fidelity) regime.
-
Trade-off visualization: As you improve reconstruction fidelity (move left), you tend to get more high-frequency features. This suggests there's a fundamental tension between:
- Making accurate reconstructions (low delta LM loss)
- Learning sparse, interpretable features (few high-frequency features)
-
Architecture comparison at fixed performance: By plotting against delta LM loss rather than just comparing raw frequency counts, the authors ensure they're comparing architectures at equivalent levels of reconstruction quality. This is a fairer comparison than just looking at frequency distributions alone.
-
Practical implications: For practitioners, this helps in choosing architectures - if you care more about avoiding high-frequency features for interpretability, you might prefer Gated SAEs despite slightly worse reconstruction fidelity.
The paper notes that while JumpReLU has more very high frequency features (>10%) like TopK, the absolute numbers remain small (<0.06% of all features even in the worst case), suggesting this may not be a major practical concern.
The x axis in the charts has the following text
2 5 0.1 2 5
the 0.1 is larger than the others. there are no other labels. is this indicating a log plot? what are the units for delta LM loss?
Thinking Summary
The user is asking about the x-axis labeling in Figure 5. Looking at the description they provided, it seems like they're seeing axis labels that suggest a logarithmic scale. The pattern "2 5 0.1 2 5" with 0.1 being larger suggests this is indeed a log scale, likely showing values like 0.02, 0.05, 0.1, 0.2, 0.5 or similar. Let me check the paper for information about delta LM loss units and whether they use log scales. Looking at the paper, delta LM loss is defined as "the increase in the cross-entropy loss experienced by the LM when we splice the SAE into the LM's forward pass." Cross-entropy loss is typically measured in nats (natural log) or bits (log base 2), and the delta would be in the same units. The paper doesn't explicitly state the units, but cross-entropy loss for language models is commonly expressed in nats per token or bits per token. Since this is a delta (difference), it would be the increase in nats/bits per token. The logarithmic scale makes sense because these loss differences can span multiple orders of magnitude, from very small (good reconstruction) to larger values (poor reconstruction).
Yes, you're absolutely right - this is a logarithmic scale on the x-axis! The pattern you're describing (2, 5, 0.1, 2, 5) is the classic labeling for a log scale, where it cycles through powers of 10. It would represent values like:
- 0.02, 0.05, 0.1, 0.2, 0.5 (or similar sequences)
The "0.1" appears larger because it's the major tick mark at a power of 10, while 2 and 5 are minor tick marks within each decade.
Units for Delta LM Loss
Delta LM loss measures the increase in cross-entropy loss when the SAE is inserted into the model. Cross-entropy loss for language models is typically measured in:
- Nats per token (using natural logarithm)
- Bits per token (using log base 2)
The paper doesn't explicitly state which unit they're using, but nats per token is more common in deep learning contexts.
Since this is a delta (difference), the units would be the increase in nats/bits per token. For example:
- A delta LM loss of 0.1 nats means the model's perplexity increases by a factor of e^0.1 ≈ 1.105 when using the SAE reconstruction
- A very small delta (e.g., 0.01) means the SAE barely affects model performance
- A larger delta (e.g., 0.5) indicates significant degradation
The logarithmic scale is appropriate here because these deltas can span multiple orders of magnitude - from nearly perfect reconstructions (delta ≈ 0.001) to quite poor ones (delta ≈ 1.0), and a log scale helps visualize this full range effectively.
Ok, I'm not understanding how I move along the graph though. For a given SAE the SAE is fixed right? So it has a certain LM loss and a certain proportor of features activity on 10% of tokens and both of those are "fixed” right? Where does the series in the chart come from?
Thinking Summary
Ah, the user is confused about what creates the curve/series in the plot. They think each SAE should just be a single point since it has fixed characteristics. Let me clarify that each point represents a different SAE trained with different hyperparameters.
Excellent question! You're right that each individual SAE, once trained, has fixed characteristics. The key insight is that each point on the curves represents a different SAE.
How the Series is Generated
The authors trained multiple SAEs of each type (JumpReLU, Gated, TopK) with different hyperparameters:
-
For JumpReLU and Gated SAEs: They varied the sparsity coefficient λ
- Higher λ → more sparse → fewer features active → worse reconstruction (higher delta LM loss)
- Lower λ → less sparse → more features active → better reconstruction (lower delta LM loss)
-
For TopK SAEs: They varied K (the number of top features to keep)
- Lower K → more sparse → worse reconstruction
- Higher K → less sparse → better reconstruction
Creating the Curves
So the process was:
- Train many SAEs with different sparsity settings (different λ or K values)
- For each trained SAE, measure:
- Its delta LM loss (reconstruction quality)
- What proportion of its features fire on >10% of tokens
- Plot each SAE as a point
- Connect points of the same architecture type to form curves
Why This Matters
This allows them to answer: "If I want a JumpReLU SAE with delta LM loss of 0.2, how many high-frequency features should I expect?" You can't just train one SAE and move along the curve - you need to train multiple SAEs with different sparsity penalties to explore the trade-off space.
The curves show the Pareto frontier - the best achievable trade-off between reconstruction quality and feature frequency characteristics for each architecture type.