gemma pytorch repo
Thread (29 posts)
Thinking Summary
The user is asking about the organization of the Gemma 2 source code based on the gemma_pytorch repository. Let me analyze the code structure and highlight how Gemma 2 is organized within the codebase. From what I can see: 1. The repository supports multiple Gemma architectures (Gemma 1, Gemma 2, and Gemma 3) 2. Gemma 2 has specific model configurations and architectural features 3. There are dedicated decoder layer classes for different architectures Let me look at the key organizational aspects:
Looking at the gemma_pytorch repository structure, here's how the Gemma 2 source code is organized:
Architecture Integration
Gemma 2 is integrated into a unified codebase that supports multiple Gemma architectures (1, 2, and 3). The code uses an Architecture
enum to differentiate between versions:
class Architecture(enum.Enum):
GEMMA_1 = 1
GEMMA_2 = 2
GEMMA_3 = 3
Configuration System
Gemma 2 has dedicated configuration functions in gemma/config.py
:
get_config_for_2b_v2()
- 2B parameter variantget_config_for_9b()
- 9B parameter variantget_config_for_27b()
- 27B parameter variant
Each configuration specifies Gemma 2's unique features:
- Architecture:
Architecture.GEMMA_2
- Pre/Post FFW Normalization:
use_pre_ffw_norm=True
,use_post_ffw_norm=True
- Logit Softcapping:
final_logit_softcapping=30.0
,attn_logit_softcapping=50.0
- Attention Types: Alternating pattern of
LOCAL_SLIDING
andGLOBAL
attention - Sliding Window:
sliding_window_size=4096
Model Implementation
The model architecture in gemma/model.py
uses conditional logic to handle different architectures:
Decoder Layer
Gemma 2 has its own decoder layer class Gemma2DecoderLayer
with specific features:
- Additional normalization layers (pre/post feedforward)
- Support for both local sliding and global attention types
- Attention logit softcapping
Model Construction
for i in range(config.num_hidden_layers):
if config.architecture == gemma_config.Architecture.GEMMA_1:
self.layers.append(GemmaDecoderLayer(config))
elif config.architecture in (Architecture.GEMMA_2, Architecture.GEMMA_3):
attn_type = config.attn_types[i % len(config.attn_types)]
self.layers.append(Gemma2DecoderLayer(config, attn_type))
Key Gemma 2 Features
- Mixed Attention Pattern: Alternates between local sliding window attention and global attention across layers
- Additional Normalizations: Extra RMSNorm layers before and after the MLP block
- Softcapping: Applied to both attention logits and final output logits
- Query Pre-attention Scaling: Custom scaling factor for certain variants (e.g., 27B model)
Variant Selection
The get_model_config()
function maps variant strings to configurations:
'2b-v2'
→ Gemma 2 2B model'9b'
→ Gemma 2 9B model'27b'
→ Gemma 2 27B model
This organization allows the codebase to support multiple Gemma architectures while sharing common components and maintaining architecture-specific features through configuration and conditional logic.
Looking at the 2b-v2 model, how does data flow from inputs through the network to the output token selection? Explain each step with references to specific functions
Thinking Summary
The user is asking about the data flow through the Gemma 2b-v2 model specifically. I need to trace through the code from input to output, referencing specific functions. Let me walk through this step by step. Looking at the code: 1. The 2b-v2 model is a Gemma 2 architecture variant 2. The main entry point appears to be the `forward` method in `GemmaForCausalLM` 3. I need to trace through embedding, model layers, and sampling Let me trace through the forward pass:
Let me trace the data flow through the Gemma 2b-v2 model from input tokens to output token selection:
1. Input Tokenization and Embedding
Function: GemmaForCausalLM.forward()
# Input: input_token_ids [batch_size, input_len]
hidden_states = self.embedder(input_token_ids)
# Output: [batch_size, input_len, hidden_size=2304]
# Normalize embeddings by sqrt(hidden_size)
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
2. Positional Encoding Setup
Function: GemmaForCausalLM.forward()
# Create frequency encodings for rotary embeddings
freqs_cis[AttentionType.LOCAL_SLIDING] = self.freqs_cis.index_select(0, input_positions)
freqs_cis[AttentionType.GLOBAL] = self.freqs_cis.index_select(0, input_positions)
3. Model Layers (26 layers for 2b-v2)
Function: GemmaModel.forward()
→ Gemma2DecoderLayer.forward()
Each layer alternates between LOCAL_SLIDING and GLOBAL attention:
3a. Pre-Attention Processing
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) # RMSNorm
3b. Self-Attention
Function: GemmaAttention.forward()
# Project to Q, K, V
qkv = self.qkv_proj(hidden_states) # Linear projection
xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Reshape for multi-head attention
# 2b-v2: 8 attention heads, 4 KV heads, head_dim=256
xq = xq.view(batch_size, -1, 8, 256)
xk = xk.view(batch_size, -1, 4, 256)
xv = xv.view(batch_size, -1, 4, 256)
# Apply rotary embeddings
xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)
xk = apply_rotary_emb(xk, freqs_cis=freqs_cis)
# Update KV cache
k_cache.index_copy_(1, kv_write_indices, xk)
v_cache.index_copy_(1, kv_write_indices, xv)
# Compute attention scores
q.mul_(self.scaling) # Scale by 1/sqrt(head_dim)
scores = torch.matmul(q, k.transpose(2, 3))
# Apply attention type-specific mask
if self.attn_type == AttentionType.LOCAL_SLIDING:
# Use sliding window of 4096 tokens
scores = scores + local_mask
else:
# Use global causal mask
scores = scores + mask
# Apply attention logit softcapping (50.0 for 2b-v2)
scores = scores / 50.0
scores = torch.tanh(scores)
scores = scores * 50.0
# Softmax and apply to values
scores = F.softmax(scores.float(), dim=-1)
output = torch.matmul(scores, v)
# Project back
output = self.o_proj(output) # Linear projection
3c. Post-Attention Processing
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states # Residual connection
3d. MLP Block
Function: GemmaMLP.forward()
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states) # Pre-norm
# MLP with GeGLU activation
gate = self.gate_proj(hidden_states) # Project to intermediate_size=9216
gate = F.gelu(gate, approximate="tanh")
up = self.up_proj(hidden_states) # Project to intermediate_size=9216
fuse = gate * up # Element-wise multiply
hidden_states = self.down_proj(fuse) # Project back to hidden_size=2304
hidden_states = self.post_feedforward_layernorm(hidden_states) # Post-norm
hidden_states = residual + hidden_states # Residual connection
4. Final Normalization
Function: GemmaModel.forward()
hidden_states = self.norm(hidden_states) # Final RMSNorm
# Output: [batch_size, input_len, hidden_size=2304]
5. Token Selection
Function: Sampler.forward()
# Select output position
hidden_states = hidden_states.index_select(1, output_positions).squeeze(dim=1)
# Shape: [batch_size, hidden_size=2304]
# Compute logits
logits = torch.matmul(hidden_states, embedding.t())
# Shape: [batch_size, vocab_size=256000]
# Apply final logit softcapping (30.0 for 2b-v2)
logits = logits / 30.0
logits = torch.tanh(logits)
logits = logits * 30.0
# Temperature scaling (if temperature provided)
logits.div_(temperatures.unsqueeze(dim=1))
# Convert to probabilities
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Apply top-p and top-k filtering
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
# ... filtering logic ...
# Sample next token
next_token_ids = torch.multinomial(probs, num_samples=1)
Summary of Data Flow
- Embedding: Tokens → 2304-dim embeddings (normalized)
- 26 Transformer Layers: Each with:
- RMSNorm → Attention (local/global) → RMSNorm → Residual
- RMSNorm → MLP (GeGLU) → RMSNorm → Residual
- Final Norm: RMSNorm on hidden states
- Logit Generation: Project to vocabulary space
- Sampling: Softmax → Top-p/Top-k filtering → Multinomial sampling
The 2b-v2 model's unique features include:
- Alternating local (4096 window) and global attention
- Double normalization (pre/post) around MLP blocks
- Softcapping on both attention (50.0) and final logits (30.0)
- 8 attention heads with 4 KV heads (2:1 ratio)
Thinking Summary
The user is asking about the `apply_rotary_emb` function, which is used for rotary positional embeddings (RoPE) in the Gemma model. Let me look at the implementation and explain what it does.
The apply_rotary_emb
function implements Rotary Position Embeddings (RoPE), which encode positional information into the query and key vectors. Here's how it works:
Function Implementation
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""Applies the rotary embedding to the query and key tensors."""
x_ = torch.view_as_complex(
torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1),
dim=-1))
x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
-1).transpose(1, 2)
return x_out
Step-by-Step Breakdown
1. Input Reshaping
x.transpose(1, 2).float()
- Input
x
shape:[batch_size, seq_len, num_heads, head_dim]
- After transpose:
[batch_size, num_heads, seq_len, head_dim]
2. Split into Pairs
torch.chunk(x.transpose(1, 2).float(), 2, dim=-1)
- Splits the head_dim (256 for 2b-v2) into two halves
- Creates two tensors of shape
[batch_size, num_heads, seq_len, 128]
3. Create Complex Numbers
torch.view_as_complex(torch.stack(..., dim=-1))
- Stacks the two halves as real and imaginary parts
- Creates complex tensor:
[batch_size, num_heads, seq_len, 128]
(complex64) - Each pair
(x[2i], x[2i+1])
becomes a complex numberx[2i] + i*x[2i+1]
4. Apply Rotation
x_ * freqs_cis
freqs_cis
contains precomputed complex exponentials:- Multiplication in complex space performs rotation
- This rotates each 2D subspace by angle θ based on position
5. Convert Back to Real
torch.view_as_real(x_ * freqs_cis)
- Converts complex numbers back to real pairs
- Shape:
[batch_size, num_heads, seq_len, 128, 2]
6. Reshape to Original Format
torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
- Rearranges the real/imaginary pairs back to original dimension order
- Final transpose returns to original shape structure
The Mathematics
RoPE applies a rotation matrix to pairs of elements:
Where
Precomputed Frequencies
The freqs_cis
are precomputed in precompute_freqs_cis
:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # e^(i*freqs)
return freqs_cis
Why RoPE?
RoPE provides several advantages:
- Relative position encoding: The dot product between rotated queries and keys depends on their relative positions
- Long-range decay: More distant positions have less similarity
- Efficient computation: No additional parameters needed
- Extrapolation: Can handle sequences longer than training length
For Gemma 2b-v2, RoPE is applied to both queries and keys before attention computation, encoding positional information directly into the representations.