Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders
Thread (53 posts)
Our key insight is to notice that although such a loss function is piecewise-constant with respect to the threshold – and therefore provides zero gradient to train this parameter – the derivative of the expected loss can be analytically derived, and is generally non-zero, albeit it is expressed in terms of probability densities of the feature activation distribution that need to be estimated.
First part makes sense but when they say expected loss, what is the precise setup for that expectation?
Thinking Summary
The user is asking about the expected loss in the JumpReLU SAE paper. Looking at the paper, they introduce a loss function that has a JumpReLU activation function with threshold parameters θ. The loss function is piecewise constant with respect to θ, so the gradient is zero. But they claim the derivative of the expected loss is non-zero. The user wants to know what expectation they're talking about. Looking at the paper, in Section 4 "How STEs enable training through the jump", they say: "Why does this work? The key is to notice that during SGD, we actually want to estimate the gradient of the expected loss, 𝔼_𝐱[ℒ_θ(𝐱)], in order to calculate parameter updates" So the expectation is over the data distribution - they're taking the expectation over input activations 𝐱 drawn from the dataset 𝒟. In Appendix B, they derive this more formally. The loss function is: ℒ_θ(𝐱) = ||𝐱 - 𝐱̂(𝐟(𝐱))||²₂ + λ||𝐟(𝐱)||₀ And they want to compute: ∂𝔼_𝐱[ℒ_θ(𝐱)]/∂θᵢ where 𝐱 are samples from the training data distribution. The key insight is that while ℒ_θ(𝐱) is piecewise constant in θ for any fixed 𝐱, the expected loss 𝔼_𝐱[ℒ_θ(𝐱)] is smooth because you're averaging over many different values of 𝐱, and the locations of the discontinuities vary with 𝐱.
Looking at Section 4 and Appendix B of the paper, the expectation is taken over the data distribution. Specifically:
𝔼_𝐱[ℒ_θ(𝐱)] where 𝐱 ~ 𝒟
Here, 𝐱 represents language model activations sampled from the training dataset 𝒟.
The key insight is that while the loss ℒ_θ(𝐱) is piecewise constant in θ for any fixed input 𝐱 (because the JumpReLU and Heaviside functions create discontinuities), when you take the expectation over many different inputs 𝐱, the expected loss becomes smooth and differentiable.
This happens because:
- For a fixed 𝐱, the discontinuities occur when pre-activations π_i(𝐱) equal the threshold θ_i
- Different inputs 𝐱 have different pre-activation values
- So the discontinuity locations vary across the data distribution
- Averaging over the distribution smooths out these discontinuities
In Equation 13, they show the derivative of this expected loss:
where is the probability density of pre-activations across the data distribution.
Looking at Section 4 and Appendix B of the paper, the expectation is taken over the data distribution. Specifically:
where
Here, represents language model activations sampled from the training dataset .
The key insight is that while the loss is piecewise constant in for any fixed input (because the JumpReLU and Heaviside functions create discontinuities), when you take the expectation over many different inputs , the expected loss becomes smooth and differentiable.
This happens because:
- For a fixed , the discontinuities occur when pre-activations equal the threshold
- Different inputs $\mathb
Generation Error: streaming error: received error while streaming: {"type":"error","error":{"details":null,"type":"overloaded_error","message":"Overloaded"} }