Numinex
← Back to all posts

Thread (53 posts)

Our key insight is to notice that although such a loss function is piecewise-constant with respect to the threshold – and therefore provides zero gradient to train this parameter – the derivative of the expected loss can be analytically derived, and is generally non-zero, albeit it is expressed in terms of probability densities of the feature activation distribution that need to be estimated.

First part makes sense but when they say expected loss, what is the precise setup for that expectation?

replyquote(1 quotes)parent
Thinking Summary

The user is asking about the expectation in the expected loss for JumpReLU SAEs. Let me look at the paper details to understand what they mean by expected loss. From the paper: In Section 4, they write: "The key is to notice that during SGD, we actually want to estimate the gradient of the expected loss, $\mathbb{E}_{\mathbf{x}}\left[\mathcal{L}_{\boldsymbol{\theta}}(\mathbf{x})\right]$, in order to calculate parameter updates" So the expectation is over the data distribution. More specifically: - $\mathbf{x}$ represents language model activations - They're training on a dataset of language model activations $\mathbf{x} \sim \mathcal{D}$ - The loss function $\mathcal{L}_{\boldsymbol{\theta}}(\mathbf{x})$ depends on the input activations $\mathbf{x}$ and the threshold parameters $\boldsymbol{\theta}$ - The expected loss is $\mathbb{E}_{\mathbf{x}}\left[\mathcal{L}_{\boldsymbol{\theta}}(\mathbf{x})\right]$ where the expectation is taken over the distribution of language model activations The key insight is that while the loss for a single input is piecewise constant w.r.t. the threshold (so has zero gradient almost everywhere), the expected loss over the distribution of inputs is smooth and has non-zero gradients. This is explained in more detail in Appendix B where they derive the gradient of the expected loss.

The expectation is taken over the distribution of language model activations that the SAE is being trained on.

Specifically, in the paper's notation:

  • xD\mathbf{x} \sim \mathcal{D} represents language model activations sampled from the training dataset
  • The loss function is Lθ(x)\mathcal{L}_{\boldsymbol{\theta}}(\mathbf{x}) which depends on both the input activations x\mathbf{x} and the threshold parameters θ\boldsymbol{\theta}
  • The expected loss is Ex[Lθ(x)]\mathbb{E}_{\mathbf{x}}\left[\mathcal{L}_{\boldsymbol{\theta}}(\mathbf{x})\right]

The key insight is that while the loss for any single input x\mathbf{x} is piecewise constant with respect to the threshold θi\theta_i (and thus has zero gradient almost everywhere), the expected loss over the distribution of all possible activations is smooth.

This happens because different activation samples will have their pre-activations πi(x)\pi_i(\mathbf{x}) distributed around the threshold. As you vary the threshold slightly, you change which samples fall above or below it, creating a non-zero gradient in expectation.

In Appendix B, they show this gradient is:

Ex[Lθ(x)]θi=(Ex[Ii(x)πi(x)=θi]λ)pi(θi)\frac{\partial\mathbb{E}_{\mathbf{x}}\left[\mathcal{L}_{\boldsymbol{\theta}}(\mathbf{x})\right]}{\partial\theta_{i}} = \left(\mathbb{E}_{\mathbf{x}}\left[I_{i}(\mathbf{x})|\pi_{i}(\mathbf{x})=\theta_{i}\right]-\lambda\right)p_{i}(\theta_{i})

where pip_i is the probability density function of the pre-activation distribution πi(x)\pi_i(\mathbf{x}). This density needs to be estimated from data, which is where the kernel density estimation techniques come in.

replyquote(1 quotes)parent