Lecture: The Transformer Architecture Step by Step

Learning Objectives

Implement every component of the transformer architecture in PyTorch from scratch. Understand why each component exists and what would happen if it were removed. Connect the mathematical description in "Attention Is All You Need" to working code.

1. Setup: The Problem We're Solving

We're building a decoder-only transformer (GPT-style) for language modeling. Given a sequence of tokens, predict the next token. After training on enough text, this objective teaches the model language structure, factual knowledge, and reasoning.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Hyperparameters for a tiny model
VOCAB_SIZE = 50257   # GPT-2 vocabulary size
CONTEXT_LEN = 1024   # Maximum sequence length
D_MODEL = 768        # Embedding dimension
N_HEADS = 12         # Number of attention heads
D_FF = 3072          # Feed-forward inner dimension (4 * D_MODEL)
N_LAYERS = 12        # Number of transformer layers
DROPOUT = 0.1

2. Token and Positional Embeddings

Each token ID maps to a learnable embedding vector. Position information is added via learned positional embeddings:

class Embeddings(nn.Module):
def __init__(self):
    super().__init__()
    self.token_embed = nn.Embedding(VOCAB_SIZE, D_MODEL)
    self.pos_embed = nn.Embedding(CONTEXT_LEN, D_MODEL)
    self.dropout = nn.Dropout(DROPOUT)

def forward(self, x):
    # x: (batch, seq_len) token IDs
    B, T = x.shape
    positions = torch.arange(T, device=x.device)

    tok = self.token_embed(x)           # (B, T, D_MODEL)
    pos = self.pos_embed(positions)     # (T, D_MODEL)

    return self.dropout(tok + pos)      # broadcast adds position to each batch

3. Scaled Dot-Product Attention

The core operation. Every token "queries" every other token, producing a weighted sum of values:

def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q, K, V: (batch, heads, seq_len, d_head)
Returns: (batch, heads, seq_len, d_head)
"""
d_k = Q.size(-1)

# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# scores: (batch, heads, seq_len, seq_len)

# Apply causal mask (decoder: can only attend to past tokens)
if mask is not None:
    scores = scores.masked_fill(mask == 0, float('-inf'))

# Softmax over the last dimension
attn_weights = F.softmax(scores, dim=-1)

# Weighted sum of values
output = torch.matmul(attn_weights, V)
return output, attn_weights

Why the scaling factor? Without sqrt(d_k), dot products grow with dimension size, pushing softmax into regions of very small gradients (saturation). Scaling prevents this. At d_k=64, the typical dot product magnitude without scaling would push softmax to near-zero gradients for all but the maximum value.

4. Multi-Head Attention

class MultiHeadAttention(nn.Module):
def __init__(self):
    super().__init__()
    assert D_MODEL % N_HEADS == 0
    self.d_head = D_MODEL // N_HEADS

    self.W_Q = nn.Linear(D_MODEL, D_MODEL)
    self.W_K = nn.Linear(D_MODEL, D_MODEL)
    self.W_V = nn.Linear(D_MODEL, D_MODEL)
    self.W_O = nn.Linear(D_MODEL, D_MODEL)
    self.dropout = nn.Dropout(DROPOUT)

def forward(self, x, mask=None):
    B, T, C = x.shape  # batch, seq_len, d_model

    # Project to Q, K, V
    Q = self.W_Q(x).view(B, T, N_HEADS, self.d_head).transpose(1, 2)
    K = self.W_K(x).view(B, T, N_HEADS, self.d_head).transpose(1, 2)
    V = self.W_V(x).view(B, T, N_HEADS, self.d_head).transpose(1, 2)
    # Each: (B, N_HEADS, T, d_head)

    # Attention
    out, _ = scaled_dot_product_attention(Q, K, V, mask)

    # Recombine heads
    out = out.transpose(1, 2).contiguous().view(B, T, C)

    return self.W_O(out)

5. Feed-Forward Network

Two linear layers with a GeLU nonlinearity. This is where most of the model's "knowledge storage" happens:

class FeedForward(nn.Module):
def __init__(self):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(D_MODEL, D_FF),
        nn.GELU(),          # More than just ReLU: smoother, handles negative inputs
        nn.Linear(D_FF, D_MODEL),
        nn.Dropout(DROPOUT),
    )

def forward(self, x):
    return self.net(x)

6. Transformer Block with Pre-Norm

Modern LLMs use pre-normalization (LayerNorm before the sublayer) rather than the original paper's post-normalization. Pre-norm is more stable for deep networks:

class TransformerBlock(nn.Module):
def __init__(self):
    super().__init__()
    self.ln1 = nn.LayerNorm(D_MODEL)
    self.attn = MultiHeadAttention()
    self.ln2 = nn.LayerNorm(D_MODEL)
    self.ff = FeedForward()

def forward(self, x, mask=None):
    # Pre-norm + residual connection for attention
    x = x + self.attn(self.ln1(x), mask)
    # Pre-norm + residual connection for FFN
    x = x + self.ff(self.ln2(x))
    return x

7. The Full GPT Model

class GPT(nn.Module):
def __init__(self):
    super().__init__()
    self.embed = Embeddings()
    self.blocks = nn.ModuleList([TransformerBlock() for _ in range(N_LAYERS)])
    self.ln_f = nn.LayerNorm(D_MODEL)
    self.lm_head = nn.Linear(D_MODEL, VOCAB_SIZE, bias=False)

    # Causal mask (lower triangular)
    self.register_buffer('causal_mask',
        torch.tril(torch.ones(CONTEXT_LEN, CONTEXT_LEN)).view(
            1, 1, CONTEXT_LEN, CONTEXT_LEN))

def forward(self, idx, targets=None):
    B, T = idx.shape

    x = self.embed(idx)
    mask = self.causal_mask[:, :, :T, :T]

    for block in self.blocks:
        x = block(x, mask)

    x = self.ln_f(x)
    logits = self.lm_head(x)  # (B, T, VOCAB_SIZE)

    loss = None
    if targets is not None:
        loss = F.cross_entropy(
            logits.view(-1, VOCAB_SIZE),
            targets.view(-1)
        )
    return logits, loss
Exercise 1

Count the total number of parameters in this GPT model. Compare to GPT-2 (117M, 345M, 762M, 1.5B). What configuration produces each size? What's the main scaling axis?

Exercise 2

Remove the causal mask from the attention computation. What would the model now be able to do that it shouldn't? What type of model does this produce? (Hint: think BERT.)

Exercise 3 (Advanced)

Replace the learned absolute positional embeddings with RoPE (Rotary Position Embedding). The key: apply rotation matrices to Q and K before computing attention scores. Implement and compare training curves on a tiny dataset.