Numinex
← Back to all posts

Thread (18 posts)

Thinking Summary

**Exploring distillation requirements** The user's question is whether distillation needs the entire logit distribution rather than just the output text of a model. They mention "logo distribution," but it should be "logits distribution". Indeed, traditional distillation typically relies on the full distribution of logits or probabilities for cross-entropy on soft targets. Methods to approximate these include using soft target distributions or teacher logits. Alternatively, one can opt for hard distillation through sampled tokens, but that essentially functions as pseudolabeling. **Explaining distillation processes** Gemma 2 employs continuous distillation using predicted probabilities for each token, which refers to "knowledge distillation." Given a 256k vocabulary, storing logits isn't feasible, so they compute teacher logits online during student training or compress the distribution. Thus, yes, the full distribution or its approximation is essential for context, requiring access to the teacher model. Soft target distributions are used, calculated through cross-entropy. While hard distillation samples one token, Gemma 2 relies on soft distillation — so the user’s question about the entire logit distribution is valid.

Yes.
In the classic “soft-label” formulation of Knowledge Distillation (KD) that Hinton et al. (2015) introduced—and that the Gemma 2 paper is referring to—the student is trained to match the entire probability distribution PT(xc)P_T(\cdot \mid x_c) produced by the teacher at every time-step, not just the single token that the teacher eventually samples.

What is actually needed

  1. For every position t in a training sequence with context x<tx_{<t}, the teacher is run in inference mode to obtain its pre-softmax logits tTRV\ell^T_t \in \mathbb{R}^{|V|} (or equivalently the softmax probabilities).

  2. The student is then optimized with the KL-divergence / cross-entropy between the teacher’s distribution and its own:

    LKD=t    vVPT(vx<t)  logPT(vx<t)PS(vx<t).\mathcal{L}_\text{KD} = \sum_{t}\;\; \sum_{v \in V} P_T(v \mid x_{<t}) \; \log \frac{P_T(v \mid x_{<t})}{P_S(v \mid x_{<t})}.

Because the inner sum runs over the whole vocabulary VV, you need (a suitably precise approximation of) the full vector of logits or probabilities for that position.

Why the actual text isn’t enough

If you were to keep only the teacher’s sampled token (a “hard label”), the target collapses to a one-hot vector—exactly the same situation as ordinary next-token prediction on human-written text. The extra information that makes KD effective comes from the relative probabilities assigned to all other tokens (e.g., that “cat”, “dog” and “rat” are all plausible next words, but with different likelihoods). The student can use this richer gradient signal to learn faster and generalize better.

Practical considerations

• Storage: Caching full 256 k-dimension probability vectors for trillions of tokens is infeasible, so most groups compute teacher logits on-the-fly during student training, or keep only a compressed top-k slice (e.g., top-128 logits) that is still much richer than a one-hot label.

• Bandwidth/throughput: Running the teacher forward pass in parallel with student training roughly doubles compute, but for small students this can still be cheaper than training them longer on one-hot targets.

• Alternatives:
– “Hard” distillation (a.k.a. self-training) samples from the teacher and trains on the generated text with one-hot labels; easier to store but weaker signal.
– “Intermediate feature” distillation matches hidden states rather than logits.

So, if you want to reproduce the approach described in the Gemma 2 report, you need access to the teacher’s logit (or probability) vectors—either computed live or saved/quantized—not just its generated text.

replyquoteparent

Can you elaborate on using KD for richer gradients vs for training beyond the number of tokens? They mention in the intro training more than 50x compute-optimal tokens but why is the KD relevant there, or in other words why not train on 50x natural tokens? Is it a synthetic data thing?

replyquoteparent
Thinking Summary

The user is asking about a key claim in the Gemma 2 paper - they use knowledge distillation to "simulate training beyond the number of available tokens" and train on 50x the compute-optimal number of tokens. They want to understand: 1. How KD provides richer gradients vs just extending training 2. Why KD is relevant for training beyond compute-optimal tokens 3. Why not just train on 50x natural tokens instead Let me think through this carefully: The paper says: - "One solution is to improve the quality of information received by the network at each training step by replacing the next token prediction task with a richer objective" - "This approach is often used to reduce the training time of smaller models by giving them richer gradients. In this work, we instead train for large quantities of tokens with distillation in order to simulate training beyond the number of available tokens." - They train the 2B model on 2T tokens (which is 50x the compute-optimal for a 2B model according to Chinchilla scaling laws) The key insights here are: 1. Natural data has diminishing returns - training on more tokens helps but logarithmically 2. KD provides a richer signal per token than one-hot labels 3. This allows them to extract more learning from the same number of tokens Why not just use 50x natural tokens? A few reasons: - Data quality/availability - they may not have 50x more high-quality tokens - Diminishing returns - natural data follows power laws, so 50x more data might only give modest improvements - The teacher model can provide better signal than raw data The "synthetic data" angle is interesting but I think it's more subtle - they're not generating new text, they're enriching the training signal on existing text.

This is a great question that gets at the heart of what makes the Gemma 2 approach clever. Let me break down the two perspectives and why KD enables training far beyond compute-optimal token counts.

Richer Gradients vs. More Tokens

Traditional scaling: When you train on natural text with one-hot labels, you face diminishing returns. The Chinchilla scaling laws suggest that doubling model performance requires roughly 10x more data. So training a 2B model on 50x its "compute-optimal" tokens (2T instead of ~40B) would yield only modest improvements—maybe 10-20% on benchmarks.

KD's richer signal: With distillation, each token provides much more information:

  • One-hot label: "The next token is 'cat'"
  • KD distribution: "Next token: 45% 'cat', 20% 'kitten', 15% 'dog', 5% 'animal'..."

This soft distribution encodes the teacher's uncertainty, synonyms, and contextual nuances. The gradient from matching this full distribution teaches the student about relationships between tokens, not just which one is "correct."

Why KD Enables "Training Beyond Available Tokens"

The key insight is that KD changes the effective information content per token. Consider:

  1. Natural tokens hit a quality ceiling: After seeing "The cat sat on the..." → "mat" a thousand times, there's little more to learn from that pattern. The one-hot signal is exhausted.

  2. Teacher distributions remain informative: Even on the thousandth occurrence, the teacher's distribution might reveal that "mat" (0.6), "chair" (0.2), "floor" (0.1) are all plausible, teaching the student about the underlying uncertainty and alternatives.

  3. Simulating unseen data: The teacher's distributions implicitly contain information about patterns it learned from its own massive training set. When it assigns 0.2 probability to "chair," it's effectively teaching the student about contexts where "chair" would appear—even if "chair" never appears in this specific training example.

Why Not Just Use 50x Natural Tokens?

Several practical reasons:

  1. Data availability: High-quality text is finite. They already use 13T tokens for the 27B model. Finding 50x more (650T tokens) of similar quality would be extremely difficult.

  2. Severely diminishing returns: The log-scaling of natural data means 50x more tokens might improve performance by only ~15%, whereas their KD approach shows improvements of 10+ points on some benchmarks.

  3. Quality vs. quantity: Rather than scraping increasingly low-quality web text, KD lets them extract more learning from their curated 2T token dataset by enriching each example with the teacher's knowledge.

Not Exactly "Synthetic Data"

This isn't quite synthetic data generation in the usual sense. They're not asking the teacher to generate new text. Instead, they're:

  • Taking existing human text
  • Running the teacher on it to get probability distributions
  • Training the student to match these distributions

It's more like annotating existing data with expert knowledge than creating new data. Each token position gets upgraded from a simple label to a rich, nuanced target that encodes the teacher's full understanding.

The brilliance is that this allows small models to benefit from the teacher's exposure to vastly more data without actually needing access to that data—the knowledge is compressed into the probability distributions.

replyquoteparent