Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders
Thread (53 posts)
JumpReLU SAEs only require a single forward and backward pass during a training step and have an elementwise activation function (unlike TopK, which requires a partial sort), making them more efficient to train than either Gated or TopK SAEs.
Can you elaborate on the training efficiency point here in more detail?
Thinking Summary
The user is asking me to elaborate on the training efficiency claim made in the JumpReLU paper. The paper states that JumpReLU SAEs are more efficient to train than both Gated and TopK SAEs because: 1. They only require a single forward and backward pass 2. They have an elementwise activation function (unlike TopK which requires a partial sort) Let me break down the efficiency comparison: **JumpReLU SAEs:** - Single forward pass: compute pre-activations, apply JumpReLU (elementwise operation) - Single backward pass: use straight-through estimators for gradients - JumpReLU is elementwise: for each feature, just check if π_i(x) > θ_i **Gated SAEs:** - Need two encoder computations: π_gate and π_mag - Additional auxiliary loss term requires frozen decoder copy - More complex forward pass with element-wise multiplication of gate and magnitude - Requires more memory for the two encoder paths **TopK SAEs:** - Requires finding the top K activations, which involves sorting or partial sorting - Partial sort has O(M log K) complexity where M is the dictionary size - Even with approximate algorithms (like jax.lax.approx_max_k), still more expensive than elementwise ops - AuxK auxiliary loss adds additional computation The key efficiency gains for JumpReLU are: 1. Simpler architecture (single encoder path vs dual for Gated) 2. Elementwise operations only (vs sorting for TopK) 3. No auxiliary loss terms needed (vs both Gated and TopK which use aux losses)
The training efficiency advantages of JumpReLU SAEs stem from their simpler computational requirements compared to Gated and TopK architectures:
JumpReLU vs Gated SAEs
Gated SAEs require:
- Two encoder paths: Computing both and means double the encoder computations
- Auxiliary loss with frozen decoder: The auxiliary term requires maintaining a frozen copy of the decoder and computing an additional reconstruction
- Element-wise multiplication: The final features are computed as
JumpReLU SAEs only need:
- Single encoder path: Just compute once
- No auxiliary loss: The L0 penalty provides sufficient gradient signal without auxiliary terms
- Simple thresholding: Apply JumpReLU elementwise
JumpReLU vs TopK SAEs
TopK SAEs require:
- Partial sorting: Finding the top K activations has complexity where is the dictionary size (e.g., 131k features)
- Even approximate TopK (like
jax.lax.approx_max_k
) is more expensive than elementwise operations - AuxK auxiliary loss: Adds computational overhead to reduce dead features
JumpReLU's elementwise operation is just:
feature_magnitudes = x * (x > threshold) # O(M) elementwise
Memory Efficiency
JumpReLU also has lower memory requirements:
- No need to store frozen decoder copies (unlike Gated)
- No need to maintain sorting indices or auxiliary computations (unlike TopK)
- Straight-through estimators add minimal memory overhead
The paper notes these efficiency gains are particularly important when training at scale - with 131k-width SAEs on billions of tokens, even small per-step efficiency improvements compound significantly.