Stealing Part of a Production Language Model
Thread (64 posts)
What is the obstacle to propagating the method through earlier layers to recover more of the model?
Thinking Summary
The user is asking about the paper "Stealing Part of a Production Language Model" and specifically why the attack method described can't be propagated to earlier layers to recover more of the model. Let me think about this carefully. The paper describes an attack that recovers the embedding projection layer (the final layer) of a transformer model. The key insight is that this final layer projects from a lower-dimensional hidden space (dimension h) to a higher-dimensional vocabulary space (dimension l), where h << l. This makes the layer low-rank, which enables the attack. The attack works by: 1. Querying the model many times with different prompts 2. Collecting the logit outputs (or reconstructing them from limited information like top-k logprobs) 3. Using SVD to find that these high-dimensional logit vectors actually lie in a lower-dimensional subspace 4. Recovering the projection matrix up to some symmetries Now, why can't this be extended to earlier layers? The paper mentions this limitation explicitly. Let me look for the relevant sections... In Section 7 (Future Work), they state: "Our attack recovers a single layer of a transformer. We see no obvious methodology to extend it beyond just a single layer, due to the non-linearity of the models. But we invite further research in this area." The key obstacles are: 1. **Non-linearity**: The intermediate layers of transformers contain non-linear activations (like ReLU, GELU, or similar), layer normalization, and other non-linear operations. The attack fundamentally relies on the linearity of the final projection layer. 2. **No direct access to intermediate representations**: The attack only has access to the final output (logits). To recover layer N-1, you'd need access to the outputs of layer N-1, but these are transformed by layer N (which you've recovered) plus all the non-linear operations. 3. **Residual connections and complex interactions**: Modern transformers have residual connections, attention mechanisms, and other architectural features that create complex dependencies between layers. The "bottom-up vs top-down" challenge presents a fundamental mathematical barrier. The attack's top-down approach breaks down due to the intricate non-linear transformations between layers. Even with a perfectly recovered final layer matrix, the intermediate layer remains obscured by complex non-linear functions that fundamentally alter the representational space. This makes a systematic, layer-by-layer recovery essentially impossible with current techniques.
The main obstacle to propagating this attack through earlier layers is non-linearity. The attack fundamentally exploits the fact that the final embedding projection layer is a simple linear transformation from a low-dimensional hidden space to a high-dimensional vocabulary space.
Here are the key barriers:
1. Non-linear Operations
Earlier transformer layers contain:
- Non-linear activation functions (GELU, ReLU, etc.)
- Layer normalization or RMSNorm
- Softmax in attention mechanisms
- Residual connections that mix information across layers
The paper explicitly states: "We see no obvious methodology to extend it beyond just a single layer, due to the non-linearity of the models."
2. No Access to Intermediate Representations
The attack only observes the final logits. To recover layer N-1, you would need to observe the outputs of layer N-1, but these are:
- Transformed by the recovered layer N
- Further modified by all intervening non-linear operations
- Not directly observable from the API
3. Information Loss
Even if you perfectly recover the final layer, you're observing:
You cannot simply "invert" this composition because:
- The non-linear functions may not be invertible
- Multiple different weight configurations could produce the same final output
- The dimensionality reduction at each layer loses information
4. Architectural Complexity
Modern transformers have:
- Attention mechanisms with complex query-key-value interactions
- Multiple parallel attention heads
- Residual connections that bypass layers
- Layer-wise transformations that entangle information
The paper notes this is fundamentally different from prior work on stealing ReLU networks, where specialized techniques could handle the non-linearity layer by layer. Transformer architectures are far more complex and don't admit similar cryptanalytic approaches.
The authors acknowledge this limitation and explicitly invite future research in this area, suggesting it remains an open problem in model extraction attacks.
I get that transformers are complicated. You don’t need to explain that part. So is SHA1. And Yet…
Let’s say we have access to the logits and we recovered the final projection matrix. What is the obstacle to learning the weights of the penultimate layer? Be specific. Don’t make vague claims or lists of potential issues.
Thinking Summary
The user is asking for a specific analysis of what prevents extending the attack to the penultimate layer. They're right to push back on my vague answer. Let me think through this carefully. If we have recovered the final projection matrix $W$, we can compute: - From the API: logits = $W \cdot h$ where $h$ is the hidden state after the final layer normalization - We know $W$ from the attack - So we can compute $h = W^{-1} \cdot \text{logits}$ (or more precisely, use the pseudoinverse since $W$ is not square) But wait, there's a catch. The paper shows we can only recover $W$ up to an orthogonal transformation. So we actually recover $\tilde{W} = W \cdot G$ for some orthogonal matrix $G$. So when we compute $\tilde{W}^{-1} \cdot \text{logits}$, we get $G^{-1} \cdot h$, not $h$ itself. Now, $h$ comes from the normalization layer. In most transformers, this is either LayerNorm or RMSNorm applied to the output of the penultimate layer. Let's say it's RMSNorm for simplicity: $h = \text{RMSNorm}(x) = \frac{x}{\sqrt{\text{mean}(x^2)}} \cdot \gamma + \beta$ where $x$ is the output of the penultimate layer and $\gamma, \beta$ are learned parameters. The specific obstacle is: even if we could perfectly recover $h$ (which we can't due to the orthogonal ambiguity), we cannot uniquely determine $x$ from $h$ because: 1. The normalization destroys the magnitude information - we only get the direction 2. Multiple different $x$ vectors can map to the same normalized $h$ I need to articulate the precise mathematical barrier more rigorously. The core issue is the fundamental information loss inherent in normalization transformations, which create a many-to-one mapping that prevents unique reconstruction of the original vector.
You're right to push for specificity. Here's the precise mathematical obstacle:
When you recover the final projection matrix , you actually recover for some unknown orthogonal matrix (as proven in the paper). This means when you compute the hidden states from the logits, you get:
where is the true hidden state after layer normalization.
Now, the penultimate layer produces some output , which goes through normalization (let's say RMSNorm) to produce :
(ignoring the learnable scale/bias for clarity)
To recover the penultimate layer weights, you need to know what values correspond to your observed values. But this is impossible because:
The normalization is a many-to-one function: For any normalized vector , there are infinitely many vectors that map to it (any positive scalar multiple of gives the same ).
Even if you could magically remove the orthogonal ambiguity and recover the true , you still cannot determine which specific produced it. Without knowing the true values, you cannot set up a system of equations to solve for the penultimate layer weights.
This is fundamentally different from the final layer case, where the transformation is purely linear and full-rank, allowing you to recover the weights up to a well-defined ambiguity.
But I don’t care about the fact that I only get the final layer up to some orthogonal matrix because I only care about getting the model weights up to symmetry. Symmetry at every step of the model architecture. So I want an explanation of where SPECIFICALLY this breaks.
Thinking Summary
The user is asking about the specific mathematical obstacle to propagating the attack backward through the model layers. They want me to be precise and not vague. They understand that we only care about recovering weights up to symmetry. Let me think through this carefully. The paper recovers $\tilde{W} = W \cdot G$ where G is some transformation matrix. The key insight the user has is that they don't care about this ambiguity - they're fine with recovering everything up to symmetry. So what's the actual obstacle? Let me trace through the computation: 1. We observe logits = W * h where h is the output of the last hidden layer (after normalization) 2. We can recover W up to right multiplication by some matrix G 3. So we can compute h up to left multiplication by G^{-1} 4. Now h came from some normalization of the previous layer's output x 5. The normalization (whether LayerNorm or RMSNorm) maps x to h The key issue is: even if we're okay with symmetries, the normalization introduces a fundamental information loss that can't be recovered even up to symmetry. Specifically, for RMSNorm: h = x / ||x|| * sqrt(d) For LayerNorm: h = (x - mean(x)) / std(x) * gamma + beta In both cases, the norm (or scale) of x is lost. This isn't a symmetry - it's actual information destruction. But wait, the user is asking about recovering weights up to symmetry. Let's think more carefully. If the penultimate layer outputs x, and x goes through normalization to produce h, and we know h (up to some transformation), can we recover the penultimate layer weights? The penultimate layer computes something like: x = W_{n-1} * (previous activations) + bias Even if we knew x exactly, we'd need to know the previous activations to solve for W_{n-1}. But we don't know x - we only know its normalized version h. I see the core problem now. Normalization creates a fundamental information loss where multiple input vectors map to the same normalized output. This non-invertible transformation prevents precise weight recovery, even when allowing for symmetry transformations. The mapping from input to normalized output destroys critical scale and offset information needed to uniquely reconstruct the original weights. The key issue is that normalization doesn't just reparameterize - it actively collapses distinct input configurations into a single representation, making weight reconstruction fundamentally impossible. Different input vectors that normalize to the same output would require fundamentally different weight configurations to produce equivalent model behavior.
You're right to demand precision. Let me trace through exactly where the symmetry argument breaks down.
The Core Issue
When we recover the final layer, we get for some matrix . This gives us hidden states .
The hidden state comes from normalizing the penultimate layer's output :
- For RMSNorm:
- For LayerNorm:
Here's the precise problem: The normalization is not injective, even up to symmetry.
Why This Is Different From The Final Layer
For the final layer, the ambiguity corresponds to a symmetry of the model - you can transform and and get identical outputs.
But for normalization, if you have some on the unit sphere, there are infinitely many vectors that map to it: all normalize to the same .
This is not a model symmetry. Different values of would require genuinely different penultimate weight matrices to produce. If the penultimate layer computes for some activation , then:
- To get , you need weights
- To get , you need weights
These are not related by any symmetry of the overall model - they would produce different outputs for most inputs.
The Fundamental Difference
The key distinction is:
- Linear layers: Information preserved up to a well-defined group action (right multiplication)
- Normalization: Information genuinely destroyed (projection onto a lower-dimensional manifold)
Even if you're happy to accept all model symmetries, normalization creates a many-to-one mapping where the different pre-images would require weight configurations that are NOT related by any symmetry. The scale information is irrecoverably lost, not just reparameterized.
This is why the attack cannot propagate backwards, even in principle, even accepting all symmetries.
What is gamma_i? Don’t use any notation without defining it. Set up the problem statement cleanly.
Thinking Summary
The user is asking me to be more precise and clear about the mathematical setup. I need to define all notation explicitly and lay out the problem cleanly. Let me think through this carefully. The setup is: - We have access to logits from a transformer model - We can recover the final projection matrix W up to some transformation - We want to know why we can't propagate back to recover earlier layers Let me be very precise about the mathematical structure: 1. The final layer of a transformer outputs logits = W * h where: - W is the projection matrix (vocabulary_size × hidden_dim) - h is the hidden state after the final normalization layer (hidden_dim × 1) 2. The normalization layer (let's focus on RMSNorm for clarity) computes: h = γ ⊙ normalize(x) + β where: - x is the output of the penultimate layer (hidden_dim × 1) - normalize(x) = x / ||x||_2 - γ is a learned scale parameter (hidden_dim × 1) - β is a learned bias parameter (hidden_dim × 1) - ⊙ denotes element-wise multiplication 3. The penultimate layer computes: x = f(W_pen * a + b_pen) where: - W_pen is the weight matrix of the penultimate layer - a is the input to the penultimate layer - b_pen is the bias - f is some nonlinearity (or could be linear for the residual stream) Now, the key question is: given that we can recover W up to some transformation, why can't we recover W_pen up to some (possibly different) transformation? The issue is that normalization destroys information in a way that cannot be undone even up to symmetry.
Let me set up the problem precisely.
Problem Setup
Consider the last two layers of a transformer:
Layer n-1 (Penultimate): Takes input and outputs
Normalization: Takes and outputs
For RMSNorm:
where and are learned parameters (element-wise scale and shift).
Layer n (Final): Takes and outputs logits
What We Know
From the paper's attack, we can recover for some invertible matrix .
This means from observed logits , we can compute:
The Question
Can we now recover (up to some symmetry)?
Why It Fails
The normalization function maps many different vectors to the same :
- Any and (for ) produce the same normalized vector
- After applying scale and shift: only if the scale and shift are both zero
But and would require:
- Either the same weights (if )
- Or weights scaled by : ,
This scaling is not a symmetry of the model. Using instead of would change the model's behavior on other inputs. There's no transformation we can apply elsewhere in the network to compensate.
Therefore, even knowing perfectly (up to the symmetry ), we cannot determine which produced it, and thus cannot set up constraints to solve for .