Gemma 2
Thread (18 posts)
knowledge distillation (Hinton et al., 2015), which replaces the one-hot vector seen at each token with the distribution of potential next tokens computed from a large model
Is it correct that this implies that distillation requires not just the output text of a model (post -sampling) but the entire logo distribution?
Thinking Summary
**Exploring distillation requirements** The user's question is whether distillation needs the entire logit distribution rather than just the output text of a model. They mention "logo distribution," but it should be "logits distribution". Indeed, traditional distillation typically relies on the full distribution of logits or probabilities for cross-entropy on soft targets. Methods to approximate these include using soft target distributions or teacher logits. Alternatively, one can opt for hard distillation through sampled tokens, but that essentially functions as pseudolabeling. **Explaining distillation processes** Gemma 2 employs continuous distillation using predicted probabilities for each token, which refers to "knowledge distillation." Given a 256k vocabulary, storing logits isn't feasible, so they compute teacher logits online during student training or compress the distribution. Thus, yes, the full distribution or its approximation is essential for context, requiring access to the teacher model. Soft target distributions are used, calculated through cross-entropy. While hard distillation samples one token, Gemma 2 relies on soft distillation — so the user’s question about the entire logit distribution is valid.
Yes.
In the classic “soft-label” formulation of Knowledge Distillation (KD) that Hinton et al. (2015) introduced—and that the Gemma 2 paper is referring to—the student is trained to match the entire probability distribution produced by the teacher at every time-step, not just the single token that the teacher eventually samples.
What is actually needed
-
For every position t in a training sequence with context , the teacher is run in inference mode to obtain its pre-softmax logits (or equivalently the softmax probabilities).
-
The student is then optimized with the KL-divergence / cross-entropy between the teacher’s distribution and its own:
Because the inner sum runs over the whole vocabulary , you need (a suitably precise approximation of) the full vector of logits or probabilities for that position.
Why the actual text isn’t enough
If you were to keep only the teacher’s sampled token (a “hard label”), the target collapses to a one-hot vector—exactly the same situation as ordinary next-token prediction on human-written text. The extra information that makes KD effective comes from the relative probabilities assigned to all other tokens (e.g., that “cat”, “dog” and “rat” are all plausible next words, but with different likelihoods). The student can use this richer gradient signal to learn faster and generalize better.
Practical considerations
• Storage: Caching full 256 k-dimension probability vectors for trillions of tokens is infeasible, so most groups compute teacher logits on-the-fly during student training, or keep only a compressed top-k slice (e.g., top-128 logits) that is still much richer than a one-hot label.
• Bandwidth/throughput: Running the teacher forward pass in parallel with student training roughly doubles compute, but for small students this can still be cheaper than training them longer on one-hot targets.
• Alternatives:
– “Hard” distillation (a.k.a. self-training) samples from the teacher and trains on the generated text with one-hot labels; easier to store but weaker signal.
– “Intermediate feature” distillation matches hidden states rather than logits.
So, if you want to reproduce the approach described in the Gemma 2 report, you need access to the teacher’s logit (or probability) vectors—either computed live or saved/quantized—not just its generated text.
Can you elaborate on using KD for richer gradients vs for training beyond the number of tokens? They mention in the intro training more than 50x compute-optimal tokens but why is the KD relevant there, or in other words why not train on 50x natural tokens? Is it a synthetic data thing?
Why Knowledge Distillation (KD) helps when you train a small model far beyond the “compute-optimal” token count has to do with two related constraints:
- How much information you extract per training step (gradient richness).
- How much fresh, high-quality text you can realistically feed the model (data availability & over-fitting).
Below is a compact explanation that ties the two together.
- What “compute-optimal” means
Hoffmann et al. 2022 showed that for a given amount of FLOPs there is an empirical sweet-spot
tokens for a model with P parameters.
• For a 2 B-param model, is only a few × 10¹⁰ tokens.
• Gemma 2 trains its 2 B model on ~2 × 10¹² tokens → 50 × the “optimal” amount.
Training that long with ordinary one-hot targets runs into two problems:
a. You run out of new, high-quality text. After deduplication, filtering, etc., the open web simply doesn’t have hundreds of trillions of unique, clean tokens.
b. If you just keep re-cycling the same corpus, the gradients very quickly approach zero because the model already predicts the one-hot target with high confidence → diminishing returns and over-fitting.
- How KD changes the gradient signal
In KD the target at position t is the teacher’s full distribution , not a single “correct” token.
Updating the student with
delivers many more informative bits per token:
• Tokens that the teacher thinks are plausible but not top-1 (e.g. synonyms, plural forms, typographical variants, sub-phrases…) still contribute gradient.
• Even if the student already predicts the teacher’s most-likely token correctly, it usually does not match the relative probabilities of the next 10–100 alternatives → useful learning signal persists for many more passes.
Think of KD as label-smoothing on steroids: every revisit to a sentence still teaches the student something about fine-grained probability mass, so extra epochs aren’t wasted.
- How this mitigates the “50 × tokens” hurdle
Scenario without KD
• After one or two passes over the 40–100 B unique tokens you have, gradients collapse; the remaining 1.9 T updates are nearly redundant.
• Adding more real data of that quality/size is practically impossible.
Scenario with KD
• Because each replayed sentence is paired with rich soft targets, the student keeps getting non-zero gradients → the extra compute is useful, not redundant.
• You can safely up-sample or cache a limited corpus many times and still extract new information on each pass.
• If you do generate synthetic teacher text to reach huge token counts, it is automatically labeled (the teacher logit vectors) at no extra human cost.
In the Gemma 2 setup the small 2 B / 9 B students see:
• Natural data (web, code, science) that is replayed dozens of times, each time paired with fresh teacher logits.
• Additional teacher-generated continuations that effectively expand the corpus.
Both exploit KD’s rich target signal; neither would help much if you only used one-hot labels.
- Why not just gather “50 × natural tokens”?
- Sheer scarcity: After deduplication and filtering for quality/safety, the public web tops out around a few × 10¹³ tokens; very little of that is high quality.
- Legal & safety filtering gets stricter every year; usable data shrinks.
- Cost: Curating, hosting and streaming tens of trillions of new tokens is far more expensive than re-using an existing cleaned corpus plus on-the-fly teacher logits.
- Signal-to-noise: If you indiscriminately scrape to reach 50 × tokens you end up with noisy or harmful text that can hurt performance.
KD lets you invest compute into extracting more value from the data you already trust, rather than hunting for ever-larger raw text dumps.
- Synthetic data angle
Yes—teacher-generated continuations are an easy way to create effectively unlimited “new” sequences. But the key benefit isn’t that they are synthetic per se; it’s that every token, synthetic or natural, is accompanied by a dense probability vector that encodes what the larger teacher knows. That’s the mechanism that keeps gradient quality high far past the compute-optimal regime.
Take-away
• KD addresses the diminishing-returns problem of small models trained for extremely long horizons.
• It does so by supplying richer gradients and by making corpus re-use (and teacher-generated text) genuinely informative.
• Therefore a 2 B/9 B Gemma student can productively consume 50 × more token-updates than scaling laws would predict feasible with plain next-token supervision.
Can you elaborate on using KD for richer gradients vs for training beyond the number of tokens? They mention in the intro training more than 50x compute-optimal tokens but why is the KD relevant there, or in other words why not train on 50x natural tokens? Is it a synthetic data thing?
Thinking Summary
The user is asking about a key claim in the Gemma 2 paper - they use knowledge distillation to "simulate training beyond the number of available tokens" and train on 50x the compute-optimal number of tokens. They want to understand: 1. How KD provides richer gradients vs just extending training 2. Why KD is relevant for training beyond compute-optimal tokens 3. Why not just train on 50x natural tokens instead Let me think through this carefully: The paper says: - "One solution is to improve the quality of information received by the network at each training step by replacing the next token prediction task with a richer objective" - "This approach is often used to reduce the training time of smaller models by giving them richer gradients. In this work, we instead train for large quantities of tokens with distillation in order to simulate training beyond the number of available tokens." - They train the 2B model on 2T tokens (which is 50x the compute-optimal for a 2B model according to Chinchilla scaling laws) The key insights here are: 1. Natural data has diminishing returns - training on more tokens helps but logarithmically 2. KD provides a richer signal per token than one-hot labels 3. This allows them to extract more learning from the same number of tokens Why not just use 50x natural tokens? A few reasons: - Data quality/availability - they may not have 50x more high-quality tokens - Diminishing returns - natural data follows power laws, so 50x more data might only give modest improvements - The teacher model can provide better signal than raw data The "synthetic data" angle is interesting but I think it's more subtle - they're not generating new text, they're enriching the training signal on existing text.
This is a great question that gets at the heart of what makes the Gemma 2 approach clever. Let me break down the two perspectives and why KD enables training far beyond compute-optimal token counts.
Richer Gradients vs. More Tokens
Traditional scaling: When you train on natural text with one-hot labels, you face diminishing returns. The Chinchilla scaling laws suggest that doubling model performance requires roughly 10x more data. So training a 2B model on 50x its "compute-optimal" tokens (2T instead of ~40B) would yield only modest improvements—maybe 10-20% on benchmarks.
KD's richer signal: With distillation, each token provides much more information:
- One-hot label: "The next token is 'cat'"
- KD distribution: "Next token: 45% 'cat', 20% 'kitten', 15% 'dog', 5% 'animal'..."
This soft distribution encodes the teacher's uncertainty, synonyms, and contextual nuances. The gradient from matching this full distribution teaches the student about relationships between tokens, not just which one is "correct."
Why KD Enables "Training Beyond Available Tokens"
The key insight is that KD changes the effective information content per token. Consider:
-
Natural tokens hit a quality ceiling: After seeing "The cat sat on the..." → "mat" a thousand times, there's little more to learn from that pattern. The one-hot signal is exhausted.
-
Teacher distributions remain informative: Even on the thousandth occurrence, the teacher's distribution might reveal that "mat" (0.6), "chair" (0.2), "floor" (0.1) are all plausible, teaching the student about the underlying uncertainty and alternatives.
-
Simulating unseen data: The teacher's distributions implicitly contain information about patterns it learned from its own massive training set. When it assigns 0.2 probability to "chair," it's effectively teaching the student about contexts where "chair" would appear—even if "chair" never appears in this specific training example.
Why Not Just Use 50x Natural Tokens?
Several practical reasons:
-
Data availability: High-quality text is finite. They already use 13T tokens for the 27B model. Finding 50x more (650T tokens) of similar quality would be extremely difficult.
-
Severely diminishing returns: The log-scaling of natural data means 50x more tokens might improve performance by only ~15%, whereas their KD approach shows improvements of 10+ points on some benchmarks.
-
Quality vs. quantity: Rather than scraping increasingly low-quality web text, KD lets them extract more learning from their curated 2T token dataset by enriching each example with the teacher's knowledge.
Not Exactly "Synthetic Data"
This isn't quite synthetic data generation in the usual sense. They're not asking the teacher to generate new text. Instead, they're:
- Taking existing human text
- Running the teacher on it to get probability distributions
- Training the student to match these distributions
It's more like annotating existing data with expert knowledge than creating new data. Each token position gets upgraded from a simple label to a rich, nuanced target that encodes the teacher's full understanding.
The brilliance is that this allows small models to benefit from the teacher's exposure to vastly more data without actually needing access to that data—the knowledge is compressed into the probability distributions.