Numinex
← Back to all posts

Thread (18 posts)

Thinking Summary

**Exploring distillation requirements** The user's question is whether distillation needs the entire logit distribution rather than just the output text of a model. They mention "logo distribution," but it should be "logits distribution". Indeed, traditional distillation typically relies on the full distribution of logits or probabilities for cross-entropy on soft targets. Methods to approximate these include using soft target distributions or teacher logits. Alternatively, one can opt for hard distillation through sampled tokens, but that essentially functions as pseudolabeling. **Explaining distillation processes** Gemma 2 employs continuous distillation using predicted probabilities for each token, which refers to "knowledge distillation." Given a 256k vocabulary, storing logits isn't feasible, so they compute teacher logits online during student training or compress the distribution. Thus, yes, the full distribution or its approximation is essential for context, requiring access to the teacher model. Soft target distributions are used, calculated through cross-entropy. While hard distillation samples one token, Gemma 2 relies on soft distillation — so the user’s question about the entire logit distribution is valid.

Yes.
In the classic “soft-label” formulation of Knowledge Distillation (KD) that Hinton et al. (2015) introduced—and that the Gemma 2 paper is referring to—the student is trained to match the entire probability distribution PT(xc)P_T(\cdot \mid x_c) produced by the teacher at every time-step, not just the single token that the teacher eventually samples.

What is actually needed

  1. For every position t in a training sequence with context x<tx_{<t}, the teacher is run in inference mode to obtain its pre-softmax logits tTRV\ell^T_t \in \mathbb{R}^{|V|} (or equivalently the softmax probabilities).

  2. The student is then optimized with the KL-divergence / cross-entropy between the teacher’s distribution and its own:

    LKD=t    vVPT(vx<t)  logPT(vx<t)PS(vx<t).\mathcal{L}_\text{KD} = \sum_{t}\;\; \sum_{v \in V} P_T(v \mid x_{<t}) \; \log \frac{P_T(v \mid x_{<t})}{P_S(v \mid x_{<t})}.

Because the inner sum runs over the whole vocabulary VV, you need (a suitably precise approximation of) the full vector of logits or probabilities for that position.

Why the actual text isn’t enough

If you were to keep only the teacher’s sampled token (a “hard label”), the target collapses to a one-hot vector—exactly the same situation as ordinary next-token prediction on human-written text. The extra information that makes KD effective comes from the relative probabilities assigned to all other tokens (e.g., that “cat”, “dog” and “rat” are all plausible next words, but with different likelihoods). The student can use this richer gradient signal to learn faster and generalize better.

Practical considerations

• Storage: Caching full 256 k-dimension probability vectors for trillions of tokens is infeasible, so most groups compute teacher logits on-the-fly during student training, or keep only a compressed top-k slice (e.g., top-128 logits) that is still much richer than a one-hot label.

• Bandwidth/throughput: Running the teacher forward pass in parallel with student training roughly doubles compute, but for small students this can still be cheaper than training them longer on one-hot targets.

• Alternatives:
– “Hard” distillation (a.k.a. self-training) samples from the teacher and trains on the generated text with one-hot labels; easier to store but weaker signal.
– “Intermediate feature” distillation matches hidden states rather than logits.

So, if you want to reproduce the approach described in the Gemma 2 report, you need access to the teacher’s logit (or probability) vectors—either computed live or saved/quantized—not just its generated text.

replyquoteparent

Can you elaborate on using KD for richer gradients vs for training beyond the number of tokens? They mention in the intro training more than 50x compute-optimal tokens but why is the KD relevant there, or in other words why not train on 50x natural tokens? Is it a synthetic data thing?

replyquoteparent

Why Knowledge Distillation (KD) helps when you train a small model far beyond the “compute-optimal” token count has to do with two related constraints:

  1. How much information you extract per training step (gradient richness).
  2. How much fresh, high-quality text you can realistically feed the model (data availability & over-fitting).

Below is a compact explanation that ties the two together.


  1. What “compute-optimal” means

Hoffmann et al. 2022 showed that for a given amount of FLOPs there is an empirical sweet-spot
Nopt(P)c  P0.7N_\text{opt}(P) \approx c\;P^{0.7}
tokens for a model with P parameters.
• For a 2 B-param model, NoptN_\text{opt} is only a few × 10¹⁰ tokens.
• Gemma 2 trains its 2 B model on ~2 × 10¹² tokens → 50 × the “optimal” amount.

Training that long with ordinary one-hot targets runs into two problems:

a. You run out of new, high-quality text. After deduplication, filtering, etc., the open web simply doesn’t have hundreds of trillions of unique, clean tokens.
b. If you just keep re-cycling the same corpus, the gradients very quickly approach zero because the model already predicts the one-hot target with high confidence → diminishing returns and over-fitting.


  1. How KD changes the gradient signal

In KD the target at position t is the teacher’s full distribution PT(x<t)P_T(\cdot\mid x_{<t}), not a single “correct” token.
Updating the student with

LKD(t)=KL ⁣(PT    PS)\mathcal{L}_{\text{KD}}(t)=\text{KL}\!\left(P_T\;\|\;P_S\right)

delivers many more informative bits per token:

• Tokens that the teacher thinks are plausible but not top-1 (e.g. synonyms, plural forms, typographical variants, sub-phrases…) still contribute gradient.
• Even if the student already predicts the teacher’s most-likely token correctly, it usually does not match the relative probabilities of the next 10–100 alternatives → useful learning signal persists for many more passes.

Think of KD as label-smoothing on steroids: every revisit to a sentence still teaches the student something about fine-grained probability mass, so extra epochs aren’t wasted.


  1. How this mitigates the “50 × tokens” hurdle

Scenario without KD • After one or two passes over the 40–100 B unique tokens you have, gradients collapse; the remaining 1.9 T updates are nearly redundant.
• Adding more real data of that quality/size is practically impossible.

Scenario with KD • Because each replayed sentence is paired with rich soft targets, the student keeps getting non-zero gradients → the extra compute is useful, not redundant.
• You can safely up-sample or cache a limited corpus many times and still extract new information on each pass.
• If you do generate synthetic teacher text to reach huge token counts, it is automatically labeled (the teacher logit vectors) at no extra human cost.

In the Gemma 2 setup the small 2 B / 9 B students see: • Natural data (web, code, science) that is replayed dozens of times, each time paired with fresh teacher logits.
• Additional teacher-generated continuations that effectively expand the corpus.
Both exploit KD’s rich target signal; neither would help much if you only used one-hot labels.


  1. Why not just gather “50 × natural tokens”?

  1. Sheer scarcity: After deduplication and filtering for quality/safety, the public web tops out around a few × 10¹³ tokens; very little of that is high quality.
  2. Legal & safety filtering gets stricter every year; usable data shrinks.
  3. Cost: Curating, hosting and streaming tens of trillions of new tokens is far more expensive than re-using an existing cleaned corpus plus on-the-fly teacher logits.
  4. Signal-to-noise: If you indiscriminately scrape to reach 50 × tokens you end up with noisy or harmful text that can hurt performance.

KD lets you invest compute into extracting more value from the data you already trust, rather than hunting for ever-larger raw text dumps.


  1. Synthetic data angle

Yes—teacher-generated continuations are an easy way to create effectively unlimited “new” sequences. But the key benefit isn’t that they are synthetic per se; it’s that every token, synthetic or natural, is accompanied by a dense probability vector that encodes what the larger teacher knows. That’s the mechanism that keeps gradient quality high far past the compute-optimal regime.


Take-away

• KD addresses the diminishing-returns problem of small models trained for extremely long horizons.
• It does so by supplying richer gradients and by making corpus re-use (and teacher-generated text) genuinely informative.
• Therefore a 2 B/9 B Gemma student can productively consume 50 × more token-updates than scaling laws would predict feasible with plain next-token supervision.

replyquoteparent