gemma pytorch repo
Thread (29 posts)
Thinking Summary
The user is asking about the organization of the Gemma 2 source code based on the gemma_pytorch repository. Let me analyze the code structure and highlight how Gemma 2 is organized within the codebase. From what I can see: 1. The repository supports multiple Gemma architectures (Gemma 1, Gemma 2, and Gemma 3) 2. Gemma 2 has specific model configurations and architectural features 3. There are dedicated decoder layer classes for different architectures Let me look at the key organizational aspects:
Looking at the gemma_pytorch repository structure, here's how the Gemma 2 source code is organized:
Architecture Integration
Gemma 2 is integrated into a unified codebase that supports multiple Gemma architectures (1, 2, and 3). The code uses an Architecture
enum to differentiate between versions:
class Architecture(enum.Enum):
GEMMA_1 = 1
GEMMA_2 = 2
GEMMA_3 = 3
Configuration System
Gemma 2 has dedicated configuration functions in gemma/config.py
:
get_config_for_2b_v2()
- 2B parameter variantget_config_for_9b()
- 9B parameter variantget_config_for_27b()
- 27B parameter variant
Each configuration specifies Gemma 2's unique features:
- Architecture:
Architecture.GEMMA_2
- Pre/Post FFW Normalization:
use_pre_ffw_norm=True
,use_post_ffw_norm=True
- Logit Softcapping:
final_logit_softcapping=30.0
,attn_logit_softcapping=50.0
- Attention Types: Alternating pattern of
LOCAL_SLIDING
andGLOBAL
attention - Sliding Window:
sliding_window_size=4096
Model Implementation
The model architecture in gemma/model.py
uses conditional logic to handle different architectures:
Decoder Layer
Gemma 2 has its own decoder layer class Gemma2DecoderLayer
with specific features:
- Additional normalization layers (pre/post feedforward)
- Support for both local sliding and global attention types
- Attention logit softcapping
Model Construction
for i in range(config.num_hidden_layers):
if config.architecture == gemma_config.Architecture.GEMMA_1:
self.layers.append(GemmaDecoderLayer(config))
elif config.architecture in (Architecture.GEMMA_2, Architecture.GEMMA_3):
attn_type = config.attn_types[i % len(config.attn_types)]
self.layers.append(Gemma2DecoderLayer(config, attn_type))
Key Gemma 2 Features
- Mixed Attention Pattern: Alternates between local sliding window attention and global attention across layers
- Additional Normalizations: Extra RMSNorm layers before and after the MLP block
- Softcapping: Applied to both attention logits and final output logits
- Query Pre-attention Scaling: Custom scaling factor for certain variants (e.g., 27B model)
Variant Selection
The get_model_config()
function maps variant strings to configurations:
'2b-v2'
→ Gemma 2 2B model'9b'
→ Gemma 2 9B model'27b'
→ Gemma 2 27B model
This organization allows the codebase to support multiple Gemma architectures while sharing common components and maintaining architecture-specific features through configuration and conditional logic.
Looking at the 2b-v2 model, how does data flow from inputs through the network to the output token selection? Explain each step with references to specific functions
Thinking Summary
The user is asking about the data flow through the Gemma 2b-v2 model specifically. I need to trace through the code from input to output, referencing specific functions. Let me walk through this step by step. Looking at the code: 1. The 2b-v2 model is a Gemma 2 architecture variant 2. The main entry point appears to be the `forward` method in `GemmaForCausalLM` 3. I need to trace through embedding, model layers, and sampling Let me trace through the forward pass:
Let me trace the data flow through the Gemma 2b-v2 model from input tokens to output token selection:
1. Input Tokenization and Embedding
Function: GemmaForCausalLM.forward()
# Input: input_token_ids [batch_size, input_len]
hidden_states = self.embedder(input_token_ids)
# Output: [batch_size, input_len, hidden_size=2304]
# Normalize embeddings by sqrt(hidden_size)
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
2. Positional Encoding Setup
Function: GemmaForCausalLM.forward()
# Create frequency encodings for rotary embeddings
freqs_cis[AttentionType.LOCAL_SLIDING] = self.freqs_cis.index_select(0, input_positions)
freqs_cis[AttentionType.GLOBAL] = self.freqs_cis.index_select(0, input_positions)
3. Model Layers (26 layers for 2b-v2)
Function: GemmaModel.forward()
→ Gemma2DecoderLayer.forward()
Each layer alternates between LOCAL_SLIDING and GLOBAL attention:
3a. Pre-Attention Processing
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) # RMSNorm
3b. Self-Attention
Function: GemmaAttention.forward()
# Project to Q, K, V
qkv = self.qkv_proj(hidden_states) # Linear projection
xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Reshape for multi-head attention
# 2b-v2: 8 attention heads, 4 KV heads, head_dim=256
xq = xq.view(batch_size, -1, 8, 256)
xk = xk.view(batch_size, -1, 4, 256)
xv = xv.view(batch_size, -1, 4, 256)
# Apply rotary embeddings
xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)
xk = apply_rotary_emb(xk, freqs_cis=freqs_cis)
# Update KV cache
k_cache.index_copy_(1, kv_write_indices, xk)
v_cache.index_copy_(1, kv_write_indices, xv)
# Compute attention scores
q.mul_(self.scaling) # Scale by 1/sqrt(head_dim)
scores = torch.matmul(q, k.transpose(2, 3))
# Apply attention type-specific mask
if self.attn_type == AttentionType.LOCAL_SLIDING:
# Use sliding window of 4096 tokens
scores = scores + local_mask
else:
# Use global causal mask
scores = scores + mask
# Apply attention logit softcapping (50.0 for 2b-v2)
scores = scores / 50.0
scores = torch.tanh(scores)
scores = scores * 50.0
# Softmax and apply to values
scores = F.softmax(scores.float(), dim=-1)
output = torch.matmul(scores, v)
# Project back
output = self.o_proj(output) # Linear projection
3c. Post-Attention Processing
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states # Residual connection
3d. MLP Block
Function: GemmaMLP.forward()
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states) # Pre-norm
# MLP with GeGLU activation
gate = self.gate_proj(hidden_states) # Project to intermediate_size=9216
gate = F.gelu(gate, approximate="tanh")
up = self.up_proj(hidden_states) # Project to intermediate_size=9216
fuse = gate * up # Element-wise multiply
hidden_states = self.down_proj(fuse) # Project back to hidden_size=2304
hidden_states = self.post_feedforward_layernorm(hidden_states) # Post-norm
hidden_states = residual + hidden_states # Residual connection
4. Final Normalization
Function: GemmaModel.forward()
hidden_states = self.norm(hidden_states) # Final RMSNorm
# Output: [batch_size, input_len, hidden_size=2304]
5. Token Selection
Function: Sampler.forward()
# Select output position
hidden_states = hidden_states.index_select(1, output_positions).squeeze(dim=1)
# Shape: [batch_size, hidden_size=2304]
# Compute logits
logits = torch.matmul(hidden_states, embedding.t())
# Shape: [batch_size, vocab_size=256000]
# Apply final logit softcapping (30.0 for 2b-v2)
logits = logits / 30.0
logits = torch.tanh(logits)
logits = logits * 30.0
# Temperature scaling (if temperature provided)
logits.div_(temperatures.unsqueeze(dim=1))
# Convert to probabilities
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Apply top-p and top-k filtering
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
# ... filtering logic ...
# Sample next token
next_token_ids = torch.multinomial(probs, num_samples=1)
Summary of Data Flow
- Embedding: Tokens → 2304-dim embeddings (normalized)
- 26 Transformer Layers: Each with:
- RMSNorm → Attention (local/global) → RMSNorm → Residual
- RMSNorm → MLP (GeGLU) → RMSNorm → Residual
- Final Norm: RMSNorm on hidden states
- Logit Generation: Project to vocabulary space
- Sampling: Softmax → Top-p/Top-k filtering → Multinomial sampling
The 2b-v2 model's unique features include:
- Alternating local (4096 window) and global attention
- Double normalization (pre/post) around MLP blocks
- Softcapping on both attention (50.0) and final logits (30.0)
- 8 attention heads with 4 KV heads (2:1 ratio)
how does the attention computation in 3b flow through to the hidden states in 3a/c
Thinking Summary
The user is asking about how the attention computation output flows from step 3b back into the hidden states in steps 3a and 3c. Let me trace this carefully through the code. Looking at the Gemma2DecoderLayer.forward() method: ```python def forward( self, hidden_states: torch.Tensor, freqs_cis: torch.Tensor, kv_write_indices: torch.Tensor, kv_cache: Tuple[torch.Tensor, torch.Tensor], mask: torch.Tensor, local_mask: torch.Tensor, ) -> torch.Tensor: # Self Attention residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, freqs_cis=freqs_cis, kv_write_indices=kv_write_indices, kv_cache=kv_cache, mask=mask, local_mask=local_mask, ) hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = residual + hidden_states # MLP residual = hidden_states if self.pre_feedforward_layernorm is not None: hidden_states = self.pre_feedforward_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) if self.post_feedforward_layernorm is not None: hidden_states = self.post_feedforward_layernorm(hidden_states) hidden_states = residual + hidden_states return hidden_states ``` So the flow is: 1. Save original hidden_states as residual 2. Apply input_layernorm to hidden_states 3. Pass normalized hidden_states through self_attn 4. Apply post_attention_layernorm to attention output 5. Add residual connection I track the attention transformation, where the original input is preserved through a clever residual connection mechanism. The self-attention output gets normalized and then added back to the initial input, allowing information flow while maintaining the original signal's integrity.
Let me trace how the attention output flows back into the hidden states through the residual connections:
Attention Block Flow in Gemma2DecoderLayer
Step 3a: Pre-Attention
# Save the original input for residual connection
residual = hidden_states # Shape: [batch_size, seq_len, 2304]
# Normalize for attention computation
hidden_states = self.input_layernorm(hidden_states) # RMSNorm
Step 3b: Attention Computation
# The attention mechanism processes the normalized hidden states
hidden_states = self.self_attn(
hidden_states=hidden_states, # Normalized input
freqs_cis=freqs_cis,
kv_write_indices=kv_write_indices,
kv_cache=kv_cache,
mask=mask,
local_mask=local_mask,
)
# Output shape: [batch_size, seq_len, 2304]
Step 3c: Post-Attention Integration
# Apply post-attention normalization
hidden_states = self.post_attention_layernorm(hidden_states)
# CRITICAL: Add residual connection
hidden_states = residual + hidden_states
The Residual Connection Pattern
The key insight is the residual connection that adds the original input back to the attention output:
Original Input (residual) ──────────────────┐
│ │
↓ │
Input LayerNorm │
│ │
↓ │
Self-Attention │
│ │
↓ │
Post-Attention LayerNorm │
│ │
↓ ↓
└──────────────────> Addition <────────┘
│
↓
Output (new hidden_states)
Why This Matters
-
Gradient Flow: The residual connection allows gradients to flow directly through the addition operation, bypassing the attention mechanism if needed.
-
Information Preservation: The original input information is preserved and combined with the attention-transformed information.
-
Mathematical Form:
hidden_states_out = hidden_states_in + LayerNorm(Attention(LayerNorm(hidden_states_in)))
The Same Pattern Repeats for MLP
# After attention block completes, save new residual
residual = hidden_states # This now includes attention output
# Pre-MLP normalization
hidden_states = self.pre_feedforward_layernorm(hidden_states)
# MLP computation
hidden_states = self.mlp(hidden_states)
# Post-MLP normalization
hidden_states = self.post_feedforward_layernorm(hidden_states)
# Another residual connection
hidden_states = residual + hidden_states
Data Flow Summary
For each Gemma2 layer:
- Input:
hidden_states
from previous layer - Attention Block:
- Save input as
residual
- Transform through normalized attention
- Add back
residual
- Save input as
- MLP Block:
- Save attention output as new
residual
- Transform through normalized MLP
- Add back
residual
- Save attention output as new
- Output: Enhanced
hidden_states
to next layer
This double-residual pattern with pre/post normalization is a key architectural feature of Gemma 2 that helps with training stability and allows the model to learn both local refinements (through attention/MLP) and preserve global information flow (through residuals).
If there are no initial tokens what is the initial state of attention?
Thinking Summary
The user is asking about the initial state of attention when there are no initial tokens. This is an interesting question about how the model starts generation from scratch. Let me think through this... In the context of Gemma models, typically there would always be at least one token - usually a BOS (beginning of sequence) token. But let me trace through what happens in the initial state. Looking at the code: 1. The tokenizer has a `bos_id` defined 2. In the `encode` method, there's a `bos` parameter that defaults to True 3. The model expects some input tokens But the question is about the attention mechanism's initial state. Let me look at the KV cache initialization and how attention works when starting fresh.
Looking at the code, Gemma models always start with at least one token - typically the BOS (Beginning of Sequence) token. But let me explain the initial attention state:
Initial KV Cache State
When the model starts, the KV caches are initialized to zeros:
# From generate() method
kv_caches = []
for _ in range(self.config.num_hidden_layers):
size = (batch_size, max_seq_len, self.config.num_key_value_heads,
self.config.head_dim)
dtype = self.config.get_dtype()
k_cache = torch.zeros(size=size, dtype=dtype, device=device)
v_cache = torch.zeros(size=size, dtype=dtype, device=device)
kv_caches.append((k_cache, v_cache))
First Token Processing
When processing the first token (usually BOS):
1. Tokenization
# From Tokenizer.encode()
def encode(self, s: str, bos: bool = True, eos: bool = False) -> List[int]:
t = self.sp_model.EncodeAsIds(s)
if bos:
t = [self.bos_id] + t # Prepends BOS token
return t
2. First Attention Computation
For the very first token at position 0:
# Input position is [0]
input_positions_tensor = torch.arange(0, 1) # Just position 0
# The attention mask allows the first token to only attend to itself
mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len), -2.3819763e38)
mask_tensor = torch.triu(mask_tensor, diagonal=1)
# This creates a causal mask where position 0 can only see position 0
3. Attention Mechanics
When the first token goes through attention:
# Query from first token
q = xq[:, 0:1, :, :] # Shape: [batch, 1, heads, head_dim]
# Key and Value caches
# Initially all zeros except position 0 which gets updated:
k_cache[:, 0:1, :, :] = xk # First token's key
v_cache[:, 0:1, :, :] = xv # First token's value
# Attention scores
scores = torch.matmul(q, k.transpose(2, 3))
# Only position 0 has non-zero keys, so attention focuses entirely on position 0
The "Empty" Attention State
If hypothetically there were truly no tokens:
- The model couldn't function because embedding lookup would fail
- There would be no query vectors to compute attention
The closest conceptual "empty" state is:
- KV caches: All zeros (no previous context stored)
- Attention mask: Prevents looking at future positions
- First attention: Self-attention only (token attends to itself)
Practical Initialization
In practice, generation always starts with at least a BOS token:
# Example from the codebase
prompt_tokens = [self.tokenizer.encode(prompt) for prompt in prompts]
# This automatically includes BOS token due to default bos=True
So the initial attention state is:
- Position 0: Contains BOS token embeddings
- KV Cache: Zeros everywhere except position 0 after first forward pass
- Attention Pattern: First token can only attend to itself
- Subsequent tokens: Can attend to all previous positions including BOS
This is why language models can generate coherent text from just a BOS token - it provides the initial hidden state that seeds all subsequent generation.
can you explain the attention mask mechanism? why do we need to prevent the token form attending to future tokens? wouldn’t that be the case by default since they haven’t arrived yet? or is it that they’re all being preprocessed in this stage?
Thinking Summary
Great question! The user is asking about the attention mask mechanism and why we need to prevent tokens from attending to future tokens. They're wondering if this would happen naturally since future tokens haven't arrived yet, or if all tokens are being processed together. This is a fundamental concept in transformer architectures, especially for causal language models. Let me explain: 1. During training and when processing multiple tokens at once (like during prompt processing), all tokens ARE present in the sequence at the same time 2. The mask is needed to maintain causality - each token should only see previous tokens 3. During inference, especially token-by-token generation, the mask might seem redundant but it's still used for consistency Let me trace through the code to show this clearly.
Excellent question! You've hit on a key insight about how transformers process sequences. Let me explain why the attention mask is necessary.
The Key Insight: Parallel Processing
During both training and inference (especially prefill), all tokens in a sequence are processed simultaneously, not one by one. This is what makes transformers efficient.
Example: Processing "The cat sat"
When the model processes this sequence, all three tokens are in memory at once:
# All tokens are embedded simultaneously
input_token_ids = [234, 1876, 3457] # "The", "cat", "sat"
hidden_states = self.embedder(input_token_ids)
# Shape: [1, 3, 2304] - all 3 tokens have embeddings
Why We Need the Mask
Without a mask, when computing attention for position 0 ("The"), the attention mechanism would see:
- Position 0: "The" ✓
- Position 1: "cat" ✗ (shouldn't see this yet!)
- Position 2: "sat" ✗ (shouldn't see this yet!)
Here's what happens in the attention computation:
# Query from position 0 ("The")
q_0 = query[0] # Shape: [heads, head_dim]
# WITHOUT mask, attention scores would be:
scores = torch.matmul(q_0, all_keys.transpose())
# This computes attention to ALL positions including future ones!
# WITH mask:
mask_tensor = torch.triu(torch.full(size, -inf), diagonal=1)
# Creates:
# [[ 0, -inf, -inf],
# [ 0, 0, -inf],
# [ 0, 0, 0]]
scores = scores + mask[0] # Adds -inf to positions 1 and 2
# After softmax, -inf becomes 0 probability
The Mask Creation
# From the code:
mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len), -2.3819763e38)
mask_tensor = torch.triu(mask_tensor, diagonal=1)
This creates an upper triangular matrix of -infinity values:
Position: 0 1 2 3 ...
0 [ 0 -inf -inf -inf ...] # Token 0 sees only position 0
1 [ 0 0 -inf -inf ...] # Token 1 sees positions 0,1
2 [ 0 0 0 -inf ...] # Token 2 sees positions 0,1,2
3 [ 0 0 0 0 ...] # Token 3 sees positions 0,1,2,3
Two Scenarios Where This Matters
1. Training/Prompt Processing (Parallel)
# Processing entire prompt at once
prompt = "The cat sat on the"
tokens = [234, 1876, 3457, 892, 234] # All 5 tokens
# Without mask, token 0 would attend to all 5 tokens!
# With mask, token 0 only attends to itself
2. Generation with KV Cache
# Even during generation, the KV cache contains ALL positions
k_cache = torch.zeros(batch, max_seq_len, heads, dim)
v_cache = torch.zeros(batch, max_seq_len, heads, dim)
# When generating token at position 5:
# - Positions 0-4: Have actual key/values
# - Positions 5-max: Are zeros
# The mask ensures we don't attend to the zero-filled future positions
The Sliding Window Complication
For Gemma 2's local attention, there's an additional mask:
if self.attn_type == AttentionType.LOCAL_SLIDING:
# Only attend to last 4096 tokens
sliding_mask = torch.triu(all_ones, -sliding_window_size + 1) *
torch.tril(all_ones, sliding_window_size - 1)
mask = torch.where(sliding_mask == 1, mask, -inf)
Summary
The attention mask is necessary because:
- Parallel Processing: Multiple tokens are processed simultaneously during training and prefill
- Causal Consistency: Ensures each token only sees its past, maintaining autoregressive property
- KV Cache Safety: Prevents attending to uninitialized future cache positions
- Training-Inference Parity: Same masking pattern works for both training and inference
Without the mask, the model would "cheat" by looking at future tokens during training, breaking the fundamental assumption that language models predict the next token based only on previous context.
Thinking Summary
The user wants me to explain the sliding window attention mechanism in Gemma 2. Let me look at the code carefully and explain how it works. From the code: ```python if self.attn_type == AttentionType.LOCAL_SLIDING and self.sliding_window_size is not None and local_mask is not None: mask = local_mask ``` And in the forward method: ```python local_mask_tensor = mask_tensor + torch.tril( torch.full((1, 1, max_seq_len, max_seq_len), -2.3819763e38, device=device), diagonal=-self.config.sliding_window_size, ) if self.config.sliding_window_size else None ``` For Gemma 2b-v2, the sliding window size is 4096. Let me break this down step by step.
Sliding Window Attention in Gemma 2
The sliding window mechanism limits how far back each token can "see" in the sequence. For Gemma 2b-v2, this window is 4096 tokens.
The Basic Concept
Instead of attending to ALL previous tokens, each token only attends to the most recent 4096 tokens:
Regular Attention (Global):
Token 5000 can see: [0, 1, 2, ..., 4998, 4999] ← All previous tokens
Sliding Window Attention (Local):
Token 5000 can see: [904, 905, ..., 4998, 4999] ← Only last 4096 tokens
↑
5000 - 4096 = 904
How the Mask is Created
Let's break down the mask creation code:
# Start with the regular causal mask
mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len), -2.3819763e38)
mask_tensor = torch.triu(mask_tensor, diagonal=1) # Upper triangular with -inf
# Create the sliding window mask
local_mask_tensor = mask_tensor + torch.tril(
torch.full((1, 1, max_seq_len, max_seq_len), -2.3819763e38),
diagonal=-self.config.sliding_window_size, # -4096 for 2b-v2
)
Step-by-Step Mask Construction
- Start with causal mask:
Position: 0 1 2 3 4 5 ...
0 [ 0.0 -inf -inf -inf -inf -inf ...]
1 [ 0.0 0.0 -inf -inf -inf -inf ...]
2 [ 0.0 0.0 0.0 -inf -inf -inf ...]
3 [ 0.0 0.0 0.0 0.0 -inf -inf ...]
...
- Add lower triangular mask with diagonal=-4096:
torch.tril(full_matrix_of_-inf, diagonal=-4096)
This creates -inf values for positions more than 4096 tokens ago.
- Result for position 5000:
Position: 0 1 ... 903 904 905 ... 4999 5000
5000 [-inf -inf ... -inf 0.0 0.0 ... 0.0 -inf]
↑________________↑ ↑_____________________↑
Can't see these Can see these 4096 Can't see
(too far back) recent tokens future
Visual Example with Small Window (size=3)
Sliding window size = 3
Position: 0 1 2 3 4 5
0 [ 0.0 -inf -inf -inf -inf -inf] # Sees: [0]
1 [ 0.0 0.0 -inf -inf -inf -inf] # Sees: [0,1]
2 [ 0.0 0.0 0.0 -inf -inf -inf] # Sees: [0,1,2]
3 [ 0.0 0.0 0.0 0.0 -inf -inf] # Sees: [0,1,2,3] (full window)
4 [-inf 0.0 0.0 0.0 0.0 -inf] # Sees: [1,2,3,4] (slides!)
5 [-inf -inf 0.0 0.0 0.0 0.0] # Sees: [2,3,4,5]
↑
Can't see position 0 anymore (outside window)
In the Attention Computation
def forward(self, hidden_states, ...):
# ... compute queries, keys, values ...
# Compute attention scores
scores = torch.matmul(q, k.transpose(2, 3))
# Apply the appropriate mask
if self.attn_type == AttentionType.LOCAL_SLIDING:
mask = local_mask # Use sliding window mask
else:
mask = global_mask # Use full causal mask
scores = scores + mask # Add -inf to blocked positions
scores = F.softmax(scores, dim=-1) # -inf → 0 probability
Why Use Sliding Windows?
- Efficiency: Reduces computation from O(n²) to O(n·w) where w=window size
- Memory: Limits memory usage for very long sequences
- Local Context: Forces model to focus on recent context
- Layer Alternation: Gemma 2 alternates between local and global attention:
# From config for 2b-v2
attn_types=[AttentionType.LOCAL_SLIDING, AttentionType.GLOBAL] * 13
# Pattern: [Local, Global, Local, Global, ...] for 26 layers
This alternation allows the model to:
- Capture fine-grained local patterns (sliding window layers)
- Maintain long-range dependencies (global attention layers)
The 4096 token window is large enough to capture most relevant context while still providing computational benefits for very long sequences.