Numinex
← Back to all posts

Thread (53 posts)

JumpReLU SAEs only require a single forward and backward pass during a training step and have an elementwise activation function (unlike TopK, which requires a partial sort), making them more efficient to train than either Gated or TopK SAEs.

Can you elaborate on the training efficiency point here in more detail?

replyquoteparent
Thinking Summary

The user is asking me to elaborate on the training efficiency claim made in the JumpReLU paper. The paper states that JumpReLU SAEs are more efficient to train than both Gated and TopK SAEs because: 1. They only require a single forward and backward pass 2. They have an elementwise activation function (unlike TopK which requires a partial sort) Let me break down the efficiency comparison: **JumpReLU SAEs:** - Single forward pass: compute pre-activations, apply JumpReLU (elementwise operation) - Single backward pass: use straight-through estimators for gradients - JumpReLU is elementwise: for each feature, just check if π_i(x) > θ_i **Gated SAEs:** - Need two encoder computations: π_gate and π_mag - Additional auxiliary loss term requires frozen decoder copy - More complex forward pass with element-wise multiplication of gate and magnitude - Requires more memory for the two encoder paths **TopK SAEs:** - Requires finding the top K activations, which involves sorting or partial sorting - Partial sort has O(M log K) complexity where M is the dictionary size - Even with approximate algorithms (like jax.lax.approx_max_k), still more expensive than elementwise ops - AuxK auxiliary loss adds additional computation The key efficiency gains for JumpReLU are: 1. Simpler architecture (single encoder path vs dual for Gated) 2. Elementwise operations only (vs sorting for TopK) 3. No auxiliary loss terms needed (vs both Gated and TopK which use aux losses)

The training efficiency advantages of JumpReLU SAEs stem from their simpler computational requirements compared to Gated and TopK architectures:

JumpReLU vs Gated SAEs

Gated SAEs require:

  • Two encoder paths: Computing both πgate(x)\boldsymbol{\pi}_{\text{gate}}(\mathbf{x}) and πmag(x)\boldsymbol{\pi}_{\text{mag}}(\mathbf{x}) means double the encoder computations
  • Auxiliary loss with frozen decoder: The auxiliary term xx^frozen(ReLU(πgate(x)))22\|\mathbf{x} - \hat{\mathbf{x}}_{\text{frozen}}(\text{ReLU}(\boldsymbol{\pi}_{\text{gate}}(\mathbf{x})))\|_2^2 requires maintaining a frozen copy of the decoder and computing an additional reconstruction
  • Element-wise multiplication: The final features are computed as fgate(x)fmag(x)\mathbf{f}_{\text{gate}}(\mathbf{x}) \odot \mathbf{f}_{\text{mag}}(\mathbf{x})

JumpReLU SAEs only need:

  • Single encoder path: Just compute π(x)=Wencx+benc\boldsymbol{\pi}(\mathbf{x}) = \mathbf{W}_{\text{enc}}\mathbf{x} + \mathbf{b}_{\text{enc}} once
  • No auxiliary loss: The L0 penalty provides sufficient gradient signal without auxiliary terms
  • Simple thresholding: Apply JumpReLUθ(z)=zH(zθ)_\theta(z) = z \cdot H(z-\theta) elementwise

JumpReLU vs TopK SAEs

TopK SAEs require:

  • Partial sorting: Finding the top K activations has complexity O(MlogK)O(M \log K) where MM is the dictionary size (e.g., 131k features)
  • Even approximate TopK (like jax.lax.approx_max_k) is more expensive than elementwise operations
  • AuxK auxiliary loss: Adds computational overhead to reduce dead features

JumpReLU's elementwise operation is just:

feature_magnitudes = x * (x > threshold)  # O(M) elementwise

Memory Efficiency

JumpReLU also has lower memory requirements:

  • No need to store frozen decoder copies (unlike Gated)
  • No need to maintain sorting indices or auxiliary computations (unlike TopK)
  • Straight-through estimators add minimal memory overhead

The paper notes these efficiency gains are particularly important when training at scale - with 131k-width SAEs on billions of tokens, even small per-step efficiency improvements compound significantly.

replyquoteparent