Numinex
← Back to all posts

Thread (53 posts)

Thinking Summary

The user is asking about Equation 10 in the paper. Let me find it: Looking at the paper, Equation 9 defines the loss function: $$\mathcal{L}(\mathbf{x}):=\underbrace{\left\lVert\mathbf{x}-\hat{\mathbf{x}}(\mathbf{f}(\mathbf{x}))\right\rVert_{2}^{2}}_{\mathcal{L}_{\text{reconstruct}}}+\underbrace{\lambda\left\lVert\mathbf{f}(\mathbf{x})\right\rVert_{0}}_{\mathcal{L}_{\text{sparsity}}}.$$ Then Equation 10 says: $$\mathcal{L}_{\text{sparsity}}:=\lambda\left\lVert\mathbf{f}(\mathbf{x})\right\rVert_{0}=\lambda\sum_{i=1}^{M}H(\pi_{i}(\mathbf{x})-\theta_{i}).$$ Where $\mathbf{f}(\mathbf{x})$ is defined in Equation 8 as: $$\mathbf{f}(\mathbf{x}):=\text{JumpReLU}_{\boldsymbol{\theta}}\left(\mathbf{W}_{\text{enc}}\mathbf{x}+\mathbf{b}_{\text{enc}}\right).$$ And JumpReLU is defined in Equation 4 as: $$\text{JumpReLU}_{\theta}(z):=z\,H(z-\theta)$$ So if we have $\mathbf{f}(\mathbf{x})$ with component $f_i(\mathbf{x}) = \text{JumpReLU}_{\theta_i}(\pi_i(\mathbf{x}))$, where $\pi_i(\mathbf{x})$ is the i-th pre-activation, then: $f_i(\mathbf{x}) = \pi_i(\mathbf{x}) \cdot H(\pi_i(\mathbf{x}) - \theta_i)$ The L0 norm essentially tracks non-zero entries. Specifically, $\|f_i(\mathbf{x})\|_0$ equals 1 when $f_i(\mathbf{x})$ is non-zero, and 0 otherwise. This means the component becomes zero only when the pre-activation falls below the threshold. The Heaviside function determines when the feature becomes active, effectively creating a sparse representation by selectively activating features based on their pre-activation values relative to learned thresholds. Consequently, the total L0 norm across all features sums the number of activated components, providing a compact measure of feature sparsity in the representation.

Yes, exactly! Equation 10 is showing that when you expand the L0 norm of the feature vector f(x)\mathbf{f}(\mathbf{x}) using the JumpReLU definition, you get a sum of Heaviside step functions.

Here's the reasoning:

  • Each feature fi(x)=JumpReLUθi(πi(x))=πi(x)H(πi(x)θi)f_i(\mathbf{x}) = \text{JumpReLU}_{\theta_i}(\pi_i(\mathbf{x})) = \pi_i(\mathbf{x}) \cdot H(\pi_i(\mathbf{x}) - \theta_i)
  • The L0 norm counts non-zero entries
  • fi(x)f_i(\mathbf{x}) is non-zero exactly when H(πi(x)θi)=1H(\pi_i(\mathbf{x}) - \theta_i) = 1 (since pre-activations are continuous, πi(x)=0\pi_i(\mathbf{x}) = 0 happens with probability 0)
  • Therefore fi(x)0=H(πi(x)θi)\|f_i(\mathbf{x})\|_0 = H(\pi_i(\mathbf{x}) - \theta_i)

So the sparsity penalty becomes: Lsparsity=λf(x)0=λi=1MH(πi(x)θi)\mathcal{L}_{\text{sparsity}} = \lambda\|\mathbf{f}(\mathbf{x})\|_0 = \lambda\sum_{i=1}^{M}H(\pi_i(\mathbf{x})-\theta_i)

This reformulation is crucial for the paper's approach because it shows that the threshold parameters θi\theta_i only appear inside Heaviside functions in the loss, which is what motivates their use of straight-through estimators to get gradient signals for training these thresholds.

replyquoteparent
Thinking Summary

The user is asking for a tutorial on straight-through estimators (STEs). This is a good opportunity to explain this important concept clearly, as it's central to the JumpReLU paper and many other works in deep learning. I should: 1. Start with the motivation/problem 2. Explain the basic idea 3. Give a simple example 4. Show how it works in practice 5. Discuss when/why it works 6. Maybe mention some applications I should keep it concise but comprehensive, and use examples to make it concrete.

A Tutorial on Straight-Through Estimators (STEs)

The Problem

Imagine you're training a neural network with a step function:

def step(x):
    return 1.0 if x > 0 else 0.0

The gradient of this function is zero everywhere (except at x=0 where it's undefined). If you try to backpropagate through it, you get zero gradients - your parameters won't update!

This problem appears whenever we have:

  • Discrete decisions (rounding, thresholding, sampling)
  • Piecewise constant functions
  • Non-differentiable operations

The Core Idea

A Straight-Through Estimator (STE) says: "In the forward pass, use the non-differentiable function. In the backward pass, pretend it was something differentiable."

The simplest STE is the identity STE:

# Forward: y = step(x)
# Backward: pretend dy/dx = 1

A Concrete Example

Let's implement a sign function with an STE:

import torch

class SignSTE(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return torch.sign(x)  # Returns -1, 0, or 1
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output    # Pretend derivative is 1

Now gradients can flow through the sign function!

How STEs Work in Practice

Consider training a network to output binary values:

# Without STE: gradients get stuck
logits = model(input)
binary_output = (logits > 0).float()  # Gradient is 0!
loss = criterion(binary_output, target)

# With STE: gradients flow
logits = model(input)
binary_output = SignSTE.apply(logits)  # Gradient passes through
loss = criterion(binary_output, target)

The JumpReLU Example

In the JumpReLU paper, they need to train through:

  • JumpReLUθ(z)=zH(zθ)\text{JumpReLU}_\theta(z) = z \cdot H(z - \theta)
  • The L0 penalty: iH(πi(x)θi)\sum_i H(\pi_i(x) - \theta_i)

Both involve the Heaviside step function HH, which has zero gradient with respect to θ\theta.

Their STE solution uses a "soft" gradient in a small window:

# Forward: use actual step function
y = (x > theta).float()

# Backward: use soft gradient
grad_theta = -(1/ε) * rectangle((x - theta)/ε) * grad_output

This gives gradient signal when xx is within ϵ\epsilon of θ\theta.

When Do STEs Work?

STEs are surprisingly effective when:

  1. The forward function preserves information: Even though sign(x)\text{sign}(x) loses magnitude, it preserves direction
  2. Gradients point in roughly the right direction: We don't need exact gradients, just reasonable descent directions
  3. The mismatch is local: STEs work best when the difference between forward and backward is small

Common Applications

  1. Quantized Neural Networks: Training networks with discrete weights/activations
  2. Discrete Latent Variables: Training VAEs with categorical variables
  3. Hard Attention: Training attention mechanisms that make discrete selections
  4. Sparse Models: Like the JumpReLU SAEs, where features are either on or off

Key Insights

The magic of STEs is that optimization often just needs "good enough" gradients. Perfect gradients are great, but biased gradients that point roughly the right way can still find good solutions.

As the JumpReLU paper shows, you can even design custom STEs based on your problem structure - they use kernel density estimation to create gradients that estimate the true expected loss gradient!

Code Example: Threshold Learning

Here's a simplified version of how JumpReLU uses STEs:

class ThresholdSTE(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, threshold, bandwidth=0.001):
        ctx.save_for_backward(x, threshold)
        ctx.bandwidth = bandwidth
        return x * (x > threshold)
    
    @staticmethod  
    def backward(ctx, grad_output):
        x, threshold = ctx.saved_tensors
        ε = ctx.bandwidth
        
        # Gradient w.r.t x: normal gradient where x > threshold
        grad_x = grad_output * (x > threshold)
        
        # Gradient w.r.t threshold: custom STE
        in_window = ((x - threshold).abs() < ε/2)
        grad_threshold = -threshold/ε * in_window * grad_output
        
        return grad_x, grad_threshold, None

This allows learning thresholds even though the step function has zero gradient!

replyquoteparent