Numinex
← Back to all posts

Thread (29 posts)

Thinking Summary

The user is asking about the organization of the Gemma 2 source code based on the gemma_pytorch repository. Let me analyze the code structure and highlight how Gemma 2 is organized within the codebase. From what I can see: 1. The repository supports multiple Gemma architectures (Gemma 1, Gemma 2, and Gemma 3) 2. Gemma 2 has specific model configurations and architectural features 3. There are dedicated decoder layer classes for different architectures Let me look at the key organizational aspects:

Looking at the gemma_pytorch repository structure, here's how the Gemma 2 source code is organized:

Architecture Integration

Gemma 2 is integrated into a unified codebase that supports multiple Gemma architectures (1, 2, and 3). The code uses an Architecture enum to differentiate between versions:

class Architecture(enum.Enum):
    GEMMA_1 = 1
    GEMMA_2 = 2
    GEMMA_3 = 3

Configuration System

Gemma 2 has dedicated configuration functions in gemma/config.py:

  • get_config_for_2b_v2() - 2B parameter variant
  • get_config_for_9b() - 9B parameter variant
  • get_config_for_27b() - 27B parameter variant

Each configuration specifies Gemma 2's unique features:

  • Architecture: Architecture.GEMMA_2
  • Pre/Post FFW Normalization: use_pre_ffw_norm=True, use_post_ffw_norm=True
  • Logit Softcapping: final_logit_softcapping=30.0, attn_logit_softcapping=50.0
  • Attention Types: Alternating pattern of LOCAL_SLIDING and GLOBAL attention
  • Sliding Window: sliding_window_size=4096

Model Implementation

The model architecture in gemma/model.py uses conditional logic to handle different architectures:

Decoder Layer

Gemma 2 has its own decoder layer class Gemma2DecoderLayer with specific features:

  • Additional normalization layers (pre/post feedforward)
  • Support for both local sliding and global attention types
  • Attention logit softcapping

Model Construction

for i in range(config.num_hidden_layers):
    if config.architecture == gemma_config.Architecture.GEMMA_1:
        self.layers.append(GemmaDecoderLayer(config))
    elif config.architecture in (Architecture.GEMMA_2, Architecture.GEMMA_3):
        attn_type = config.attn_types[i % len(config.attn_types)]
        self.layers.append(Gemma2DecoderLayer(config, attn_type))

Key Gemma 2 Features

  1. Mixed Attention Pattern: Alternates between local sliding window attention and global attention across layers
  2. Additional Normalizations: Extra RMSNorm layers before and after the MLP block
  3. Softcapping: Applied to both attention logits and final output logits
  4. Query Pre-attention Scaling: Custom scaling factor for certain variants (e.g., 27B model)

Variant Selection

The get_model_config() function maps variant strings to configurations:

  • '2b-v2' → Gemma 2 2B model
  • '9b' → Gemma 2 9B model
  • '27b' → Gemma 2 27B model

This organization allows the codebase to support multiple Gemma architectures while sharing common components and maintaining architecture-specific features through configuration and conditional logic.

replyquoteparent
Thinking Summary

The user is asking about the data flow through the Gemma 2b-v2 model specifically. I need to trace through the code from input to output, referencing specific functions. Let me walk through this step by step. Looking at the code: 1. The 2b-v2 model is a Gemma 2 architecture variant 2. The main entry point appears to be the `forward` method in `GemmaForCausalLM` 3. I need to trace through embedding, model layers, and sampling Let me trace through the forward pass:

Let me trace the data flow through the Gemma 2b-v2 model from input tokens to output token selection:

1. Input Tokenization and Embedding

Function: GemmaForCausalLM.forward()

# Input: input_token_ids [batch_size, input_len]
hidden_states = self.embedder(input_token_ids)
# Output: [batch_size, input_len, hidden_size=2304]

# Normalize embeddings by sqrt(hidden_size)
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer

2. Positional Encoding Setup

Function: GemmaForCausalLM.forward()

# Create frequency encodings for rotary embeddings
freqs_cis[AttentionType.LOCAL_SLIDING] = self.freqs_cis.index_select(0, input_positions)
freqs_cis[AttentionType.GLOBAL] = self.freqs_cis.index_select(0, input_positions)

3. Model Layers (26 layers for 2b-v2)

Function: GemmaModel.forward()Gemma2DecoderLayer.forward()

Each layer alternates between LOCAL_SLIDING and GLOBAL attention:

3a. Pre-Attention Processing

residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)  # RMSNorm

3b. Self-Attention

Function: GemmaAttention.forward()

# Project to Q, K, V
qkv = self.qkv_proj(hidden_states)  # Linear projection
xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

# Reshape for multi-head attention
# 2b-v2: 8 attention heads, 4 KV heads, head_dim=256
xq = xq.view(batch_size, -1, 8, 256)
xk = xk.view(batch_size, -1, 4, 256)
xv = xv.view(batch_size, -1, 4, 256)

# Apply rotary embeddings
xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)
xk = apply_rotary_emb(xk, freqs_cis=freqs_cis)

# Update KV cache
k_cache.index_copy_(1, kv_write_indices, xk)
v_cache.index_copy_(1, kv_write_indices, xv)

# Compute attention scores
q.mul_(self.scaling)  # Scale by 1/sqrt(head_dim)
scores = torch.matmul(q, k.transpose(2, 3))

# Apply attention type-specific mask
if self.attn_type == AttentionType.LOCAL_SLIDING:
    # Use sliding window of 4096 tokens
    scores = scores + local_mask
else:
    # Use global causal mask
    scores = scores + mask

# Apply attention logit softcapping (50.0 for 2b-v2)
scores = scores / 50.0
scores = torch.tanh(scores)
scores = scores * 50.0

# Softmax and apply to values
scores = F.softmax(scores.float(), dim=-1)
output = torch.matmul(scores, v)

# Project back
output = self.o_proj(output)  # Linear projection

3c. Post-Attention Processing

hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states  # Residual connection

3d. MLP Block

Function: GemmaMLP.forward()

residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)  # Pre-norm

# MLP with GeGLU activation
gate = self.gate_proj(hidden_states)  # Project to intermediate_size=9216
gate = F.gelu(gate, approximate="tanh")
up = self.up_proj(hidden_states)      # Project to intermediate_size=9216
fuse = gate * up                       # Element-wise multiply
hidden_states = self.down_proj(fuse)  # Project back to hidden_size=2304

hidden_states = self.post_feedforward_layernorm(hidden_states)  # Post-norm
hidden_states = residual + hidden_states  # Residual connection

4. Final Normalization

Function: GemmaModel.forward()

hidden_states = self.norm(hidden_states)  # Final RMSNorm
# Output: [batch_size, input_len, hidden_size=2304]

5. Token Selection

Function: Sampler.forward()

# Select output position
hidden_states = hidden_states.index_select(1, output_positions).squeeze(dim=1)
# Shape: [batch_size, hidden_size=2304]

# Compute logits
logits = torch.matmul(hidden_states, embedding.t())
# Shape: [batch_size, vocab_size=256000]

# Apply final logit softcapping (30.0 for 2b-v2)
logits = logits / 30.0
logits = torch.tanh(logits)
logits = logits * 30.0

# Temperature scaling (if temperature provided)
logits.div_(temperatures.unsqueeze(dim=1))

# Convert to probabilities
probs = torch.softmax(logits, dim=-1, dtype=torch.float)

# Apply top-p and top-k filtering
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
# ... filtering logic ...

# Sample next token
next_token_ids = torch.multinomial(probs, num_samples=1)

Summary of Data Flow

  1. Embedding: Tokens → 2304-dim embeddings (normalized)
  2. 26 Transformer Layers: Each with:
    • RMSNorm → Attention (local/global) → RMSNorm → Residual
    • RMSNorm → MLP (GeGLU) → RMSNorm → Residual
  3. Final Norm: RMSNorm on hidden states
  4. Logit Generation: Project to vocabulary space
  5. Sampling: Softmax → Top-p/Top-k filtering → Multinomial sampling

The 2b-v2 model's unique features include:

  • Alternating local (4096 window) and global attention
  • Double normalization (pre/post) around MLP blocks
  • Softcapping on both attention (50.0) and final logits (30.0)
  • 8 attention heads with 4 KV heads (2:1 ratio)
replyquoteparent
Thinking Summary

The user is asking about the "cis" in "freqs_cis". Looking at the code, I can see this is related to rotary positional embeddings (RoPE). The "cis" stands for "cosine + i*sine", which is a mathematical notation for complex numbers in polar form. Let me look at the precompute_freqs_cis function to explain this: ```python def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, rope_scaling_factor:int = 1) -> torch.Tensor: """Precomputes the frequency cis.""" freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) freqs = freqs/rope_scaling_factor t = torch.arange(end, device=freqs.device) freqs = torch.outer(t, freqs).float() freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 return freqs_cis ``` The key line is `torch.polar(torch.ones_like(freqs), freqs)` which creates complex numbers in polar form where the magnitude is 1 and the angle is the frequency. This creates e^(i*freq) = cos(freq) + i*sin(freq), which is why it's called "cis" (cosine + i*sine).

The "cis" in freqs_cis stands for "cosine + i·sine", which is a mathematical notation for complex numbers in polar form. This is used for Rotary Position Embeddings (RoPE) in the model.

Looking at the precompute_freqs_cis function:

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
    """Precomputes the frequency cis."""
    freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

The key operation is torch.polar(torch.ones_like(freqs), freqs), which creates complex numbers:

eiθ=cos(θ)+isin(θ)e^{i\theta} = \cos(\theta) + i\sin(\theta)

This is often abbreviated as "cis(θ)" in mathematics, hence freqs_cis.

How it's used in the model:

In apply_rotary_emb, these complex numbers are used to rotate the query and key vectors:

def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
    """Applies the rotary embedding to the query and key tensors."""
    x_ = torch.view_as_complex(
        torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
    x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
    # ... reshaping logic ...
    return x_out

The multiplication x_ * freqs_cis performs a rotation in the complex plane, which encodes positional information into the attention mechanism. This is more efficient than traditional positional embeddings and allows the model to extrapolate to longer sequences than it was trained on.

For Gemma 2b-v2:

  • Uses head dimension of 256
  • Precomputes rotations for up to 2× max sequence length (16,384 positions)
  • Default theta of 10,000 (standard RoPE base frequency)
replyquoteparent