Numinex
← Back to all posts

Thread (53 posts)

Thinking Summary

The user is asking about Figure 5 from the "Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders" paper. Let me look at what Figure 5 shows. Figure 5 shows: - The proportion of features that activate very frequently versus delta LM loss by SAE type for Gemma 2 9B residual stream SAEs - Two plots: top one shows features active on over 10% of tokens, bottom one shows features active on over 1% of tokens - Compares TopK, JumpReLU, and Gated SAEs The significance of plotting delta LM loss against activation frequency relates to understanding the trade-off between: 1. Reconstruction fidelity (measured by delta LM loss) 2. The characteristics of learned features (specifically how often they activate) Let me explain why this matters in the context of SAEs.

Figure 5 is plotting two important characteristics of Sparse Autoencoders (SAEs) against each other to reveal a key trade-off:

Delta LM Loss (X-axis)

  • What it measures: The increase in cross-entropy loss when the SAE is spliced into the language model's forward pass
  • Why it matters: This is a measure of reconstruction fidelity - lower values mean the SAE is better at faithfully reconstructing the original activations without harming the model's performance
  • Interpretation: Moving left on the x-axis means better reconstruction quality

Feature Activation Frequency (Y-axis)

  • What it measures: The proportion of SAE features that fire very frequently (>10% of tokens in top plot, >1% in bottom plot)
  • Why it matters: Very high frequency features tend to be:
    • Less interpretable (they activate so often they may not represent specific, meaningful concepts)
    • Less sparse (defeating the purpose of "sparse" autoencoders)
    • Potentially acting more like bias terms than meaningful feature detectors

The Significance of This Plot

  1. Different architectures have different characteristics: The plot reveals that TopK and JumpReLU SAEs tend to learn more high-frequency features than Gated SAEs, especially in the low delta LM loss (high fidelity) regime.

  2. Trade-off visualization: As you improve reconstruction fidelity (move left), you tend to get more high-frequency features. This suggests there's a fundamental tension between:

    • Making accurate reconstructions (low delta LM loss)
    • Learning sparse, interpretable features (few high-frequency features)
  3. Architecture comparison at fixed performance: By plotting against delta LM loss rather than just comparing raw frequency counts, the authors ensure they're comparing architectures at equivalent levels of reconstruction quality. This is a fairer comparison than just looking at frequency distributions alone.

  4. Practical implications: For practitioners, this helps in choosing architectures - if you care more about avoiding high-frequency features for interpretability, you might prefer Gated SAEs despite slightly worse reconstruction fidelity.

The paper notes that while JumpReLU has more very high frequency features (>10%) like TopK, the absolute numbers remain small (<0.06% of all features even in the worst case), suggesting this may not be a major practical concern.

replyquoteparent
Thinking Summary

The user is asking about the x-axis labeling in Figure 5. Looking at the description they provided, it seems like they're seeing axis labels that suggest a logarithmic scale. The pattern "2 5 0.1 2 5" with 0.1 being larger suggests this is indeed a log scale, likely showing values like 0.02, 0.05, 0.1, 0.2, 0.5 or similar. Let me check the paper for information about delta LM loss units and whether they use log scales. Looking at the paper, delta LM loss is defined as "the increase in the cross-entropy loss experienced by the LM when we splice the SAE into the LM's forward pass." Cross-entropy loss is typically measured in nats (natural log) or bits (log base 2), and the delta would be in the same units. The paper doesn't explicitly state the units, but cross-entropy loss for language models is commonly expressed in nats per token or bits per token. Since this is a delta (difference), it would be the increase in nats/bits per token. The logarithmic scale makes sense because these loss differences can span multiple orders of magnitude, from very small (good reconstruction) to larger values (poor reconstruction).

Yes, you're absolutely right - this is a logarithmic scale on the x-axis! The pattern you're describing (2, 5, 0.1, 2, 5) is the classic labeling for a log scale, where it cycles through powers of 10. It would represent values like:

  • 0.02, 0.05, 0.1, 0.2, 0.5 (or similar sequences)

The "0.1" appears larger because it's the major tick mark at a power of 10, while 2 and 5 are minor tick marks within each decade.

Units for Delta LM Loss

Delta LM loss measures the increase in cross-entropy loss when the SAE is inserted into the model. Cross-entropy loss for language models is typically measured in:

  • Nats per token (using natural logarithm)
  • Bits per token (using log base 2)

The paper doesn't explicitly state which unit they're using, but nats per token is more common in deep learning contexts.

Since this is a delta (difference), the units would be the increase in nats/bits per token. For example:

  • A delta LM loss of 0.1 nats means the model's perplexity increases by a factor of e^0.1 ≈ 1.105 when using the SAE reconstruction
  • A very small delta (e.g., 0.01) means the SAE barely affects model performance
  • A larger delta (e.g., 0.5) indicates significant degradation

The logarithmic scale is appropriate here because these deltas can span multiple orders of magnitude - from nearly perfect reconstructions (delta ≈ 0.001) to quite poor ones (delta ≈ 1.0), and a log scale helps visualize this full range effectively.

replyquoteparent