gemma pytorch repo
Thread (29 posts)
Thinking Summary
The user is asking about the organization of the Gemma 2 source code based on the gemma_pytorch repository. Let me analyze the code structure and highlight how Gemma 2 is organized within the codebase. From what I can see: 1. The repository supports multiple Gemma architectures (Gemma 1, Gemma 2, and Gemma 3) 2. Gemma 2 has specific model configurations and architectural features 3. There are dedicated decoder layer classes for different architectures Let me look at the key organizational aspects:
Looking at the gemma_pytorch repository structure, here's how the Gemma 2 source code is organized:
Architecture Integration
Gemma 2 is integrated into a unified codebase that supports multiple Gemma architectures (1, 2, and 3). The code uses an Architecture
enum to differentiate between versions:
class Architecture(enum.Enum):
GEMMA_1 = 1
GEMMA_2 = 2
GEMMA_3 = 3
Configuration System
Gemma 2 has dedicated configuration functions in gemma/config.py
:
get_config_for_2b_v2()
- 2B parameter variantget_config_for_9b()
- 9B parameter variantget_config_for_27b()
- 27B parameter variant
Each configuration specifies Gemma 2's unique features:
- Architecture:
Architecture.GEMMA_2
- Pre/Post FFW Normalization:
use_pre_ffw_norm=True
,use_post_ffw_norm=True
- Logit Softcapping:
final_logit_softcapping=30.0
,attn_logit_softcapping=50.0
- Attention Types: Alternating pattern of
LOCAL_SLIDING
andGLOBAL
attention - Sliding Window:
sliding_window_size=4096
Model Implementation
The model architecture in gemma/model.py
uses conditional logic to handle different architectures:
Decoder Layer
Gemma 2 has its own decoder layer class Gemma2DecoderLayer
with specific features:
- Additional normalization layers (pre/post feedforward)
- Support for both local sliding and global attention types
- Attention logit softcapping
Model Construction
for i in range(config.num_hidden_layers):
if config.architecture == gemma_config.Architecture.GEMMA_1:
self.layers.append(GemmaDecoderLayer(config))
elif config.architecture in (Architecture.GEMMA_2, Architecture.GEMMA_3):
attn_type = config.attn_types[i % len(config.attn_types)]
self.layers.append(Gemma2DecoderLayer(config, attn_type))
Key Gemma 2 Features
- Mixed Attention Pattern: Alternates between local sliding window attention and global attention across layers
- Additional Normalizations: Extra RMSNorm layers before and after the MLP block
- Softcapping: Applied to both attention logits and final output logits
- Query Pre-attention Scaling: Custom scaling factor for certain variants (e.g., 27B model)
Variant Selection
The get_model_config()
function maps variant strings to configurations:
'2b-v2'
→ Gemma 2 2B model'9b'
→ Gemma 2 9B model'27b'
→ Gemma 2 27B model
This organization allows the codebase to support multiple Gemma architectures while sharing common components and maintaining architecture-specific features through configuration and conditional logic.
Looking at the 2b-v2 model, how does data flow from inputs through the network to the output token selection? Explain each step with references to specific functions
Thinking Summary
The user is asking about the data flow through the Gemma 2b-v2 model specifically. I need to trace through the code from input to output, referencing specific functions. Let me walk through this step by step. Looking at the code: 1. The 2b-v2 model is a Gemma 2 architecture variant 2. The main entry point appears to be the `forward` method in `GemmaForCausalLM` 3. I need to trace through embedding, model layers, and sampling Let me trace through the forward pass:
Let me trace the data flow through the Gemma 2b-v2 model from input tokens to output token selection:
1. Input Tokenization and Embedding
Function: GemmaForCausalLM.forward()
# Input: input_token_ids [batch_size, input_len]
hidden_states = self.embedder(input_token_ids)
# Output: [batch_size, input_len, hidden_size=2304]
# Normalize embeddings by sqrt(hidden_size)
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
2. Positional Encoding Setup
Function: GemmaForCausalLM.forward()
# Create frequency encodings for rotary embeddings
freqs_cis[AttentionType.LOCAL_SLIDING] = self.freqs_cis.index_select(0, input_positions)
freqs_cis[AttentionType.GLOBAL] = self.freqs_cis.index_select(0, input_positions)
3. Model Layers (26 layers for 2b-v2)
Function: GemmaModel.forward()
→ Gemma2DecoderLayer.forward()
Each layer alternates between LOCAL_SLIDING and GLOBAL attention:
3a. Pre-Attention Processing
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) # RMSNorm
3b. Self-Attention
Function: GemmaAttention.forward()
# Project to Q, K, V
qkv = self.qkv_proj(hidden_states) # Linear projection
xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Reshape for multi-head attention
# 2b-v2: 8 attention heads, 4 KV heads, head_dim=256
xq = xq.view(batch_size, -1, 8, 256)
xk = xk.view(batch_size, -1, 4, 256)
xv = xv.view(batch_size, -1, 4, 256)
# Apply rotary embeddings
xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)
xk = apply_rotary_emb(xk, freqs_cis=freqs_cis)
# Update KV cache
k_cache.index_copy_(1, kv_write_indices, xk)
v_cache.index_copy_(1, kv_write_indices, xv)
# Compute attention scores
q.mul_(self.scaling) # Scale by 1/sqrt(head_dim)
scores = torch.matmul(q, k.transpose(2, 3))
# Apply attention type-specific mask
if self.attn_type == AttentionType.LOCAL_SLIDING:
# Use sliding window of 4096 tokens
scores = scores + local_mask
else:
# Use global causal mask
scores = scores + mask
# Apply attention logit softcapping (50.0 for 2b-v2)
scores = scores / 50.0
scores = torch.tanh(scores)
scores = scores * 50.0
# Softmax and apply to values
scores = F.softmax(scores.float(), dim=-1)
output = torch.matmul(scores, v)
# Project back
output = self.o_proj(output) # Linear projection
3c. Post-Attention Processing
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states # Residual connection
3d. MLP Block
Function: GemmaMLP.forward()
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states) # Pre-norm
# MLP with GeGLU activation
gate = self.gate_proj(hidden_states) # Project to intermediate_size=9216
gate = F.gelu(gate, approximate="tanh")
up = self.up_proj(hidden_states) # Project to intermediate_size=9216
fuse = gate * up # Element-wise multiply
hidden_states = self.down_proj(fuse) # Project back to hidden_size=2304
hidden_states = self.post_feedforward_layernorm(hidden_states) # Post-norm
hidden_states = residual + hidden_states # Residual connection
4. Final Normalization
Function: GemmaModel.forward()
hidden_states = self.norm(hidden_states) # Final RMSNorm
# Output: [batch_size, input_len, hidden_size=2304]
5. Token Selection
Function: Sampler.forward()
# Select output position
hidden_states = hidden_states.index_select(1, output_positions).squeeze(dim=1)
# Shape: [batch_size, hidden_size=2304]
# Compute logits
logits = torch.matmul(hidden_states, embedding.t())
# Shape: [batch_size, vocab_size=256000]
# Apply final logit softcapping (30.0 for 2b-v2)
logits = logits / 30.0
logits = torch.tanh(logits)
logits = logits * 30.0
# Temperature scaling (if temperature provided)
logits.div_(temperatures.unsqueeze(dim=1))
# Convert to probabilities
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Apply top-p and top-k filtering
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
# ... filtering logic ...
# Sample next token
next_token_ids = torch.multinomial(probs, num_samples=1)
Summary of Data Flow
- Embedding: Tokens → 2304-dim embeddings (normalized)
- 26 Transformer Layers: Each with:
- RMSNorm → Attention (local/global) → RMSNorm → Residual
- RMSNorm → MLP (GeGLU) → RMSNorm → Residual
- Final Norm: RMSNorm on hidden states
- Logit Generation: Project to vocabulary space
- Sampling: Softmax → Top-p/Top-k filtering → Multinomial sampling
The 2b-v2 model's unique features include:
- Alternating local (4096 window) and global attention
- Double normalization (pre/post) around MLP blocks
- Softcapping on both attention (50.0) and final logits (30.0)
- 8 attention heads with 4 KV heads (2:1 ratio)
Thinking Summary
The user is asking about the "cis" in "freqs_cis". Looking at the code, I can see this is related to rotary positional embeddings (RoPE). The "cis" stands for "cosine + i*sine", which is a mathematical notation for complex numbers in polar form. Let me look at the precompute_freqs_cis function to explain this: ```python def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, rope_scaling_factor:int = 1) -> torch.Tensor: """Precomputes the frequency cis.""" freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) freqs = freqs/rope_scaling_factor t = torch.arange(end, device=freqs.device) freqs = torch.outer(t, freqs).float() freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 return freqs_cis ``` The key line is `torch.polar(torch.ones_like(freqs), freqs)` which creates complex numbers in polar form where the magnitude is 1 and the angle is the frequency. This creates e^(i*freq) = cos(freq) + i*sin(freq), which is why it's called "cis" (cosine + i*sine).
The "cis" in freqs_cis
stands for "cosine + i·sine", which is a mathematical notation for complex numbers in polar form. This is used for Rotary Position Embeddings (RoPE) in the model.
Looking at the precompute_freqs_cis
function:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
"""Precomputes the frequency cis."""
freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
The key operation is torch.polar(torch.ones_like(freqs), freqs)
, which creates complex numbers:
This is often abbreviated as "cis(θ)" in mathematics, hence freqs_cis
.
How it's used in the model:
In apply_rotary_emb
, these complex numbers are used to rotate the query and key vectors:
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""Applies the rotary embedding to the query and key tensors."""
x_ = torch.view_as_complex(
torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
# ... reshaping logic ...
return x_out
The multiplication x_ * freqs_cis
performs a rotation in the complex plane, which encodes positional information into the attention mechanism. This is more efficient than traditional positional embeddings and allows the model to extrapolate to longer sequences than it was trained on.
For Gemma 2b-v2:
- Uses head dimension of 256
- Precomputes rotations for up to 2× max sequence length (16,384 positions)
- Default theta of 10,000 (standard RoPE base frequency)