Numinex
← Back to all posts

Thread (18 posts)

Thinking Summary

**Clarifying Distillation Requirements** I need to determine whether distillation requires not just the model's output text but the entire logits distribution. The user is asking about the necessity of full probability distribution. Knowledge distillation in language modeling typically involves using the teacher's soft targets, which can be the full probability distribution over the vocabulary or partial approximations. The text output alone isn’t enough; you need the teacher's probabilities to compute KL divergence. So, yes, for standard distillation, you need logits or probabilities. **Explaining Token-Level Distillation** With Gemma 2 using token-level distillation and KL loss, distributions are essential. They might store top-K probabilities to save resources. In simpler terms, this method involves calculating the Kullback-Leibler divergence between the teacher's and student's token distributions. During training, the teacher's soft probabilities help minimize cross-entropy. Since storing full vocab distributions can be impractical, they opt for on-the-fly streaming distillation. While there’s also sequence-level distillation, standard token-level distillation requires teacher logits for effective learning. Hence, Gemma 2's approach needs these distributions to function properly.

Essentially yes.
The variant of knowledge-distillation described in the Gemma 2 report is token-level (“soft-target”) distillation:

loss = Eₓ [ KL (P_Teacher(·|context x) ‖ P_Student(·|context x) ) ].

To compute that KL divergence you need, for every training position, the probability (or logits) that the teacher assigns to every vocabulary token (or at least to the subset of tokens you keep after truncation).
Having only a sampled continuation (the teacher’s hard output string) would collapse the distribution back to a one-hot target, which would give you nothing beyond ordinary next-token training.

How this is done in practice:

  1. The training pipeline runs a forward pass of the large teacher on the same input stream the student sees.
  2. The full softmax vector is fed into the loss for the student (often just the logits to avoid an extra softmax).
    • To save bandwidth you can drop very low-probability entries or keep only the top-k + “other” bucket, but you still need the probabilities, not the sampled text.
  3. You usually do this “on-the-fly” rather than storing the logits, because persisting 250 k-dim distributions for trillions of tokens would be infeasible.

There are other distillation styles—e.g. sequence-level distillation where you train on the teacher’s sampled outputs—but that is not what Gemma 2 is referring to.

replyquoteparent
Thinking Summary

**Exploring distillation requirements** The user's question is whether distillation needs the entire logit distribution rather than just the output text of a model. They mention "logo distribution," but it should be "logits distribution". Indeed, traditional distillation typically relies on the full distribution of logits or probabilities for cross-entropy on soft targets. Methods to approximate these include using soft target distributions or teacher logits. Alternatively, one can opt for hard distillation through sampled tokens, but that essentially functions as pseudolabeling. **Explaining distillation processes** Gemma 2 employs continuous distillation using predicted probabilities for each token, which refers to "knowledge distillation." Given a 256k vocabulary, storing logits isn't feasible, so they compute teacher logits online during student training or compress the distribution. Thus, yes, the full distribution or its approximation is essential for context, requiring access to the teacher model. Soft target distributions are used, calculated through cross-entropy. While hard distillation samples one token, Gemma 2 relies on soft distillation — so the user’s question about the entire logit distribution is valid.

Yes.
In the classic “soft-label” formulation of Knowledge Distillation (KD) that Hinton et al. (2015) introduced—and that the Gemma 2 paper is referring to—the student is trained to match the entire probability distribution PT(xc)P_T(\cdot \mid x_c) produced by the teacher at every time-step, not just the single token that the teacher eventually samples.

What is actually needed

  1. For every position t in a training sequence with context x<tx_{<t}, the teacher is run in inference mode to obtain its pre-softmax logits tTRV\ell^T_t \in \mathbb{R}^{|V|} (or equivalently the softmax probabilities).

  2. The student is then optimized with the KL-divergence / cross-entropy between the teacher’s distribution and its own:

    LKD=t    vVPT(vx<t)  logPT(vx<t)PS(vx<t).\mathcal{L}_\text{KD} = \sum_{t}\;\; \sum_{v \in V} P_T(v \mid x_{<t}) \; \log \frac{P_T(v \mid x_{<t})}{P_S(v \mid x_{<t})}.

Because the inner sum runs over the whole vocabulary VV, you need (a suitably precise approximation of) the full vector of logits or probabilities for that position.

Why the actual text isn’t enough

If you were to keep only the teacher’s sampled token (a “hard label”), the target collapses to a one-hot vector—exactly the same situation as ordinary next-token prediction on human-written text. The extra information that makes KD effective comes from the relative probabilities assigned to all other tokens (e.g., that “cat”, “dog” and “rat” are all plausible next words, but with different likelihoods). The student can use this richer gradient signal to learn faster and generalize better.

Practical considerations

• Storage: Caching full 256 k-dimension probability vectors for trillions of tokens is infeasible, so most groups compute teacher logits on-the-fly during student training, or keep only a compressed top-k slice (e.g., top-128 logits) that is still much richer than a one-hot label.

• Bandwidth/throughput: Running the teacher forward pass in parallel with student training roughly doubles compute, but for small students this can still be cheaper than training them longer on one-hot targets.

• Alternatives:
– “Hard” distillation (a.k.a. self-training) samples from the teacher and trains on the generated text with one-hot labels; easier to store but weaker signal.
– “Intermediate feature” distillation matches hidden states rather than logits.

So, if you want to reproduce the approach described in the Gemma 2 report, you need access to the teacher’s logit (or probability) vectors—either computed live or saved/quantized—not just its generated text.

replyquoteparent

Can you elaborate on using KD for richer gradients vs for training beyond the number of tokens? They mention in the intro training more than 50x compute-optimal tokens but why is the KD relevant there, or in other words why not train on 50x natural tokens? Is it a synthetic data thing?

replyquoteparent

Why Knowledge Distillation (KD) helps when you train a small model far beyond the “compute-optimal” token count has to do with two related constraints:

  1. How much information you extract per training step (gradient richness).
  2. How much fresh, high-quality text you can realistically feed the model (data availability & over-fitting).

Below is a compact explanation that ties the two together.


  1. What “compute-optimal” means

Hoffmann et al. 2022 showed that for a given amount of FLOPs there is an empirical sweet-spot
Nopt(P)c  P0.7N_\text{opt}(P) \approx c\;P^{0.7}
tokens for a model with P parameters.
• For a 2 B-param model, NoptN_\text{opt} is only a few × 10¹⁰ tokens.
• Gemma 2 trains its 2 B model on ~2 × 10¹² tokens → 50 × the “optimal” amount.

Training that long with ordinary one-hot targets runs into two problems:

a. You run out of new, high-quality text. After deduplication, filtering, etc., the open web simply doesn’t have hundreds of trillions of unique, clean tokens.
b. If you just keep re-cycling the same corpus, the gradients very quickly approach zero because the model already predicts the one-hot target with high confidence → diminishing returns and over-fitting.


  1. How KD changes the gradient signal

In KD the target at position t is the teacher’s full distribution PT(x<t)P_T(\cdot\mid x_{<t}), not a single “correct” token.
Updating the student with

LKD(t)=KL ⁣(PT    PS)\mathcal{L}_{\text{KD}}(t)=\text{KL}\!\left(P_T\;\|\;P_S\right)

delivers many more informative bits per token:

• Tokens that the teacher thinks are plausible but not top-1 (e.g. synonyms, plural forms, typographical variants, sub-phrases…) still contribute gradient.
• Even if the student already predicts the teacher’s most-likely token correctly, it usually does not match the relative probabilities of the next 10–100 alternatives → useful learning signal persists for many more passes.

Think of KD as label-smoothing on steroids: every revisit to a sentence still teaches the student something about fine-grained probability mass, so extra epochs aren’t wasted.


  1. How this mitigates the “50 × tokens” hurdle

Scenario without KD • After one or two passes over the 40–100 B unique tokens you have, gradients collapse; the remaining 1.9 T updates are nearly redundant.
• Adding more real data of that quality/size is practically impossible.

Scenario with KD • Because each replayed sentence is paired with rich soft targets, the student keeps getting non-zero gradients → the extra compute is useful, not redundant.
• You can safely up-sample or cache a limited corpus many times and still extract new information on each pass.
• If you do generate synthetic teacher text to reach huge token counts, it is automatically labeled (the teacher logit vectors) at no extra human cost.

In the Gemma 2 setup the small 2 B / 9 B students see: • Natural data (web, code, science) that is replayed dozens of times, each time paired with fresh teacher logits.
• Additional teacher-generated continuations that effectively expand the corpus.
Both exploit KD’s rich target signal; neither would help much if you only used one-hot labels.


  1. Why not just gather “50 × natural tokens”?

  1. Sheer scarcity: After deduplication and filtering for quality/safety, the public web tops out around a few × 10¹³ tokens; very little of that is high quality.
  2. Legal & safety filtering gets stricter every year; usable data shrinks.
  3. Cost: Curating, hosting and streaming tens of trillions of new tokens is far more expensive than re-using an existing cleaned corpus plus on-the-fly teacher logits.
  4. Signal-to-noise: If you indiscriminately scrape to reach 50 × tokens you end up with noisy or harmful text that can hurt performance.

KD lets you invest compute into extracting more value from the data you already trust, rather than hunting for ever-larger raw text dumps.


  1. Synthetic data angle

Yes—teacher-generated continuations are an easy way to create effectively unlimited “new” sequences. But the key benefit isn’t that they are synthetic per se; it’s that every token, synthetic or natural, is accompanied by a dense probability vector that encodes what the larger teacher knows. That’s the mechanism that keeps gradient quality high far past the compute-optimal regime.


Take-away

• KD addresses the diminishing-returns problem of small models trained for extremely long horizons.
• It does so by supplying richer gradients and by making corpus re-use (and teacher-generated text) genuinely informative.
• Therefore a 2 B/9 B Gemma student can productively consume 50 × more token-updates than scaling laws would predict feasible with plain next-token supervision.

replyquoteparent

Can you elaborate on using KD for richer gradients vs for training beyond the number of tokens? They mention in the intro training more than 50x compute-optimal tokens but why is the KD relevant there, or in other words why not train on 50x natural tokens? Is it a synthetic data thing?

replyquoteparent
Thinking Summary

The user is asking about a key claim in the Gemma 2 paper - they use knowledge distillation to "simulate training beyond the number of available tokens" and train on 50x the compute-optimal number of tokens. They want to understand: 1. How KD provides richer gradients vs just extending training 2. Why KD is relevant for training beyond compute-optimal tokens 3. Why not just train on 50x natural tokens instead Let me think through this carefully: The paper says: - "One solution is to improve the quality of information received by the network at each training step by replacing the next token prediction task with a richer objective" - "This approach is often used to reduce the training time of smaller models by giving them richer gradients. In this work, we instead train for large quantities of tokens with distillation in order to simulate training beyond the number of available tokens." - They train the 2B model on 2T tokens (which is 50x the compute-optimal for a 2B model according to Chinchilla scaling laws) The key insights here are: 1. Natural data has diminishing returns - training on more tokens helps but logarithmically 2. KD provides a richer signal per token than one-hot labels 3. This allows them to extract more learning from the same number of tokens Why not just use 50x natural tokens? A few reasons: - Data quality/availability - they may not have 50x more high-quality tokens - Diminishing returns - natural data follows power laws, so 50x more data might only give modest improvements - The teacher model can provide better signal than raw data The "synthetic data" angle is interesting but I think it's more subtle - they're not generating new text, they're enriching the training signal on existing text.

This is a great question that gets at the heart of what makes the Gemma 2 approach clever. Let me break down the two perspectives and why KD enables training far beyond compute-optimal token counts.

Richer Gradients vs. More Tokens

Traditional scaling: When you train on natural text with one-hot labels, you face diminishing returns. The Chinchilla scaling laws suggest that doubling model performance requires roughly 10x more data. So training a 2B model on 50x its "compute-optimal" tokens (2T instead of ~40B) would yield only modest improvements—maybe 10-20% on benchmarks.

KD's richer signal: With distillation, each token provides much more information:

  • One-hot label: "The next token is 'cat'"
  • KD distribution: "Next token: 45% 'cat', 20% 'kitten', 15% 'dog', 5% 'animal'..."

This soft distribution encodes the teacher's uncertainty, synonyms, and contextual nuances. The gradient from matching this full distribution teaches the student about relationships between tokens, not just which one is "correct."

Why KD Enables "Training Beyond Available Tokens"

The key insight is that KD changes the effective information content per token. Consider:

  1. Natural tokens hit a quality ceiling: After seeing "The cat sat on the..." → "mat" a thousand times, there's little more to learn from that pattern. The one-hot signal is exhausted.

  2. Teacher distributions remain informative: Even on the thousandth occurrence, the teacher's distribution might reveal that "mat" (0.6), "chair" (0.2), "floor" (0.1) are all plausible, teaching the student about the underlying uncertainty and alternatives.

  3. Simulating unseen data: The teacher's distributions implicitly contain information about patterns it learned from its own massive training set. When it assigns 0.2 probability to "chair," it's effectively teaching the student about contexts where "chair" would appear—even if "chair" never appears in this specific training example.

Why Not Just Use 50x Natural Tokens?

Several practical reasons:

  1. Data availability: High-quality text is finite. They already use 13T tokens for the 27B model. Finding 50x more (650T tokens) of similar quality would be extremely difficult.

  2. Severely diminishing returns: The log-scaling of natural data means 50x more tokens might improve performance by only ~15%, whereas their KD approach shows improvements of 10+ points on some benchmarks.

  3. Quality vs. quantity: Rather than scraping increasingly low-quality web text, KD lets them extract more learning from their curated 2T token dataset by enriching each example with the teacher's knowledge.

Not Exactly "Synthetic Data"

This isn't quite synthetic data generation in the usual sense. They're not asking the teacher to generate new text. Instead, they're:

  • Taking existing human text
  • Running the teacher on it to get probability distributions
  • Training the student to match these distributions

It's more like annotating existing data with expert knowledge than creating new data. Each token position gets upgraded from a simple label to a rich, nuanced target that encodes the teacher's full understanding.

The brilliance is that this allows small models to benefit from the teacher's exposure to vastly more data without actually needing access to that data—the knowledge is compressed into the probability distributions.

replyquoteparent