Numinex
← Back to all posts

Thread (29 posts)

Thinking Summary

The user is asking about the organization of the Gemma 2 source code based on the gemma_pytorch repository. Let me analyze the code structure and highlight how Gemma 2 is organized within the codebase. From what I can see: 1. The repository supports multiple Gemma architectures (Gemma 1, Gemma 2, and Gemma 3) 2. Gemma 2 has specific model configurations and architectural features 3. There are dedicated decoder layer classes for different architectures Let me look at the key organizational aspects:

Looking at the gemma_pytorch repository structure, here's how the Gemma 2 source code is organized:

Architecture Integration

Gemma 2 is integrated into a unified codebase that supports multiple Gemma architectures (1, 2, and 3). The code uses an Architecture enum to differentiate between versions:

class Architecture(enum.Enum):
    GEMMA_1 = 1
    GEMMA_2 = 2
    GEMMA_3 = 3

Configuration System

Gemma 2 has dedicated configuration functions in gemma/config.py:

  • get_config_for_2b_v2() - 2B parameter variant
  • get_config_for_9b() - 9B parameter variant
  • get_config_for_27b() - 27B parameter variant

Each configuration specifies Gemma 2's unique features:

  • Architecture: Architecture.GEMMA_2
  • Pre/Post FFW Normalization: use_pre_ffw_norm=True, use_post_ffw_norm=True
  • Logit Softcapping: final_logit_softcapping=30.0, attn_logit_softcapping=50.0
  • Attention Types: Alternating pattern of LOCAL_SLIDING and GLOBAL attention
  • Sliding Window: sliding_window_size=4096

Model Implementation

The model architecture in gemma/model.py uses conditional logic to handle different architectures:

Decoder Layer

Gemma 2 has its own decoder layer class Gemma2DecoderLayer with specific features:

  • Additional normalization layers (pre/post feedforward)
  • Support for both local sliding and global attention types
  • Attention logit softcapping

Model Construction

for i in range(config.num_hidden_layers):
    if config.architecture == gemma_config.Architecture.GEMMA_1:
        self.layers.append(GemmaDecoderLayer(config))
    elif config.architecture in (Architecture.GEMMA_2, Architecture.GEMMA_3):
        attn_type = config.attn_types[i % len(config.attn_types)]
        self.layers.append(Gemma2DecoderLayer(config, attn_type))

Key Gemma 2 Features

  1. Mixed Attention Pattern: Alternates between local sliding window attention and global attention across layers
  2. Additional Normalizations: Extra RMSNorm layers before and after the MLP block
  3. Softcapping: Applied to both attention logits and final output logits
  4. Query Pre-attention Scaling: Custom scaling factor for certain variants (e.g., 27B model)

Variant Selection

The get_model_config() function maps variant strings to configurations:

  • '2b-v2' → Gemma 2 2B model
  • '9b' → Gemma 2 9B model
  • '27b' → Gemma 2 27B model

This organization allows the codebase to support multiple Gemma architectures while sharing common components and maintaining architecture-specific features through configuration and conditional logic.

replyquoteparent
Thinking Summary

The user is asking about the data flow through the Gemma 2b-v2 model specifically. I need to trace through the code from input to output, referencing specific functions. Let me walk through this step by step. Looking at the code: 1. The 2b-v2 model is a Gemma 2 architecture variant 2. The main entry point appears to be the `forward` method in `GemmaForCausalLM` 3. I need to trace through embedding, model layers, and sampling Let me trace through the forward pass:

Let me trace the data flow through the Gemma 2b-v2 model from input tokens to output token selection:

1. Input Tokenization and Embedding

Function: GemmaForCausalLM.forward()

# Input: input_token_ids [batch_size, input_len]
hidden_states = self.embedder(input_token_ids)
# Output: [batch_size, input_len, hidden_size=2304]

# Normalize embeddings by sqrt(hidden_size)
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer

2. Positional Encoding Setup

Function: GemmaForCausalLM.forward()

# Create frequency encodings for rotary embeddings
freqs_cis[AttentionType.LOCAL_SLIDING] = self.freqs_cis.index_select(0, input_positions)
freqs_cis[AttentionType.GLOBAL] = self.freqs_cis.index_select(0, input_positions)

3. Model Layers (26 layers for 2b-v2)

Function: GemmaModel.forward()Gemma2DecoderLayer.forward()

Each layer alternates between LOCAL_SLIDING and GLOBAL attention:

3a. Pre-Attention Processing

residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)  # RMSNorm

3b. Self-Attention

Function: GemmaAttention.forward()

# Project to Q, K, V
qkv = self.qkv_proj(hidden_states)  # Linear projection
xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

# Reshape for multi-head attention
# 2b-v2: 8 attention heads, 4 KV heads, head_dim=256
xq = xq.view(batch_size, -1, 8, 256)
xk = xk.view(batch_size, -1, 4, 256)
xv = xv.view(batch_size, -1, 4, 256)

# Apply rotary embeddings
xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)
xk = apply_rotary_emb(xk, freqs_cis=freqs_cis)

# Update KV cache
k_cache.index_copy_(1, kv_write_indices, xk)
v_cache.index_copy_(1, kv_write_indices, xv)

# Compute attention scores
q.mul_(self.scaling)  # Scale by 1/sqrt(head_dim)
scores = torch.matmul(q, k.transpose(2, 3))

# Apply attention type-specific mask
if self.attn_type == AttentionType.LOCAL_SLIDING:
    # Use sliding window of 4096 tokens
    scores = scores + local_mask
else:
    # Use global causal mask
    scores = scores + mask

# Apply attention logit softcapping (50.0 for 2b-v2)
scores = scores / 50.0
scores = torch.tanh(scores)
scores = scores * 50.0

# Softmax and apply to values
scores = F.softmax(scores.float(), dim=-1)
output = torch.matmul(scores, v)

# Project back
output = self.o_proj(output)  # Linear projection

3c. Post-Attention Processing

hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states  # Residual connection

3d. MLP Block

Function: GemmaMLP.forward()

residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)  # Pre-norm

# MLP with GeGLU activation
gate = self.gate_proj(hidden_states)  # Project to intermediate_size=9216
gate = F.gelu(gate, approximate="tanh")
up = self.up_proj(hidden_states)      # Project to intermediate_size=9216
fuse = gate * up                       # Element-wise multiply
hidden_states = self.down_proj(fuse)  # Project back to hidden_size=2304

hidden_states = self.post_feedforward_layernorm(hidden_states)  # Post-norm
hidden_states = residual + hidden_states  # Residual connection

4. Final Normalization

Function: GemmaModel.forward()

hidden_states = self.norm(hidden_states)  # Final RMSNorm
# Output: [batch_size, input_len, hidden_size=2304]

5. Token Selection

Function: Sampler.forward()

# Select output position
hidden_states = hidden_states.index_select(1, output_positions).squeeze(dim=1)
# Shape: [batch_size, hidden_size=2304]

# Compute logits
logits = torch.matmul(hidden_states, embedding.t())
# Shape: [batch_size, vocab_size=256000]

# Apply final logit softcapping (30.0 for 2b-v2)
logits = logits / 30.0
logits = torch.tanh(logits)
logits = logits * 30.0

# Temperature scaling (if temperature provided)
logits.div_(temperatures.unsqueeze(dim=1))

# Convert to probabilities
probs = torch.softmax(logits, dim=-1, dtype=torch.float)

# Apply top-p and top-k filtering
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
# ... filtering logic ...

# Sample next token
next_token_ids = torch.multinomial(probs, num_samples=1)

Summary of Data Flow

  1. Embedding: Tokens → 2304-dim embeddings (normalized)
  2. 26 Transformer Layers: Each with:
    • RMSNorm → Attention (local/global) → RMSNorm → Residual
    • RMSNorm → MLP (GeGLU) → RMSNorm → Residual
  3. Final Norm: RMSNorm on hidden states
  4. Logit Generation: Project to vocabulary space
  5. Sampling: Softmax → Top-p/Top-k filtering → Multinomial sampling

The 2b-v2 model's unique features include:

  • Alternating local (4096 window) and global attention
  • Double normalization (pre/post) around MLP blocks
  • Softcapping on both attention (50.0) and final logits (30.0)
  • 8 attention heads with 4 KV heads (2:1 ratio)
replyquoteparent
Thinking Summary

The user is asking about the `apply_rotary_emb` function, which is used for rotary positional embeddings (RoPE) in the Gemma model. Let me look at the implementation and explain what it does.

The apply_rotary_emb function implements Rotary Position Embeddings (RoPE), which encode positional information into the query and key vectors. Here's how it works:

Function Implementation

def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
    """Applies the rotary embedding to the query and key tensors."""
    x_ = torch.view_as_complex(
        torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1),
                    dim=-1))
    x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
    x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
    x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
                          -1).transpose(1, 2)
    return x_out

Step-by-Step Breakdown

1. Input Reshaping

x.transpose(1, 2).float()
  • Input x shape: [batch_size, seq_len, num_heads, head_dim]
  • After transpose: [batch_size, num_heads, seq_len, head_dim]

2. Split into Pairs

torch.chunk(x.transpose(1, 2).float(), 2, dim=-1)
  • Splits the head_dim (256 for 2b-v2) into two halves
  • Creates two tensors of shape [batch_size, num_heads, seq_len, 128]

3. Create Complex Numbers

torch.view_as_complex(torch.stack(..., dim=-1))
  • Stacks the two halves as real and imaginary parts
  • Creates complex tensor: [batch_size, num_heads, seq_len, 128] (complex64)
  • Each pair (x[2i], x[2i+1]) becomes a complex number x[2i] + i*x[2i+1]

4. Apply Rotation

x_ * freqs_cis
  • freqs_cis contains precomputed complex exponentials: eiθe^{i\theta}
  • Multiplication in complex space performs rotation
  • This rotates each 2D subspace by angle θ based on position

5. Convert Back to Real

torch.view_as_real(x_ * freqs_cis)
  • Converts complex numbers back to real pairs
  • Shape: [batch_size, num_heads, seq_len, 128, 2]

6. Reshape to Original Format

torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
  • Rearranges the real/imaginary pairs back to original dimension order
  • Final transpose returns to original shape structure

The Mathematics

RoPE applies a rotation matrix to pairs of elements:

[x2ix2i+1]=[cos(θ)sin(θ)sin(θ)cos(θ)][x2ix2i+1]\begin{bmatrix} x'_{2i} \\ x'_{2i+1} \end{bmatrix} = \begin{bmatrix} \cos(\theta) & -\sin(\theta) \\ \sin(\theta) & \cos(\theta) \end{bmatrix} \begin{bmatrix} x_{2i} \\ x_{2i+1} \end{bmatrix}

Where θ=position×1100002i/d\theta = \text{position} \times \frac{1}{10000^{2i/d}}

Precomputed Frequencies

The freqs_cis are precomputed in precompute_freqs_cis:

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # e^(i*freqs)
    return freqs_cis

Why RoPE?

RoPE provides several advantages:

  1. Relative position encoding: The dot product between rotated queries and keys depends on their relative positions
  2. Long-range decay: More distant positions have less similarity
  3. Efficient computation: No additional parameters needed
  4. Extrapolation: Can handle sequences longer than training length

For Gemma 2b-v2, RoPE is applied to both queries and keys before attention computation, encoding positional information directly into the representations.

replyquoteparent