Numinex
← Back to all posts

Thread (53 posts)

loss function that is simply the weighted sum of a L2 reconstruction error term and a L0 sparsity penalty, eschewing easier-to-train proxies to L0, such as L1, and avoiding the need for auxiliary tasks to train the threshold.

replyquoteparent

Our key insight is to notice that although such a loss function is piecewise-constant with respect to the threshold – and therefore provides zero gradient to train this parameter – the derivative of the expected loss can be analytically derived, and is generally non-zero, albeit it is expressed in terms of probability densities of the feature activation distribution that need to be estimated.

First part makes sense but when they say expected loss, what is the precise setup for that expectation?

replyquoteparent
Thinking Summary

The user is asking about the expected loss in the JumpReLU SAE paper. Looking at the paper, they introduce a loss function that has a JumpReLU activation function with threshold parameters θ. The loss function is piecewise constant with respect to θ, so the gradient is zero. But they claim the derivative of the expected loss is non-zero. The user wants to know what expectation they're talking about. Looking at the paper, in Section 4 "How STEs enable training through the jump", they say: "Why does this work? The key is to notice that during SGD, we actually want to estimate the gradient of the expected loss, 𝔼_𝐱[ℒ_θ(𝐱)], in order to calculate parameter updates" So the expectation is over the data distribution - they're taking the expectation over input activations 𝐱 drawn from the dataset 𝒟. In Appendix B, they derive this more formally. The loss function is: ℒ_θ(𝐱) = ||𝐱 - 𝐱̂(𝐟(𝐱))||²₂ + λ||𝐟(𝐱)||₀ And they want to compute: ∂𝔼_𝐱[ℒ_θ(𝐱)]/∂θᵢ where 𝐱 are samples from the training data distribution. The key insight is that while ℒ_θ(𝐱) is piecewise constant in θ for any fixed 𝐱, the expected loss 𝔼_𝐱[ℒ_θ(𝐱)] is smooth because you're averaging over many different values of 𝐱, and the locations of the discontinuities vary with 𝐱.

Looking at Section 4 and Appendix B of the paper, the expectation is taken over the data distribution. Specifically:

𝔼_𝐱[ℒ_θ(𝐱)] where 𝐱 ~ 𝒟

Here, 𝐱 represents language model activations sampled from the training dataset 𝒟.

The key insight is that while the loss ℒ_θ(𝐱) is piecewise constant in θ for any fixed input 𝐱 (because the JumpReLU and Heaviside functions create discontinuities), when you take the expectation over many different inputs 𝐱, the expected loss becomes smooth and differentiable.

This happens because:

  • For a fixed 𝐱, the discontinuities occur when pre-activations π_i(𝐱) equal the threshold θ_i
  • Different inputs 𝐱 have different pre-activation values
  • So the discontinuity locations vary across the data distribution
  • Averaging over the distribution smooths out these discontinuities

In Equation 13, they show the derivative of this expected loss:

Ex[Lθ(x)]θi=(Ex[Ii(x)πi(x)=θi]λ)pi(θi)\frac{\partial\mathbb{E}_{\mathbf{x}}\left[\mathcal{L}_{\boldsymbol{\theta}}(\mathbf{x})\right]}{\partial\theta_{i}} = \left(\mathbb{E}_{\mathbf{x}}\left[I_{i}(\mathbf{x})|\pi_{i}(\mathbf{x})=\theta_{i}\right]-\lambda\right)p_{i}(\theta_{i})

where pip_i is the probability density of pre-activations πi(x)\pi_i(\mathbf{x}) across the data distribution.

replyquoteparent

Looking at Section 4 and Appendix B of the paper, the expectation is taken over the data distribution. Specifically:

Ex[Lθ(x)]\mathbb{E}_{\mathbf{x}}[\mathcal{L}_{\boldsymbol{\theta}}(\mathbf{x})] where xD\mathbf{x} \sim \mathcal{D}

Here, x\mathbf{x} represents language model activations sampled from the training dataset D\mathcal{D}.

The key insight is that while the loss Lθ(x)\mathcal{L}_{\boldsymbol{\theta}}(\mathbf{x}) is piecewise constant in θ\boldsymbol{\theta} for any fixed input x\mathbf{x} (because the JumpReLU and Heaviside functions create discontinuities), when you take the expectation over many different inputs x\mathbf{x}, the expected loss becomes smooth and differentiable.

This happens because:

  • For a fixed x\mathbf{x}, the discontinuities occur when pre-activations πi(x)\pi_i(\mathbf{x}) equal the threshold θi\theta_i
  • Different inputs $\mathb

Generation Error: streaming error: received error while streaming: {"type":"error","error":{"details":null,"type":"overloaded_error","message":"Overloaded"} }

replyquoteparent

[deleted]

Generation Error: streaming error: received error while streaming: {"type":"error","error":{"details":null,"type":"overloaded_error","message":"Overloaded"} }

Our key insight is to notice that although such a loss function is piecewise-constant with respect to the threshold – and therefore provides zero gradient to train this parameter – the derivative of the expected loss can be analytically derived, and is generally non-zero, albeit it is expressed in terms of probability densities of the feature activation distribution that need to be estimated.

First part makes sense but when they say expected loss, what is the precise setup for that expectation?

replyquote(1 quotes)parent
Thinking Summary

The user is asking about the expectation in the expected loss for JumpReLU SAEs. Let me look at the paper details to understand what they mean by expected loss. From the paper: In Section 4, they write: "The key is to notice that during SGD, we actually want to estimate the gradient of the expected loss, $\mathbb{E}_{\mathbf{x}}\left[\mathcal{L}_{\boldsymbol{\theta}}(\mathbf{x})\right]$, in order to calculate parameter updates" So the expectation is over the data distribution. More specifically: - $\mathbf{x}$ represents language model activations - They're training on a dataset of language model activations $\mathbf{x} \sim \mathcal{D}$ - The loss function $\mathcal{L}_{\boldsymbol{\theta}}(\mathbf{x})$ depends on the input activations $\mathbf{x}$ and the threshold parameters $\boldsymbol{\theta}$ - The expected loss is $\mathbb{E}_{\mathbf{x}}\left[\mathcal{L}_{\boldsymbol{\theta}}(\mathbf{x})\right]$ where the expectation is taken over the distribution of language model activations The key insight is that while the loss for a single input is piecewise constant w.r.t. the threshold (so has zero gradient almost everywhere), the expected loss over the distribution of inputs is smooth and has non-zero gradients. This is explained in more detail in Appendix B where they derive the gradient of the expected loss.

The expectation is taken over the distribution of language model activations that the SAE is being trained on.

Specifically, in the paper's notation:

  • xD\mathbf{x} \sim \mathcal{D} represents language model activations sampled from the training dataset
  • The loss function is Lθ(x)\mathcal{L}_{\boldsymbol{\theta}}(\mathbf{x}) which depends on both the input activations x\mathbf{x} and the threshold parameters θ\boldsymbol{\theta}
  • The expected loss is Ex[Lθ(x)]\mathbb{E}_{\mathbf{x}}\left[\mathcal{L}_{\boldsymbol{\theta}}(\mathbf{x})\right]

The key insight is that while the loss for any single input x\mathbf{x} is piecewise constant with respect to the threshold θi\theta_i (and thus has zero gradient almost everywhere), the expected loss over the distribution of all possible activations is smooth.

This happens because different activation samples will have their pre-activations πi(x)\pi_i(\mathbf{x}) distributed around the threshold. As you vary the threshold slightly, you change which samples fall above or below it, creating a non-zero gradient in expectation.

In Appendix B, they show this gradient is:

Ex[Lθ(x)]θi=(Ex[Ii(x)πi(x)=θi]λ)pi(θi)\frac{\partial\mathbb{E}_{\mathbf{x}}\left[\mathcal{L}_{\boldsymbol{\theta}}(\mathbf{x})\right]}{\partial\theta_{i}} = \left(\mathbb{E}_{\mathbf{x}}\left[I_{i}(\mathbf{x})|\pi_{i}(\mathbf{x})=\theta_{i}\right]-\lambda\right)p_{i}(\theta_{i})

where pip_i is the probability density function of the pre-activation distribution πi(x)\pi_i(\mathbf{x}). This density needs to be estimated from data, which is where the kernel density estimation techniques come in.

replyquoteparent

JumpReLU SAEs only require a single forward and backward pass during a training step and have an elementwise activation function (unlike TopK, which requires a partial sort), making them more efficient to train than either Gated or TopK SAEs.

Can you elaborate on the training efficiency point here in more detail?

replyquoteparent
Thinking Summary

The user is asking me to elaborate on the training efficiency claim made in the JumpReLU paper. The paper states that JumpReLU SAEs are more efficient to train than both Gated and TopK SAEs because: 1. They only require a single forward and backward pass 2. They have an elementwise activation function (unlike TopK which requires a partial sort) Let me break down the efficiency comparison: **JumpReLU SAEs:** - Single forward pass: compute pre-activations, apply JumpReLU (elementwise operation) - Single backward pass: use straight-through estimators for gradients - JumpReLU is elementwise: for each feature, just check if π_i(x) > θ_i **Gated SAEs:** - Need two encoder computations: π_gate and π_mag - Additional auxiliary loss term requires frozen decoder copy - More complex forward pass with element-wise multiplication of gate and magnitude - Requires more memory for the two encoder paths **TopK SAEs:** - Requires finding the top K activations, which involves sorting or partial sorting - Partial sort has O(M log K) complexity where M is the dictionary size - Even with approximate algorithms (like jax.lax.approx_max_k), still more expensive than elementwise ops - AuxK auxiliary loss adds additional computation The key efficiency gains for JumpReLU are: 1. Simpler architecture (single encoder path vs dual for Gated) 2. Elementwise operations only (vs sorting for TopK) 3. No auxiliary loss terms needed (vs both Gated and TopK which use aux losses)

The training efficiency advantages of JumpReLU SAEs stem from their simpler computational requirements compared to Gated and TopK architectures:

JumpReLU vs Gated SAEs

Gated SAEs require:

  • Two encoder paths: Computing both πgate(x)\boldsymbol{\pi}_{\text{gate}}(\mathbf{x}) and πmag(x)\boldsymbol{\pi}_{\text{mag}}(\mathbf{x}) means double the encoder computations
  • Auxiliary loss with frozen decoder: The auxiliary term xx^frozen(ReLU(πgate(x)))22\|\mathbf{x} - \hat{\mathbf{x}}_{\text{frozen}}(\text{ReLU}(\boldsymbol{\pi}_{\text{gate}}(\mathbf{x})))\|_2^2 requires maintaining a frozen copy of the decoder and computing an additional reconstruction
  • Element-wise multiplication: The final features are computed as fgate(x)fmag(x)\mathbf{f}_{\text{gate}}(\mathbf{x}) \odot \mathbf{f}_{\text{mag}}(\mathbf{x})

JumpReLU SAEs only need:

  • Single encoder path: Just compute π(x)=Wencx+benc\boldsymbol{\pi}(\mathbf{x}) = \mathbf{W}_{\text{enc}}\mathbf{x} + \mathbf{b}_{\text{enc}} once
  • No auxiliary loss: The L0 penalty provides sufficient gradient signal without auxiliary terms
  • Simple thresholding: Apply JumpReLUθ(z)=zH(zθ)_\theta(z) = z \cdot H(z-\theta) elementwise

JumpReLU vs TopK SAEs

TopK SAEs require:

  • Partial sorting: Finding the top K activations has complexity O(MlogK)O(M \log K) where MM is the dictionary size (e.g., 131k features)
  • Even approximate TopK (like jax.lax.approx_max_k) is more expensive than elementwise operations
  • AuxK auxiliary loss: Adds computational overhead to reduce dead features

JumpReLU's elementwise operation is just:

feature_magnitudes = x * (x > threshold)  # O(M) elementwise

Memory Efficiency

JumpReLU also has lower memory requirements:

  • No need to store frozen decoder copies (unlike Gated)
  • No need to maintain sorting indices or auxiliary computations (unlike TopK)
  • Straight-through estimators add minimal memory overhead

The paper notes these efficiency gains are particularly important when training at scale - with 131k-width SAEs on billions of tokens, even small per-step efficiency improvements compound significantly.

replyquoteparent

JumpReLU SAEs only require a single forward and backward pass during a training step and have an elementwise activation function (unlike TopK, which requires a partial sort), making them more efficient to train than either Gated or TopK SAEs.

Can you elaborate on the training efficiency point here in more detail? Only focus comparison on TopK

replyquoteparent
Thinking Summary

The user is asking me to elaborate on the training efficiency comparison between JumpReLU and TopK SAEs, specifically focusing on TopK. The paper mentions that JumpReLU is more efficient to train than TopK, and I should explain why. From the paper, the key points are: 1. JumpReLU uses a single forward and backward pass 2. JumpReLU has an elementwise activation function 3. TopK requires a partial sort operation Let me think about the computational complexity implications: - TopK requires finding the top K values among M features for each input - This is typically O(M log K) or O(M) with partial sorting algorithms - JumpReLU just applies an elementwise threshold comparison, which is O(M) The paper mentions they use `jax.lax.approx_max_k` for their TopK implementation, which is an approximate algorithm that's faster than exact TopK but still more expensive than elementwise operations.

The training efficiency advantage of JumpReLU over TopK SAEs comes down to the fundamental difference in their activation mechanisms:

TopK's computational overhead:

  • Partial sorting requirement: For each input, TopK must identify and retain only the K largest pre-activations out of M total features (where M is the dictionary size, e.g., 131k in the paper's experiments)
  • Complexity: Even with efficient algorithms, this requires O(M log K) operations for exact TopK, or O(M) for approximate methods
  • Non-local operation: The activation of any feature depends on the values of all other features, requiring global information

JumpReLU's efficiency:

  • Simple elementwise comparison: Each feature activation is determined by a single comparison: if pre_activation > threshold
  • Complexity: O(M) with a very low constant factor - just M comparisons
  • Local operation: Each feature's activation is independent of others

Practical implications: The paper mentions using jax.lax.approx_max_k (based on Chern et al., 2022) for TopK implementation, which is faster than exact sorting but still involves:

  • Multiple passes through the data to approximate the top-K selection
  • Additional memory access patterns that are less cache-friendly
  • More complex gradient computation in the backward pass

For a 131k-width SAE trained on billions of tokens, these differences compound significantly. The elementwise nature of JumpReLU also makes it more amenable to parallelization and hardware acceleration compared to the inherently sequential nature of sorting operations.

This efficiency gain is particularly important when training multiple SAEs with different sparsity levels or when scaling to larger models and dictionary sizes.

replyquoteparent
Thinking Summary

The user is asking about Figure 1's caption, specifically what "direction" means in this context. They have a typo - "excrete dear" should probably be "concrete terms" or similar. Looking at Figure 1's caption in the paper: "Figure 1: A toy model illustrating why JumpReLU (or similar activation functions, such as TopK) are an improvement over ReLU for training sparse yet faithful SAEs. Consider a direction in which the encoder pre-activation is high when the corresponding feature is active and low, but not always negative, when the feature is inactive (far-left)." In this context, "direction" refers to a feature direction in the SAE - essentially one of the learned dictionary elements that the SAE uses to decompose activations. Each "direction" corresponds to one feature in the SAE's dictionary, represented by: - A row in the encoder weight matrix W_enc - A column in the decoder weight matrix W_dec The figure is showing how the encoder pre-activation (the value before applying the activation function) behaves for one particular feature/direction.

In the context of Figure 1's caption, "direction" refers to a feature direction in the Sparse Autoencoder (SAE) - essentially one of the learned dictionary elements.

More precisely, each direction corresponds to:

  • A row in the encoder weight matrix Wenc\mathbf{W}_{\text{enc}}
  • A column in the decoder weight matrix Wdec\mathbf{W}_{\text{dec}} (denoted as di\mathbf{d}_i in the paper)

So when the caption says "Consider a direction in which the encoder pre-activation is high...", it's referring to one specific feature in the SAE's dictionary. The figure shows how the encoder pre-activation πi(x)=(Wencx+benc)i\pi_i(\mathbf{x}) = (\mathbf{W}_{\text{enc}}\mathbf{x} + \mathbf{b}_{\text{enc}})_i behaves for that particular feature ii.

The key insight illustrated is that for this feature:

  • When the feature is truly "active" (should be detected), the pre-activation is high
  • When the feature is "inactive" (shouldn't be detected), the pre-activation is low but sometimes still positive

This creates the challenge that ReLU addresses poorly but JumpReLU handles well by introducing an adjustable threshold.

replyquoteparent
Thinking Summary

The user is asking about the basis for the encoder weight matrix rows. Let me think carefully about this. The encoder takes in activations $\mathbf{x} \in \mathbb{R}^n$ from the language model and produces pre-activations. So the encoder weight matrix $\mathbf{W}_{\text{enc}}$ has shape $M \times n$ where: - $M$ is the number of features in the SAE dictionary - $n$ is the dimension of the language model activations Each row of $\mathbf{W}_{\text{enc}}$ is therefore a vector in $\mathbb{R}^n$, expressed in the same basis as the input language model activations $\mathbf{x}$. This basis is the standard basis of the language model's activation space at whatever layer/site is being analyzed (e.g., residual stream, MLP output, or attention output).

Each row of the encoder matrix Wenc\mathbf{W}_{\text{enc}} is expressed in the same basis as the language model activations xRn\mathbf{x} \in \mathbb{R}^n.

Specifically:

  • If the SAE is trained on residual stream activations, the basis is the standard basis of the residual stream space
  • If trained on MLP outputs, it's the standard basis of the MLP output space
  • If trained on attention outputs, it's the standard basis of the attention output space

So each encoder row is a vector in Rn\mathbb{R}^n that defines a linear projection from the LM's activation space. When we compute Wencx\mathbf{W}_{\text{enc}}\mathbf{x}, each element of the result is the dot product between a row of Wenc\mathbf{W}_{\text{enc}} and the input activation vector x\mathbf{x}, both expressed in the same coordinate system.

This is in contrast to the decoder columns di\mathbf{d}_i, which also live in the same Rn\mathbb{R}^n space but represent the actual feature directions that the SAE uses to reconstruct activations.

replyquoteparent

Am i correctly understanding the caption / point of figure I to be that because JumpReLU has a vector of offsets θ\theta which is learned during training, it can precisely identify the "intended" activations without alter them?

replyquoteparent
Thinking Summary

The user is asking about Figure 1 in the paper. Let me look at what Figure 1 shows and its caption. From the paper: "Figure 1: A toy model illustrating why JumpReLU (or similar activation functions, such as TopK) are an improvement over ReLU for training sparse yet faithful SAEs. Consider a direction in which the encoder pre-activation is high when the corresponding feature is active and low, but not always negative, when the feature is inactive (far-left). Applying a ReLU activation function fails to remove all false positives (centre-left), harming sparsity. It is possible to get rid of false positives while maintaining the ReLU, e.g. by decreasing the encoder bias (centre-right), but this leads to feature magnitudes being systematically underestimated, harming fidelity. The JumpReLU activation function (far-right) provides an independent threshold below which pre-activations are screened out, minimising false positives, while leaving pre-activations above the threshold unaffected, improving fidelity." So the key points are: 1. The encoder pre-activation can be high when the feature is truly active, but also sometimes positive (though lower) when the feature is inactive 2. ReLU passes through all positive values, so it includes false positives (when the feature is inactive but pre-activation is positive) 3. You could lower the bias to make more pre-activations negative, but then you systematically underestimate the magnitudes of true activations 4. JumpReLU has a learned threshold θ that can be set to separate true activations from false positives, while preserving the magnitudes of true activations above the threshold The user is asking if they correctly understand that JumpReLU can "precisely identify the 'intended' activations without altering them" - this is essentially correct. The threshold allows it to filter out false positives while preserving the magnitudes of true activations.

Yes, you've got it exactly right! Figure 1 illustrates the key advantage of JumpReLU over standard ReLU for sparse autoencoders.

The core insight is that with a learned threshold vector θ\boldsymbol{\theta}, JumpReLU can:

  1. Filter out false positives: Set the threshold above the "noise floor" of pre-activations that occur when the feature is actually inactive (but happen to be slightly positive)

  2. Preserve true activation magnitudes: Unlike solutions that shift the bias to make more pre-activations negative, JumpReLU doesn't systematically shrink the magnitudes of true activations - it just passes them through unchanged if they're above the threshold

The figure shows this nicely:

  • ReLU (center-left): Lets through all positive values, including false positives
  • ReLU with shifted bias (center-right): Eliminates false positives but systematically underestimates all magnitudes
  • JumpReLU (far-right): Clean separation - zeros out everything below θ\theta, preserves everything above

This is why the paper emphasizes that JumpReLU (like TopK and Gated SAEs) can "separate out deciding which features are active from estimating active features' magnitudes." The threshold handles the binary active/inactive decision, while the magnitude above threshold captures the strength of activation.

replyquoteparent
Thinking Summary

The user is asking about whether this part (referring to the quote about expected loss) is showing that there's actually a way to learn the parameter theta. They seem to understand that the individual loss is piecewise constant (no gradient) but are asking about how the expected loss insight enables learning theta.

Yes, exactly! This is the paper's key technical contribution - showing how to actually learn θ\boldsymbol{\theta} despite the fact that the loss function has zero gradient with respect to it.

The clever trick is:

  1. Individual samples give no gradient: For any single activation x\mathbf{x}, the loss is piecewise constant in θi\theta_i (flat everywhere except at discontinuities)
  2. But the expected loss is smooth: When you average over the distribution of all activations, you get a smooth function with non-zero gradients

The authors then use straight-through estimators (STEs) to approximate this expected gradient during training. Specifically:

  • They define pseudo-derivatives (Equations 11-12) that approximate the true gradient of the expected loss
  • These use kernel functions to "smooth out" the discontinuity in a small window of width ε\varepsilon around the threshold
  • During backpropagation, whenever they would normally get zero gradient, they substitute these pseudo-derivatives instead

The mathematical justification (Section 4) shows that this STE approach is actually computing an unbiased kernel density estimate of the true expected loss gradient:

1Nεα=1N{Ii(xα)λ}K(πi(xα)θiε)\frac{1}{N\varepsilon}\sum_{\alpha=1}^{N}\{I_i(\mathbf{x}_\alpha)-\lambda\}K\left(\frac{\pi_i(\mathbf{x}_\alpha)-\theta_i}{\varepsilon}\right)

So yes - despite the discontinuous activation function, they can train θ\boldsymbol{\theta} effectively using standard gradient descent, because the STEs give them a principled way to estimate gradients that would otherwise be zero.

replyquoteparent
Thinking Summary

The user is asking about what $\theta$ (the threshold parameter in JumpReLU) is called in the literature and if it has any interesting interpretations beyond the obvious one of being a threshold. Let me think about this: 1. In the JumpReLU paper, they simply call it the "threshold" or "threshold parameter" 2. In the context of sparse coding/dictionary learning, similar parameters might be called: - Activation threshold - Firing threshold - Detection threshold - Gating threshold (in Gated SAEs) 3. Interesting interpretations: - Statistical interpretation: $\theta_i$ could be seen as defining a hypothesis test threshold - we're testing whether feature $i$ is "truly present" vs just noise - Information theoretic: It's setting a signal-to-noise threshold - Neuroscience analogy: Similar to neuronal firing thresholds - Bayesian interpretation: Could be viewed as related to a decision boundary in a classification problem - It implicitly defines a "baseline activation level" below which we consider the feature to be absent Let me check if the paper mentions any of these interpretations...

In this paper and related SAE literature, θ\boldsymbol{\theta} is typically just called the threshold or threshold parameter. In the Gated SAE paper, the equivalent parameter (when using weight sharing) emerges from the "gating" mechanism, hence "gating threshold."

Some interesting interpretations beyond the obvious:

  1. Statistical significance threshold: θi\theta_i implicitly defines a hypothesis test - "is this feature truly active or just noise?" The threshold separates signal from background fluctuations in the pre-activations.

  2. Implicit prior on sparsity: The threshold encodes an assumption about how strong a feature's pre-activation needs to be before we believe it's "really there." Higher thresholds encode stronger priors that features should be rare.

  3. Neuroscience analogy: Similar to neuronal firing thresholds - a neuron needs sufficient input before it "decides" to fire. The JumpReLU makes this binary decision explicit.

  4. Information-theoretic view: The threshold defines a channel capacity constraint - features must exceed a certain "signal strength" to transmit information through the autoencoder.

  5. Learned noise floor: As Figure 1 illustrates, θi\theta_i can be interpreted as learning the typical "false positive" activation level for each feature when it's not truly present.

The paper hints at another interpretation in Equation 13 - the gradient involves terms that measure:

  • How much reconstruction error increases when we raise the threshold (screening out marginal activations)
  • How much sparsity improves when we raise the threshold

So θi\theta_i represents the learned equilibrium point where the marginal cost to reconstruction equals λ\lambda times the marginal benefit to sparsity.

replyquoteparent

The L1 penalty also fails to be invariant under reparameterizations of a SAE; by scaling down encoder parameters and scaling up decoder parameters accordingly, it is possible to arbitrarily shrink feature magnitudes, and thus the L1 penalty, without changing either the number of active features or the SAE’s output reconstructions.

Can you give a worked toy example of this?

replyquote(1 quotes)parent
Thinking Summary

The user is asking for a worked toy example of how the L1 penalty is not invariant under reparameterizations of a SAE. Let me think through this. The key idea is that if we have an SAE with encoder and decoder, we can scale the encoder weights down by some factor and scale the decoder weights up by the same factor, and this will: 1. Keep the reconstruction the same 2. Keep the number of active features the same 3. But reduce the L1 penalty (since feature magnitudes will be smaller) Let me create a simple example to illustrate this.

Sure! Here's a simple toy example to illustrate how the L1 penalty is not invariant under reparameterizations:

Original SAE Setup

Consider a simple SAE with:

  • Input dimension: 2
  • Dictionary size: 2
  • Input vector: x=[1,0]\mathbf{x} = [1, 0]

Original parameters:

  • Encoder weights: Wenc=[2002]\mathbf{W}_{\text{enc}} = \begin{bmatrix} 2 & 0 \\ 0 & 2 \end{bmatrix}
  • Encoder bias: benc=[1,1]\mathbf{b}_{\text{enc}} = [-1, -1]
  • Decoder weights: Wdec=[0.5000.5]\mathbf{W}_{\text{dec}} = \begin{bmatrix} 0.5 & 0 \\ 0 & 0.5 \end{bmatrix}
  • Decoder bias: bdec=[0,0]\mathbf{b}_{\text{dec}} = [0, 0]

Forward pass:

  1. Pre-activations: π=Wencx+benc=[211,011]=[1,1]\boldsymbol{\pi} = \mathbf{W}_{\text{enc}}\mathbf{x} + \mathbf{b}_{\text{enc}} = [2 \cdot 1 - 1, 0 \cdot 1 - 1] = [1, -1]
  2. Feature activations: f=ReLU([1,1])=[1,0]\mathbf{f} = \text{ReLU}([1, -1]) = [1, 0]
  3. Reconstruction: x^=Wdecf=[0.51,01]=[0.5,0]\hat{\mathbf{x}} = \mathbf{W}_{\text{dec}}\mathbf{f} = [0.5 \cdot 1, 0 \cdot 1] = [0.5, 0]
  4. L1 penalty: f1=1\|\mathbf{f}\|_1 = 1

Reparameterized SAE

Now, let's scale encoder weights down by factor 10 and decoder weights up by factor 10:

Reparameterized parameters:

  • Wenc=110Wenc=[0.2000.2]\mathbf{W}'_{\text{enc}} = \frac{1}{10}\mathbf{W}_{\text{enc}} = \begin{bmatrix} 0.2 & 0 \\ 0 & 0.2 \end{bmatrix}
  • benc=110benc=[0.1,0.1]\mathbf{b}'_{\text{enc}} = \frac{1}{10}\mathbf{b}_{\text{enc}} = [-0.1, -0.1]
  • Wdec=10Wdec=[5005]\mathbf{W}'_{\text{dec}} = 10 \cdot \mathbf{W}_{\text{dec}} = \begin{bmatrix} 5 & 0 \\ 0 & 5 \end{bmatrix}
  • bdec=bdec=[0,0]\mathbf{b}'_{\text{dec}} = \mathbf{b}_{\text{dec}} = [0, 0]

Forward pass:

  1. Pre-activations: π=[0.210.1,010.1]=[0.1,0.1]\boldsymbol{\pi}' = [0.2 \cdot 1 - 0.1, 0 \cdot 1 - 0.1] = [0.1, -0.1]
  2. Feature activations: f=ReLU([0.1,0.1])=[0.1,0]\mathbf{f}' = \text{ReLU}([0.1, -0.1]) = [0.1, 0]
  3. Reconstruction: x^=[50.1,00.1]=[0.5,0]\hat{\mathbf{x}}' = [5 \cdot 0.1, 0 \cdot 0.1] = [0.5, 0]
  4. L1 penalty: f1=0.1\|\mathbf{f}'\|_1 = 0.1

Key Observations

  1. Same reconstruction: x^=x^=[0.5,0]\hat{\mathbf{x}} = \hat{\mathbf{x}}' = [0.5, 0]
  2. Same sparsity pattern: Both have only the first feature active
  3. Different L1 penalties: Original = 1, Reparameterized = 0.1

By scaling down the encoder by any factor α\alpha and scaling up the decoder by 1/α1/\alpha, we can make the L1 penalty arbitrarily small while keeping everything else the same. This is why the L1 penalty needs to be paired with decoder weight normalization constraints (typically di2=1\|\mathbf{d}_i\|_2 = 1) to be meaningful.

The RI-L1 penalty ifidi2\sum_i f_i \|\mathbf{d}_i\|_2 fixes this issue because when we scale decoder weights up, the penalty increases proportionally, canceling out the decrease from smaller feature magnitudes.

replyquoteparent

Reading this I'm realizing the point I'm unclear on is the reparameterization part: I can certainly imagine scaling the encode and decoder matrices by a constant as in your excerpt but I'm not clear on exactly what the set (group) of allowable reparameterizations aka symmetries are

replyquoteparent
Thinking Summary

The user is asking about what exactly constitutes the group of allowable reparameterizations/symmetries for SAEs. This is a good question - I gave one example of scaling, but what's the full set? Let me think about this carefully. For a SAE, we have: - Encoder: f(x) = σ(W_enc * x + b_enc) - Decoder: x_hat = W_dec * f + b_dec The key insight is that we want transformations that preserve the input-output mapping x → x_hat. The main reparameterization symmetry is: 1. For each feature i, we can scale the i-th row of W_enc and b_enc[i] by some factor α_i 2. And correspondingly scale the i-th column of W_dec by 1/α_i This works because: - Pre-activation for feature i: (W_enc[i,:] * x + b_enc[i]) gets scaled by α_i - After ReLU: if positive, feature activation f_i gets scaled by α_i - In decoder: W_dec[:,i] * f_i = (W_dec[:,i] / α_i) * (α_i * f_i) = unchanged So the full group is diagonal scaling with positive scalars - one for each feature. There's also potentially: - Permutation symmetry: we can permute the features (reorder rows of W_enc, b_enc and columns of W_dec) - But that's it for exact symmetries that preserve the function For other activation functions there might be other symmetries, but for ReLU these are the main ones.

Great question! Let me clarify the full group of reparameterization symmetries for SAEs.

The Reparameterization Group for ReLU SAEs

For a standard ReLU SAE, the allowable reparameterizations that preserve the input-output function form a group with two types of transformations:

1. Feature-wise Scaling (the main one)

For any positive diagonal matrix D=diag(α1,α2,...,αM)\mathbf{D} = \text{diag}(\alpha_1, \alpha_2, ..., \alpha_M) where αi>0\alpha_i > 0:

  • Transform encoder: Wenc=DWenc\mathbf{W}'_{\text{enc}} = \mathbf{D}\mathbf{W}_{\text{enc}}, benc=Dbenc\mathbf{b}'_{\text{enc}} = \mathbf{D}\mathbf{b}_{\text{enc}}
  • Transform decoder: Wdec=WdecD1\mathbf{W}'_{\text{dec}} = \mathbf{W}_{\text{dec}}\mathbf{D}^{-1}

This works because for each feature ii:

  • Pre-activation gets scaled: πi=αiπi\pi'_i = \alpha_i \pi_i
  • Post-ReLU activation gets scaled: fi=αifif'_i = \alpha_i f_i (when πi>0\pi_i > 0)
  • Decoder compensates: contribution to reconstruction is (Wdec[:,i]/αi)(αifi)=Wdec[:,i]fi(W_{\text{dec}}[:,i] / \alpha_i) \cdot (\alpha_i f_i) = W_{\text{dec}}[:,i] \cdot f_i

2. Feature Permutations

For any permutation matrix P\mathbf{P}:

  • Transform encoder: Wenc=PWenc\mathbf{W}'_{\text{enc}} = \mathbf{P}\mathbf{W}_{\text{enc}}, benc=Pbenc\mathbf{b}'_{\text{enc}} = \mathbf{P}\mathbf{b}_{\text{enc}}
  • Transform decoder: Wdec=WdecPT\mathbf{W}'_{\text{dec}} = \mathbf{W}_{\text{dec}}\mathbf{P}^T

This just reorders which feature is which, without changing the overall computation.

The Complete Group Structure

The full reparameterization group is the semidirect product of these two: G=(R>0)MSMG = (\mathbb{R}_{>0})^M \rtimes S_M

where:

  • (R>0)M(\mathbb{R}_{>0})^M is the group of positive scalings for each feature
  • SMS_M is the symmetric group of permutations on MM features
  • \rtimes denotes semidirect product (you can scale then permute, or permute then scale with appropriately transformed scalings)

Why This Matters

The L1 penalty ifi\sum_i |f_i| is not invariant under the scaling transformations - we can make it arbitrarily small by choosing small αi\alpha_i values. This is why we need either:

  1. Decoder normalization constraint: Fix di2=1\|\mathbf{d}_i\|_2 = 1 to break the scaling symmetry
  2. Reparameterization-invariant penalty: Use ifidi2\sum_i f_i \|\mathbf{d}_i\|_2 which is invariant under these transformations

The L0 penalty f0\|\mathbf{f}\|_0 used in JumpReLU SAEs is naturally invariant under both scaling and permutation, which is one of its advantages!

replyquoteparent

This is a loss function of the standard form Eq. 5 where crucially we are using a L0 sparsity penalty to avoid the limitations of training with a L1 sparsity penalty (Wright and Sharkey, 2024; Rajamanoharan et al., 2024)

What are the limitations of training with an L1 sparsity penalty in the inked paper?

Link not found
replyquote(1 quotes)parent
Thinking Summary

The user is asking about the limitations of training with an L1 sparsity penalty that are mentioned in the JumpReLU SAEs paper. The paper references Wright and Sharkey (2024) and Rajamanoharan et al. (2024) for this point. Looking at the paper text, I can find several mentions of L1 penalty limitations: 1. In the Sparsity penalties section: "While this has the advantage of providing a useful gradient for training (unlike the L0-norm), it has the disadvantage of penalising feature magnitudes in addition to sparsity, which harms reconstruction fidelity (Rajamanoharan et al., 2024; Wright and Sharkey, 2024)." 2. Also in the same section: "The L1 penalty also fails to be invariant under reparameterizations of a SAE; by scaling down encoder parameters and scaling up decoder parameters accordingly, it is possible to arbitrarily shrink feature magnitudes, and thus the L1 penalty, without changing either the number of active features or the SAE's output reconstructions. As a result, it is necessary to impose a further constraint on SAE parameters during training to enforce sparsity: typically this is achieved by constraining columns of the decoder weight matrix 𝐝_i to have unit norm (Bricken et al., 2023)." 3. In Section 3: "This is a loss function of the standard form Eq. 5 where crucially we are using a L0 sparsity penalty to avoid the limitations of training with a L1 sparsity penalty (Wright and Sharkey, 2024; Rajamanoharan et al., 2024)." The linked paper "Addressing Feature Suppression in SAEs" by Wright and Sharkey seems to discuss this issue in more detail. So the main limitations are: 1. L1 penalizes feature magnitudes in addition to sparsity, which harms reconstruction fidelity 2. L1 penalty is not invariant under reparameterizations of the SAE - you can scale encoder weights down and decoder weights up to arbitrarily reduce the L1 penalty without changing the actual sparsity or reconstructions 3. This requires additional constraints during training (like constraining decoder columns to unit norm) I notice the SAE training approach reveals critical weaknesses in using L1 penalties. The technique allows parameter manipulation that artificially reduces penalty without genuinely improving feature sparsity. Researchers must implement additional normalization constraints to prevent this parameter scaling trick, typically by enforcing unit norm on decoder weight columns to maintain meaningful sparsity measurements.

The paper identifies two key limitations of L1 sparsity penalties for training SAEs:

1. Feature Magnitude Suppression

The L1 penalty penalizes feature magnitudes in addition to sparsity, which harms reconstruction fidelity. This is because L1 measures the sum of absolute values, so it incentivizes making active features smaller even when they should have larger values for accurate reconstruction.

2. Lack of Reparameterization Invariance

The L1 penalty is not invariant under reparameterizations of the SAE. Specifically:

  • You can scale down encoder parameters and scale up decoder parameters proportionally
  • This arbitrarily shrinks feature magnitudes (and thus the L1 penalty)
  • But it doesn't change the number of active features or the SAE's output reconstructions
  • This creates a "cheat" where the model can reduce the loss without actually becoming more sparse

Due to this issue, L1-trained SAEs require additional constraints during training, typically constraining decoder weight matrix columns (𝐝ᵢ) to have unit norm.

The paper introduces the RI-L1 (Reparameterization-Invariant L1) penalty as an alternative:

SRI-L1(f)=i=1Mfidi2S_{\text{RI-L1}}(\mathbf{f}) = \sum_{i=1}^{M} f_i \|\mathbf{d}_i\|_2

This weights feature coefficients by the norms of corresponding dictionary directions, making it invariant to the reparameterization problem.

By using L0 instead, JumpReLU SAEs avoid both issues - L0 only counts the number of non-zero features, not their magnitudes, and is inherently invariant to reparameterization.

replyquoteparent
Thinking Summary

Oh, I see! The user was asking about what the linked paper "Addressing Feature Suppression in SAEs" by Wright and Sharkey (2024) says about the limitations of L1 sparsity penalties, not what the JumpReLU paper says about it. I should look at the content snapshot from the embedded link to see what information is available about that paper. Looking at the embedded content, I only see: - Title: "Addressing Feature Suppression in SAEs" - URL: https://www.alignmentforum.org/posts/3JuSjTZyMzaSeTxKk/addressing-feature-suppression-in-saes - Description: "Produced as part of the ML Alignment Theory Scholars Program - Winter 2023-24 Cohort as part of Lee Sharkey's stream. …" - Content Snapshot: "AI ALIGNMENT FORUMAF" Unfortunately, the content snapshot doesn't actually contain any substantive content from the Wright and Sharkey paper - it just shows "AI ALIGNMENT FORUMAF" which appears to be a header or navigation element from the website. So I cannot actually tell what Wright and Sharkey (2024) specifically says about the limitations of L1 sparsity penalties based on the provided embed. I should clarify this to the user.

You're right - I apologize for the confusion! You were asking about what Wright and Sharkey (2024) "Addressing Feature Suppression in SAEs" says about L1 limitations.

Unfortunately, the embedded content snapshot for that paper only shows "AI ALIGNMENT FORUMAF" and doesn't include the actual content of the paper. Based on just the title "Addressing Feature Suppression in SAEs", I can infer that Wright and Sharkey likely discuss how L1 penalties cause feature suppression (reducing the magnitudes of features that should be active), but I cannot see their specific arguments or findings from the provided embed.

To properly answer your question about what limitations of L1 sparsity penalties Wright and Sharkey identify, I would need access to the actual content of their paper rather than just the title and URL.

replyquoteparent
Thinking Summary

The user is asking about Equation 10 in the paper. Let me find it: Looking at the paper, Equation 9 defines the loss function: $$\mathcal{L}(\mathbf{x}):=\underbrace{\left\lVert\mathbf{x}-\hat{\mathbf{x}}(\mathbf{f}(\mathbf{x}))\right\rVert_{2}^{2}}_{\mathcal{L}_{\text{reconstruct}}}+\underbrace{\lambda\left\lVert\mathbf{f}(\mathbf{x})\right\rVert_{0}}_{\mathcal{L}_{\text{sparsity}}}.$$ Then Equation 10 says: $$\mathcal{L}_{\text{sparsity}}:=\lambda\left\lVert\mathbf{f}(\mathbf{x})\right\rVert_{0}=\lambda\sum_{i=1}^{M}H(\pi_{i}(\mathbf{x})-\theta_{i}).$$ Where $\mathbf{f}(\mathbf{x})$ is defined in Equation 8 as: $$\mathbf{f}(\mathbf{x}):=\text{JumpReLU}_{\boldsymbol{\theta}}\left(\mathbf{W}_{\text{enc}}\mathbf{x}+\mathbf{b}_{\text{enc}}\right).$$ And JumpReLU is defined in Equation 4 as: $$\text{JumpReLU}_{\theta}(z):=z\,H(z-\theta)$$ So if we have $\mathbf{f}(\mathbf{x})$ with component $f_i(\mathbf{x}) = \text{JumpReLU}_{\theta_i}(\pi_i(\mathbf{x}))$, where $\pi_i(\mathbf{x})$ is the i-th pre-activation, then: $f_i(\mathbf{x}) = \pi_i(\mathbf{x}) \cdot H(\pi_i(\mathbf{x}) - \theta_i)$ The L0 norm essentially tracks non-zero entries. Specifically, $\|f_i(\mathbf{x})\|_0$ equals 1 when $f_i(\mathbf{x})$ is non-zero, and 0 otherwise. This means the component becomes zero only when the pre-activation falls below the threshold. The Heaviside function determines when the feature becomes active, effectively creating a sparse representation by selectively activating features based on their pre-activation values relative to learned thresholds. Consequently, the total L0 norm across all features sums the number of activated components, providing a compact measure of feature sparsity in the representation.

Yes, exactly! Equation 10 is showing that when you expand the L0 norm of the feature vector f(x)\mathbf{f}(\mathbf{x}) using the JumpReLU definition, you get a sum of Heaviside step functions.

Here's the reasoning:

  • Each feature fi(x)=JumpReLUθi(πi(x))=πi(x)H(πi(x)θi)f_i(\mathbf{x}) = \text{JumpReLU}_{\theta_i}(\pi_i(\mathbf{x})) = \pi_i(\mathbf{x}) \cdot H(\pi_i(\mathbf{x}) - \theta_i)
  • The L0 norm counts non-zero entries
  • fi(x)f_i(\mathbf{x}) is non-zero exactly when H(πi(x)θi)=1H(\pi_i(\mathbf{x}) - \theta_i) = 1 (since pre-activations are continuous, πi(x)=0\pi_i(\mathbf{x}) = 0 happens with probability 0)
  • Therefore fi(x)0=H(πi(x)θi)\|f_i(\mathbf{x})\|_0 = H(\pi_i(\mathbf{x}) - \theta_i)

So the sparsity penalty becomes: Lsparsity=λf(x)0=λi=1MH(πi(x)θi)\mathcal{L}_{\text{sparsity}} = \lambda\|\mathbf{f}(\mathbf{x})\|_0 = \lambda\sum_{i=1}^{M}H(\pi_i(\mathbf{x})-\theta_i)

This reformulation is crucial for the paper's approach because it shows that the threshold parameters θi\theta_i only appear inside Heaviside functions in the loss, which is what motivates their use of straight-through estimators to get gradient signals for training these thresholds.

replyquoteparent
Thinking Summary

The user is asking for a tutorial on straight-through estimators (STEs). This is a good opportunity to explain this important concept clearly, as it's central to the JumpReLU paper and many other works in deep learning. I should: 1. Start with the motivation/problem 2. Explain the basic idea 3. Give a simple example 4. Show how it works in practice 5. Discuss when/why it works 6. Maybe mention some applications I should keep it concise but comprehensive, and use examples to make it concrete.

A Tutorial on Straight-Through Estimators (STEs)

The Problem

Imagine you're training a neural network with a step function:

def step(x):
    return 1.0 if x > 0 else 0.0

The gradient of this function is zero everywhere (except at x=0 where it's undefined). If you try to backpropagate through it, you get zero gradients - your parameters won't update!

This problem appears whenever we have:

  • Discrete decisions (rounding, thresholding, sampling)
  • Piecewise constant functions
  • Non-differentiable operations

The Core Idea

A Straight-Through Estimator (STE) says: "In the forward pass, use the non-differentiable function. In the backward pass, pretend it was something differentiable."

The simplest STE is the identity STE:

# Forward: y = step(x)
# Backward: pretend dy/dx = 1

A Concrete Example

Let's implement a sign function with an STE:

import torch

class SignSTE(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return torch.sign(x)  # Returns -1, 0, or 1
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output    # Pretend derivative is 1

Now gradients can flow through the sign function!

How STEs Work in Practice

Consider training a network to output binary values:

# Without STE: gradients get stuck
logits = model(input)
binary_output = (logits > 0).float()  # Gradient is 0!
loss = criterion(binary_output, target)

# With STE: gradients flow
logits = model(input)
binary_output = SignSTE.apply(logits)  # Gradient passes through
loss = criterion(binary_output, target)

The JumpReLU Example

In the JumpReLU paper, they need to train through:

  • JumpReLUθ(z)=zH(zθ)\text{JumpReLU}_\theta(z) = z \cdot H(z - \theta)
  • The L0 penalty: iH(πi(x)θi)\sum_i H(\pi_i(x) - \theta_i)

Both involve the Heaviside step function HH, which has zero gradient with respect to θ\theta.

Their STE solution uses a "soft" gradient in a small window:

# Forward: use actual step function
y = (x > theta).float()

# Backward: use soft gradient
grad_theta = -(1/ε) * rectangle((x - theta)/ε) * grad_output

This gives gradient signal when xx is within ϵ\epsilon of θ\theta.

When Do STEs Work?

STEs are surprisingly effective when:

  1. The forward function preserves information: Even though sign(x)\text{sign}(x) loses magnitude, it preserves direction
  2. Gradients point in roughly the right direction: We don't need exact gradients, just reasonable descent directions
  3. The mismatch is local: STEs work best when the difference between forward and backward is small

Common Applications

  1. Quantized Neural Networks: Training networks with discrete weights/activations
  2. Discrete Latent Variables: Training VAEs with categorical variables
  3. Hard Attention: Training attention mechanisms that make discrete selections
  4. Sparse Models: Like the JumpReLU SAEs, where features are either on or off

Key Insights

The magic of STEs is that optimization often just needs "good enough" gradients. Perfect gradients are great, but biased gradients that point roughly the right way can still find good solutions.

As the JumpReLU paper shows, you can even design custom STEs based on your problem structure - they use kernel density estimation to create gradients that estimate the true expected loss gradient!

Code Example: Threshold Learning

Here's a simplified version of how JumpReLU uses STEs:

class ThresholdSTE(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, threshold, bandwidth=0.001):
        ctx.save_for_backward(x, threshold)
        ctx.bandwidth = bandwidth
        return x * (x > threshold)
    
    @staticmethod  
    def backward(ctx, grad_output):
        x, threshold = ctx.saved_tensors
        ε = ctx.bandwidth
        
        # Gradient w.r.t x: normal gradient where x > threshold
        grad_x = grad_output * (x > threshold)
        
        # Gradient w.r.t threshold: custom STE
        in_window = ((x - threshold).abs() < ε/2)
        grad_threshold = -threshold/ε * in_window * grad_output
        
        return grad_x, grad_threshold, None

This allows learning thresholds even though the step function has zero gradient!

replyquoteparent
Thinking Summary

The user is asking about the notation in equation 11 of the paper, where they use ð (eth) instead of the standard partial derivative symbol ∂. Let me look at the context. Looking at equation 11: ``` ðð⁢θ⁢JumpReLUθ⁢(z):=−θε⁢K⁢(z−θε) ``` And in the footnote 4, they explain: "We use the notation ð/ð⁢z to denote pseudo-derivatives, to avoid conflating them with actual partial derivatives for these functions." So they're using ð (the lowercase eth letter from Old English and Icelandic) to denote "pseudo-derivatives" rather than true partial derivatives. This is because the JumpReLU function and Heaviside step function are not differentiable in the classical sense at the threshold point - they have discontinuities. The authors are defining custom gradients (straight-through estimators) for these non-differentiable functions. This doesn't seem to be a widely established convention - it appears to be a notation choice made by the authors of this specific paper to clearly distinguish between: - True partial derivatives (∂) - Their custom-defined pseudo-derivatives for the backward pass (ð) The eth symbol (ð) is sometimes used in mathematics, but not typically for this purpose. It's more commonly seen in: - Icelandic and Old English text - Some areas of mathematics like differential geometry (though not commonly) - Phonetics (to represent the voiced dental fricative) The authors are being careful to make it clear that these aren't true derivatives but rather custom gradient definitions used for training through non-differentiable functions.

Great observation! The ð (eth) symbol is not a standard mathematical convention - it's a notation choice specific to this paper.

The authors explain in footnote 4:

"We use the notation ð/ð⁢z to denote pseudo-derivatives, to avoid conflating them with actual partial derivatives for these functions."

They're using ð instead of the standard ∂ (partial derivative) because the JumpReLU and Heaviside functions aren't differentiable in the classical sense - they have discontinuities at the threshold. The "pseudo-derivatives" are custom gradient definitions they create for the backward pass during training.

This is actually quite thoughtful notation design! Since they're defining straight-through estimators (STEs) that approximate gradients where true gradients don't exist or are zero, using a visually similar but distinct symbol helps avoid confusion. It makes it immediately clear when they're talking about:

  • : True partial derivatives (for differentiable functions)
  • ð: Their custom-defined gradients for training (for non-differentiable points)

The eth symbol (ð) comes from Old English and Icelandic alphabets. While it's occasionally used in some mathematical contexts, this specific usage as "pseudo-derivative" appears to be novel to this paper. It's a clever way to maintain mathematical clarity while introducing new gradient approximations for discontinuous functions.

replyquoteparent
Thinking Summary

The user is asking about the epsilon parameter used in the pseudo-derivatives defined in equations 11 and 12. Let me look at what the paper says about this. From the paper: 1. In equations 11 and 12, epsilon appears as a bandwidth parameter in the kernel function K. 2. The paper mentions that epsilon plays the role of a KDE (kernel density estimation) bandwidth. 3. In footnote 5, they mention: "For the experiments in this paper, we swept this parameter and found ε=0.001 (assuming a dataset normalised such that 𝔼𝐱[𝐱²]=1) works well across different models, layers and sites. However, we suspect there are more principled ways to determine this parameter, borrowing from the literature on KDE bandwidth selection." 4. In Section 4, they explain: "As we show in Section 4, the hyperparameter ε plays the role of a KDE bandwidth, and needs to be selected accordingly: too low and gradient estimates become too noisy, too high and estimates become too biased." 5. The paper shows that the STEs are actually computing kernel density estimates of the true gradient of the expected loss. So epsilon is a bandwidth parameter that controls the width of the window around the threshold where gradients are non-zero. It's a hyperparameter they had to tune.

Excellent question! The epsilon (ε) parameter is essentially a hyperparameter that controls the "bandwidth" or width of the window where the pseudo-derivatives are non-zero.

Here's what the paper tells us about it:

Where it comes from

  1. It's borrowed from kernel density estimation (KDE): The authors show that their STEs are actually computing kernel density estimates of the true gradient. In KDE, you always need a bandwidth parameter that controls the smoothing.

  2. They found it empirically: From footnote 5:

    "For the experiments in this paper, we swept this parameter and found ε=0.001 (assuming a dataset normalised such that 𝔼[𝐱²]=1) works well across different models, layers and sites."

What it does

Looking at the pseudo-derivatives:

ððθH(zθ)=1εK(zθε)\frac{\text{ð}}{\text{ð}\theta}H(z-\theta) = -\frac{1}{\varepsilon}K\left(\frac{z-\theta}{\varepsilon}\right)

The epsilon controls:

  • Width of the gradient window: Only when zθ<ε/2|z - \theta| < \varepsilon/2 (for rectangular kernel) do we get non-zero gradients
  • Magnitude of gradients: Smaller ε means larger gradient magnitudes (note the 1/ε factor)

The tradeoff

As they explain in Section 4:

  • ε too small: Gradient estimates become too noisy (few samples fall in the narrow window)
  • ε too large: Estimates become too biased (you're averaging over too wide a region)

Visual intuition

From Figure 4, you can see that epsilon defines a small window around the threshold where the pseudo-derivative is non-zero. Outside this window, the gradient is zero (just like the true derivative).

Why not something more principled?

The authors acknowledge this is somewhat ad-hoc:

"However, we suspect there are more principled ways to determine this parameter, borrowing from the literature on KDE bandwidth selection."

So epsilon is essentially a "gradient window width" hyperparameter they tuned empirically. It's small enough to be local (ε=0.001) but large enough to catch some activations and provide gradient signal. The fact that the same value works across different models and layers suggests it's reasonably robust, though more principled selection methods could likely improve it.

replyquoteparent
Thinking Summary

The user is asking about the definition of "fraction of variance unexplained" from the cited paper. Looking at the embed they shared, this is from the paper "Scaling and evaluating sparse autoencoders" by Gao et al. 2024. Let me search through the content for this definition. Looking at the paper content... Found it! In the "Setup" section under Methods (Section 2.1), they write: "Evaluation: After training, we evaluate autoencoders on sparsity L0, and reconstruction mean-squared error (MSE). We report a normalized version of all MSE numbers, where we divide by a baseline reconstruction error of always predicting the mean activations." So the fraction of variance unexplained (FVU) is the mean reconstruction loss (MSE) normalized by the reconstruction loss obtained by always predicting the dataset mean. This is a standard way to normalize reconstruction error - it's the fraction of variance that remains unexplained by the model. The user is referring to this in the context of Section 5.1 of the JumpReLU paper, where they mention FVU as a secondary metric for measuring reconstruction fidelity.

Looking at the Gao et al. paper, the definition of "fraction of variance unexplained" (FVU) can be found in Section 2.1 under "Evaluation":

Evaluation: After training, we evaluate autoencoders on sparsity L0, and reconstruction mean-squared error (MSE). We report a normalized version of all MSE numbers, where we divide by a baseline reconstruction error of always predicting the mean activations.

So FVU is calculated as:

FVU=MSE of SAE reconstructionMSE of always predicting the mean\text{FVU} = \frac{\text{MSE of SAE reconstruction}}{\text{MSE of always predicting the mean}}

This is also referred to as the "normalized loss" in the paper. It's a standard normalization that measures what fraction of the variance in the data remains unexplained by the model. A perfect reconstruction would have FVU = 0, while a model that does no better than predicting the mean would have FVU = 1.

The JumpReLU paper uses this as a secondary metric alongside their primary metric of delta LM loss (the increase in cross-entropy loss when the SAE is spliced into the LM's forward pass).

replyquoteparent
Thinking Summary

The user wants me to provide more context about the setup for FVU (fraction of variance unexplained). Let me look at both papers to understand: 1. What data is being used 2. What the MSE is calculated over 3. What "mean activations" refers to From the JumpReLU paper Section 5: - They trained SAEs on Gemma 2 9B activations - Specifically on residual stream, attention output, and MLP output sites after layers 9, 20, and 31 - All metrics computed on 2,048 sequences of length 1,024, excluding special tokens From the Gao et al. paper Section 2.1: - They train on residual streams of GPT-2 small and GPT-4 models - They use a layer 5/6 of the way into the network for GPT-4 series models, and layer 8 (3/4 of the way) for GPT-2 small - Context length of 64 tokens - They subtract the mean over the d_model dimension and normalize to unit norm before passing to autoencoder So the setup involves: 1. Data: Language model activations (residual stream, MLP outputs, or attention outputs) from specific layers 2. The MSE is calculated over tokens/positions in sequences 3. The "mean activations" baseline is the mean activation vector across the dataset

Sure! Let me provide more context on the setup for FVU (fraction of variance unexplained):

Data Setup

What data is being evaluated:

  • The JumpReLU paper evaluates SAEs trained on Gemma 2 9B activations at specific sites:
    • Residual stream activations
    • MLP output activations
    • Attention output activations (before the output projection)
  • These are taken from layers 9, 20, and 31 of the model

Preprocessing (from Gao et al.):

  • Activations are preprocessed by:
    1. Subtracting the mean over the d_model dimension
    2. Normalizing to unit norm

Evaluation dataset:

  • All metrics are computed on 2,048 sequences of length 1,024 tokens
  • Special tokens (pad, start, end of sequence) are excluded

What the MSE is calculated over

The MSE measures reconstruction error: MSE=Ex[xx^22]\text{MSE} = \mathbb{E}_{\mathbf{x}}[\|\mathbf{x} - \hat{\mathbf{x}}\|_2^2]

Where:

  • x\mathbf{x} is the original activation vector (e.g., dimension 2304 for Gemma 2 9B)
  • x^\hat{\mathbf{x}} is the SAE reconstruction
  • The expectation is over all token positions in the evaluation dataset

The baseline for normalization

The "mean activations" baseline is:

  • Simply predicting the dataset mean activation vector for every input
  • This gives the variance of the data as the MSE
  • So FVU = (SAE's MSE) / (Variance of data)

This normalization makes the metric interpretable: FVU = 0.1 means the SAE fails to explain 10% of the variance in the activations.

replyquoteparent

This is either naive or pedantic or both but in what sense is this "variance"? I thought variance was defined as E[(xE[x])2]\mathbb E [(x - \mathbb E [x])^2] ie the expectation of squared deviation from the mean. but here we are using reconstruction error.

replyquoteparent
Thinking Summary

The user is making a good point. They're asking about why this is called "fraction of variance unexplained" when the denominator is the MSE of predicting the mean, which is indeed the variance, but the numerator is the MSE of the SAE reconstruction, not necessarily related to variance in the traditional sense. Let me think about this carefully: 1. The denominator: MSE of always predicting the mean = $\mathbb{E}[\|\mathbf{x} - \mathbb{E}[\mathbf{x}]\|^2]$ = variance of the data 2. The numerator: MSE of SAE reconstruction = $\mathbb{E}[\|\mathbf{x} - \hat{\mathbf{x}}\|^2]$ where $\hat{\mathbf{x}}$ is the SAE reconstruction The name "fraction of variance unexplained" comes from the fact that: - If you could perfectly reconstruct, MSE = 0, so FVU = 0 (all variance explained) - If your reconstruction is no better than predicting the mean, MSE = variance, so FVU = 1 (no variance explained) - In between, FVU tells you what fraction of the total variance remains as error This is actually a standard concept in statistics, related to R² (coefficient of determination). In fact, FVU = 1 - R². The user is right that it's not literally "variance" in the numerator, but it's the residual variance after the model's predictions.

You're absolutely right to point this out! This is actually a really good question about the terminology.

You're correct that variance is E[(xE[x])2]\mathbb{E}[(x - \mathbb{E}[x])^2]. The name "fraction of variance unexplained" comes from statistical tradition, particularly from linear regression where this concept originates.

Here's why it makes sense:

The denominator is indeed the variance: MSEmean=E[xE[x]2]=Var(x)\text{MSE}_{\text{mean}} = \mathbb{E}[\|\mathbf{x} - \mathbb{E}[\mathbf{x}]\|^2] = \text{Var}(\mathbf{x})

The numerator is the MSE of reconstruction: MSESAE=E[xx^2]\text{MSE}_{\text{SAE}} = \mathbb{E}[\|\mathbf{x} - \hat{\mathbf{x}}\|^2]

The interpretation is:

  • The total variance in the data represents all the "variation" we'd like to explain
  • The reconstruction error represents the "residual variance" - the variation that remains unexplained after our model
  • So FVU measures what fraction of the original variance is still left as error

This is exactly analogous to R2R^2 (coefficient of determination) in statistics: FVU=1R2=Residual sum of squaresTotal sum of squares\text{FVU} = 1 - R^2 = \frac{\text{Residual sum of squares}}{\text{Total sum of squares}}

So while the numerator isn't literally "variance" in the formal sense, it represents the variance that remains unexplained. The terminology comes from thinking of the total variance as something to be "explained away" by the model, with perfect explanation leaving zero residual error.

replyquoteparent
Thinking Summary

The user is asking about Figure 5 from the "Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders" paper. Let me look at what Figure 5 shows. Figure 5 shows: - The proportion of features that activate very frequently versus delta LM loss by SAE type for Gemma 2 9B residual stream SAEs - Two plots: top one shows features active on over 10% of tokens, bottom one shows features active on over 1% of tokens - Compares TopK, JumpReLU, and Gated SAEs The significance of plotting delta LM loss against activation frequency relates to understanding the trade-off between: 1. Reconstruction fidelity (measured by delta LM loss) 2. The characteristics of learned features (specifically how often they activate) Let me explain why this matters in the context of SAEs.

Figure 5 is plotting two important characteristics of Sparse Autoencoders (SAEs) against each other to reveal a key trade-off:

Delta LM Loss (X-axis)

  • What it measures: The increase in cross-entropy loss when the SAE is spliced into the language model's forward pass
  • Why it matters: This is a measure of reconstruction fidelity - lower values mean the SAE is better at faithfully reconstructing the original activations without harming the model's performance
  • Interpretation: Moving left on the x-axis means better reconstruction quality

Feature Activation Frequency (Y-axis)

  • What it measures: The proportion of SAE features that fire very frequently (>10% of tokens in top plot, >1% in bottom plot)
  • Why it matters: Very high frequency features tend to be:
    • Less interpretable (they activate so often they may not represent specific, meaningful concepts)
    • Less sparse (defeating the purpose of "sparse" autoencoders)
    • Potentially acting more like bias terms than meaningful feature detectors

The Significance of This Plot

  1. Different architectures have different characteristics: The plot reveals that TopK and JumpReLU SAEs tend to learn more high-frequency features than Gated SAEs, especially in the low delta LM loss (high fidelity) regime.

  2. Trade-off visualization: As you improve reconstruction fidelity (move left), you tend to get more high-frequency features. This suggests there's a fundamental tension between:

    • Making accurate reconstructions (low delta LM loss)
    • Learning sparse, interpretable features (few high-frequency features)
  3. Architecture comparison at fixed performance: By plotting against delta LM loss rather than just comparing raw frequency counts, the authors ensure they're comparing architectures at equivalent levels of reconstruction quality. This is a fairer comparison than just looking at frequency distributions alone.

  4. Practical implications: For practitioners, this helps in choosing architectures - if you care more about avoiding high-frequency features for interpretability, you might prefer Gated SAEs despite slightly worse reconstruction fidelity.

The paper notes that while JumpReLU has more very high frequency features (>10%) like TopK, the absolute numbers remain small (<0.06% of all features even in the worst case), suggesting this may not be a major practical concern.

replyquoteparent
Thinking Summary

The user is asking about the x-axis labeling in Figure 5. Looking at the description they provided, it seems like they're seeing axis labels that suggest a logarithmic scale. The pattern "2 5 0.1 2 5" with 0.1 being larger suggests this is indeed a log scale, likely showing values like 0.02, 0.05, 0.1, 0.2, 0.5 or similar. Let me check the paper for information about delta LM loss units and whether they use log scales. Looking at the paper, delta LM loss is defined as "the increase in the cross-entropy loss experienced by the LM when we splice the SAE into the LM's forward pass." Cross-entropy loss is typically measured in nats (natural log) or bits (log base 2), and the delta would be in the same units. The paper doesn't explicitly state the units, but cross-entropy loss for language models is commonly expressed in nats per token or bits per token. Since this is a delta (difference), it would be the increase in nats/bits per token. The logarithmic scale makes sense because these loss differences can span multiple orders of magnitude, from very small (good reconstruction) to larger values (poor reconstruction).

Yes, you're absolutely right - this is a logarithmic scale on the x-axis! The pattern you're describing (2, 5, 0.1, 2, 5) is the classic labeling for a log scale, where it cycles through powers of 10. It would represent values like:

  • 0.02, 0.05, 0.1, 0.2, 0.5 (or similar sequences)

The "0.1" appears larger because it's the major tick mark at a power of 10, while 2 and 5 are minor tick marks within each decade.

Units for Delta LM Loss

Delta LM loss measures the increase in cross-entropy loss when the SAE is inserted into the model. Cross-entropy loss for language models is typically measured in:

  • Nats per token (using natural logarithm)
  • Bits per token (using log base 2)

The paper doesn't explicitly state which unit they're using, but nats per token is more common in deep learning contexts.

Since this is a delta (difference), the units would be the increase in nats/bits per token. For example:

  • A delta LM loss of 0.1 nats means the model's perplexity increases by a factor of e^0.1 ≈ 1.105 when using the SAE reconstruction
  • A very small delta (e.g., 0.01) means the SAE barely affects model performance
  • A larger delta (e.g., 0.5) indicates significant degradation

The logarithmic scale is appropriate here because these deltas can span multiple orders of magnitude - from nearly perfect reconstructions (delta ≈ 0.001) to quite poor ones (delta ≈ 1.0), and a log scale helps visualize this full range effectively.

replyquoteparent

Ok, I'm not understanding how I move along the graph though. For a given SAE the SAE is fixed right? So it has a certain LM loss and a certain proportor of features activity on 10% of tokens and both of those are "fixed” right? Where does the series in the chart come from?

replyquoteparent
Thinking Summary

Ah, the user is confused about what creates the curve/series in the plot. They think each SAE should just be a single point since it has fixed characteristics. Let me clarify that each point represents a different SAE trained with different hyperparameters.

Excellent question! You're right that each individual SAE, once trained, has fixed characteristics. The key insight is that each point on the curves represents a different SAE.

How the Series is Generated

The authors trained multiple SAEs of each type (JumpReLU, Gated, TopK) with different hyperparameters:

  1. For JumpReLU and Gated SAEs: They varied the sparsity coefficient λ

    • Higher λ → more sparse → fewer features active → worse reconstruction (higher delta LM loss)
    • Lower λ → less sparse → more features active → better reconstruction (lower delta LM loss)
  2. For TopK SAEs: They varied K (the number of top features to keep)

    • Lower K → more sparse → worse reconstruction
    • Higher K → less sparse → better reconstruction

Creating the Curves

So the process was:

  1. Train many SAEs with different sparsity settings (different λ or K values)
  2. For each trained SAE, measure:
    • Its delta LM loss (reconstruction quality)
    • What proportion of its features fire on >10% of tokens
  3. Plot each SAE as a point
  4. Connect points of the same architecture type to form curves

Why This Matters

This allows them to answer: "If I want a JumpReLU SAE with delta LM loss of 0.2, how many high-frequency features should I expect?" You can't just train one SAE and move along the curve - you need to train multiple SAEs with different sparsity penalties to explore the trade-off space.

The curves show the Pareto frontier - the best achievable trade-off between reconstruction quality and feature frequency characteristics for each architecture type.

replyquoteparent