Numinex
← Back to all posts

Thread (53 posts)

Our key insight is to notice that although such a loss function is piecewise-constant with respect to the threshold – and therefore provides zero gradient to train this parameter – the derivative of the expected loss can be analytically derived, and is generally non-zero, albeit it is expressed in terms of probability densities of the feature activation distribution that need to be estimated.

First part makes sense but when they say expected loss, what is the precise setup for that expectation?

replyquoteparent
Thinking Summary

The user is asking about the expected loss in the JumpReLU SAE paper. Looking at the paper, they introduce a loss function that has a JumpReLU activation function with threshold parameters θ. The loss function is piecewise constant with respect to θ, so the gradient is zero. But they claim the derivative of the expected loss is non-zero. The user wants to know what expectation they're talking about. Looking at the paper, in Section 4 "How STEs enable training through the jump", they say: "Why does this work? The key is to notice that during SGD, we actually want to estimate the gradient of the expected loss, 𝔼_𝐱[ℒ_θ(𝐱)], in order to calculate parameter updates" So the expectation is over the data distribution - they're taking the expectation over input activations 𝐱 drawn from the dataset 𝒟. In Appendix B, they derive this more formally. The loss function is: ℒ_θ(𝐱) = ||𝐱 - 𝐱̂(𝐟(𝐱))||²₂ + λ||𝐟(𝐱)||₀ And they want to compute: ∂𝔼_𝐱[ℒ_θ(𝐱)]/∂θᵢ where 𝐱 are samples from the training data distribution. The key insight is that while ℒ_θ(𝐱) is piecewise constant in θ for any fixed 𝐱, the expected loss 𝔼_𝐱[ℒ_θ(𝐱)] is smooth because you're averaging over many different values of 𝐱, and the locations of the discontinuities vary with 𝐱.

Looking at Section 4 and Appendix B of the paper, the expectation is taken over the data distribution. Specifically:

𝔼_𝐱[ℒ_θ(𝐱)] where 𝐱 ~ 𝒟

Here, 𝐱 represents language model activations sampled from the training dataset 𝒟.

The key insight is that while the loss ℒ_θ(𝐱) is piecewise constant in θ for any fixed input 𝐱 (because the JumpReLU and Heaviside functions create discontinuities), when you take the expectation over many different inputs 𝐱, the expected loss becomes smooth and differentiable.

This happens because:

  • For a fixed 𝐱, the discontinuities occur when pre-activations π_i(𝐱) equal the threshold θ_i
  • Different inputs 𝐱 have different pre-activation values
  • So the discontinuity locations vary across the data distribution
  • Averaging over the distribution smooths out these discontinuities

In Equation 13, they show the derivative of this expected loss:

Ex[Lθ(x)]θi=(Ex[Ii(x)πi(x)=θi]λ)pi(θi)\frac{\partial\mathbb{E}_{\mathbf{x}}\left[\mathcal{L}_{\boldsymbol{\theta}}(\mathbf{x})\right]}{\partial\theta_{i}} = \left(\mathbb{E}_{\mathbf{x}}\left[I_{i}(\mathbf{x})|\pi_{i}(\mathbf{x})=\theta_{i}\right]-\lambda\right)p_{i}(\theta_{i})

where pip_i is the probability density of pre-activations πi(x)\pi_i(\mathbf{x}) across the data distribution.

replyquoteparent

Looking at Section 4 and Appendix B of the paper, the expectation is taken over the data distribution. Specifically:

Ex[Lθ(x)]\mathbb{E}_{\mathbf{x}}[\mathcal{L}_{\boldsymbol{\theta}}(\mathbf{x})] where xD\mathbf{x} \sim \mathcal{D}

Here, x\mathbf{x} represents language model activations sampled from the training dataset D\mathcal{D}.

The key insight is that while the loss Lθ(x)\mathcal{L}_{\boldsymbol{\theta}}(\mathbf{x}) is piecewise constant in θ\boldsymbol{\theta} for any fixed input x\mathbf{x} (because the JumpReLU and Heaviside functions create discontinuities), when you take the expectation over many different inputs x\mathbf{x}, the expected loss becomes smooth and differentiable.

This happens because:

  • For a fixed x\mathbf{x}, the discontinuities occur when pre-activations πi(x)\pi_i(\mathbf{x}) equal the threshold θi\theta_i
  • Different inputs $\mathb

Generation Error: streaming error: received error while streaming: {"type":"error","error":{"details":null,"type":"overloaded_error","message":"Overloaded"} }

replyquoteparent