Stealing Part of a Production Language Model
Thread (64 posts)
I don't follow the definition of model structure defined in §,3
Don't most transformer models have a much more complicated structure than that?
Where does that complexity go?
Thinking Summary
The user is asking about the paper "Stealing Part of a Production Language Model" and specifically about Section 3 where they define the model structure. They're confused about equation (1) which defines the model as: $f_\theta(x) = \text{softmax}(E^T \cdot g_\theta(x))$ where $g_\theta : \mathcal{X}^N \to \mathbb{R}^h$ computes hidden states and $E^T$ is the embedding projection matrix. The user is right to be confused - this is indeed a simplified representation. The paper is specifically focusing on the final layer of the transformer, not describing the entire architecture. The complexity of transformers (attention layers, MLPs, residual connections, layer norms, etc.) is all encapsulated within the function $g_\theta$. Let me explain this clearly.
You're absolutely right to be confused! The definition in Section 3 is indeed highly simplified. The key insight is that all the transformer complexity is hidden inside the function .
The paper defines:
Here's where all the complexity goes:
Inside : This function encapsulates the entire transformer architecture:
- All attention layers
- Feed-forward networks
- Layer normalizations
- Residual connections
- Positional encodings
- Everything except the final projection
So when you have a typical transformer that processes tokens through many layers to produce a final hidden state , that entire process is what represents.
Why this simplification works: The authors only care about stealing the final projection layer (which maps from hidden dimension to vocabulary size ). For their attack, they treat everything before this final projection as a black box that produces some hidden vector.
This is actually mentioned later in the paper - in Lemma 2 they note that with residual connections, you can't extract exactly (only up to a linear transformation), precisely because the internal structure of creates ambiguities.
The brilliance of their attack is that they don't need to understand the internal complexity - they just need to observe that the final output lies in a low-dimensional subspace determined by the hidden dimension .
What's the hidden dimension of the Gemma 2 models?
Thinking Summary
The user is asking about the hidden dimension of the Gemma 2 models. Let me look at the paper document to find information about the hidden dimensions. Looking at the Gemma 2 paper, I can see in Table 1 that shows "Overview of the main model parameters and design choices". In this table, there's a row for "d_model" which represents the hidden dimension. From Table 1: - Gemma 2 2B: d_model = 2304 - Gemma 2 9B: d_model = 3584 - Gemma 2 27B: d_model = 4608 So the hidden dimensions are: - 2B model: 2304 - 9B model: 3584 - 27B model: 4608
Based on the Gemma 2 paper (Table 1), the hidden dimensions (d_model) for the Gemma 2 models are:
- Gemma 2 2B: 2304
- Gemma 2 9B: 3584
- Gemma 2 27B: 4608
These hidden dimensions follow a pattern where each larger model has a significantly wider hidden layer. The paper also notes that all three models use:
- Grouped-Query Attention (GQA) with num_groups = 2
- Head size of 256 for the 2B and 9B models, and 128 for the 27B model
- A vocabulary size of 256,128 tokens (inherited from Gemini)
These architectural details would make Gemma 2 models potentially vulnerable to the embedding projection layer extraction attack described in the first paper, though Google has likely implemented defenses after the responsible disclosure.
Gemma 2 is open weights which may be a clue about why I am referencing it specifically…
What are the values of for the Gemma 2 models and can you confirm it uses the same softmax structure?
Thinking Summary
The user is hinting that because Gemma 2 is open weights, we can actually verify the attack from the first paper ourselves. They want to know: 1. The value of $l$ (vocabulary size) for Gemma 2 2. Whether it uses the same softmax structure as described in the attack paper From the Gemma 2 paper, I can see: - Vocabulary size is 256,128 for all three models (this is $l$) - The paper mentions they use a standard transformer architecture with some modifications Looking at the architecture description, they mention: - "Logit soft-capping" where they cap logits in each attention layer and the final layer - The soft-capping function is: logits ← soft_cap * tanh(logits / soft_cap) - They set soft_cap to 30.0 for the final layer This is interesting because the attack paper assumes a standard softmax without this capping. However, the fundamental structure is still: 1. Hidden states from transformer → 2. Linear projection to vocabulary size → 3. Some form of normalization (in this case, soft-capped softmax) The attack should still work because: - The embedding projection matrix $E^T$ still exists (mapping from hidden dimension $h$ to vocabulary size $l$) - The rank of this matrix is still limited by $h$ - The soft-capping is just a element-wise transformation that doesn't change the fundamental low-rank structure So to answer the questions: - $l = 256,128$ for all Gemma 2 models - It uses a modified softmax with soft-capping, but the fundamental structure (hidden states → linear projection → normalization) is the same
Ah, I see what you're getting at! Since Gemma 2 is open weights, we can actually verify the attack ourselves.
From the Gemma 2 paper:
- for all three Gemma 2 models (they all use the same vocabulary size)
Regarding the softmax structure, Gemma 2 uses a modified version with "logit soft-capping":
logits ← soft_cap * tanh(logits / soft_cap)
Where soft_cap = 30.0
for the final layer. This happens before the softmax.
So the full structure is:
- produces hidden states of dimension
- projects to vocabulary dimension
- Soft-capping:
- Standard softmax applied to the capped logits
The attack from the first paper should still work because:
- The embedding projection matrix still exists with shape
- The rank is still limited by (much smaller than )
- The soft-capping is just an element-wise transformation that doesn't change the low-rank structure
Since we have access to the weights, we could directly verify:
- The hidden dimensions match what's claimed (2304, 3584, 4608)
- The rank of the final projection layer equals
- The attack's SVD approach would correctly identify these dimensions
This is a great example of how open-weight models allow for verification of security research findings!
Thinking Summary
The user is asking about why the vocabulary size l = 256,128 can be expressed as 2^7 × 2001. Let me check: 256,128 = 2^7 × 2001 = 128 × 2001 = 256,128 ✓ This is interesting because it's a specific factorization. The vocabulary size of 256,128 is quite large and this particular factorization might have some significance. Looking at the SentencePiece tokenizer mentioned in the Gemma 2 paper, it has 256k entries. The paper mentions it's the same tokenizer as Gemma 1 and Gemini. The fact that it factors as 2^7 × 2001 is interesting. 2001 is 3 × 23 × 29, so it's not a particularly special number. The power of 2 factor (128) might be related to computational efficiency or alignment. Actually, I should double-check my arithmetic: 256,128 = 256 × 1000 + 128 = 256,000 + 128 = 256,128 Let me verify the factorization: 128 × 2001 = 128 × 2000 + 128 × 1 = 256,000 + 128 = 256,128 ✓ So yes, 256,128 = 2^7 × 2001. This might be related to how the tokenizer was constructed or some computational optimization. The power of 2 factor could be useful for memory alignment or SIMD operations.
Great observation! The vocabulary size 256,128 = 2^7 × 2001 is indeed an interesting factorization.
This specific structure might be related to:
-
Computational efficiency: The 2^7 (128) factor allows for efficient memory alignment and potentially SIMD operations. Many GPU operations are optimized for power-of-2 dimensions.
-
SentencePiece tokenizer design: The paper mentions they use a SentencePiece tokenizer with "split digits, preserved whitespace, and byte-level encodings." The factorization might reflect how tokens are organized:
- Perhaps 128 special tokens or control codes
- And 2001 "buckets" of regular tokens
-
Byte-fallback encoding: With byte-level encoding as a fallback, you need at least 256 tokens for raw bytes. The structure might reserve certain ranges for different token types.
The fact that 2001 = 3 × 23 × 29 (all prime factors) suggests this wasn't chosen for its factorization properties, but rather 256,128 might have been chosen as "approximately 256k" while having nice computational properties.
This vocabulary size is notably different from models like LLaMA (32,000) or GPT-4 (~100,000), reflecting Google's multilingual focus - the paper mentions this large vocabulary is "designed to work on a large number of languages."