Implement every component of the transformer architecture in PyTorch from scratch. Understand why each component exists and what would happen if it were removed. Connect the mathematical description in "Attention Is All You Need" to working code.
1. Setup: The Problem We're Solving
We're building a decoder-only transformer (GPT-style) for language modeling. Given a sequence of tokens, predict the next token. After training on enough text, this objective teaches the model language structure, factual knowledge, and reasoning.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# Hyperparameters for a tiny model
VOCAB_SIZE = 50257 # GPT-2 vocabulary size
CONTEXT_LEN = 1024 # Maximum sequence length
D_MODEL = 768 # Embedding dimension
N_HEADS = 12 # Number of attention heads
D_FF = 3072 # Feed-forward inner dimension (4 * D_MODEL)
N_LAYERS = 12 # Number of transformer layers
DROPOUT = 0.1
2. Token and Positional Embeddings
Each token ID maps to a learnable embedding vector. Position information is added via learned positional embeddings:
class Embeddings(nn.Module):
def __init__(self):
super().__init__()
self.token_embed = nn.Embedding(VOCAB_SIZE, D_MODEL)
self.pos_embed = nn.Embedding(CONTEXT_LEN, D_MODEL)
self.dropout = nn.Dropout(DROPOUT)
def forward(self, x):
# x: (batch, seq_len) token IDs
B, T = x.shape
positions = torch.arange(T, device=x.device)
tok = self.token_embed(x) # (B, T, D_MODEL)
pos = self.pos_embed(positions) # (T, D_MODEL)
return self.dropout(tok + pos) # broadcast adds position to each batch
3. Scaled Dot-Product Attention
The core operation. Every token "queries" every other token, producing a weighted sum of values:
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q, K, V: (batch, heads, seq_len, d_head)
Returns: (batch, heads, seq_len, d_head)
"""
d_k = Q.size(-1)
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# scores: (batch, heads, seq_len, seq_len)
# Apply causal mask (decoder: can only attend to past tokens)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax over the last dimension
attn_weights = F.softmax(scores, dim=-1)
# Weighted sum of values
output = torch.matmul(attn_weights, V)
return output, attn_weights
Why the scaling factor? Without sqrt(d_k), dot products grow with dimension size, pushing softmax into regions of very small gradients (saturation). Scaling prevents this. At d_k=64, the typical dot product magnitude without scaling would push softmax to near-zero gradients for all but the maximum value.
4. Multi-Head Attention
class MultiHeadAttention(nn.Module):
def __init__(self):
super().__init__()
assert D_MODEL % N_HEADS == 0
self.d_head = D_MODEL // N_HEADS
self.W_Q = nn.Linear(D_MODEL, D_MODEL)
self.W_K = nn.Linear(D_MODEL, D_MODEL)
self.W_V = nn.Linear(D_MODEL, D_MODEL)
self.W_O = nn.Linear(D_MODEL, D_MODEL)
self.dropout = nn.Dropout(DROPOUT)
def forward(self, x, mask=None):
B, T, C = x.shape # batch, seq_len, d_model
# Project to Q, K, V
Q = self.W_Q(x).view(B, T, N_HEADS, self.d_head).transpose(1, 2)
K = self.W_K(x).view(B, T, N_HEADS, self.d_head).transpose(1, 2)
V = self.W_V(x).view(B, T, N_HEADS, self.d_head).transpose(1, 2)
# Each: (B, N_HEADS, T, d_head)
# Attention
out, _ = scaled_dot_product_attention(Q, K, V, mask)
# Recombine heads
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.W_O(out)
5. Feed-Forward Network
Two linear layers with a GeLU nonlinearity. This is where most of the model's "knowledge storage" happens:
class FeedForward(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(D_MODEL, D_FF),
nn.GELU(), # More than just ReLU: smoother, handles negative inputs
nn.Linear(D_FF, D_MODEL),
nn.Dropout(DROPOUT),
)
def forward(self, x):
return self.net(x)
6. Transformer Block with Pre-Norm
Modern LLMs use pre-normalization (LayerNorm before the sublayer) rather than the original paper's post-normalization. Pre-norm is more stable for deep networks:
class TransformerBlock(nn.Module):
def __init__(self):
super().__init__()
self.ln1 = nn.LayerNorm(D_MODEL)
self.attn = MultiHeadAttention()
self.ln2 = nn.LayerNorm(D_MODEL)
self.ff = FeedForward()
def forward(self, x, mask=None):
# Pre-norm + residual connection for attention
x = x + self.attn(self.ln1(x), mask)
# Pre-norm + residual connection for FFN
x = x + self.ff(self.ln2(x))
return x
7. The Full GPT Model
class GPT(nn.Module):
def __init__(self):
super().__init__()
self.embed = Embeddings()
self.blocks = nn.ModuleList([TransformerBlock() for _ in range(N_LAYERS)])
self.ln_f = nn.LayerNorm(D_MODEL)
self.lm_head = nn.Linear(D_MODEL, VOCAB_SIZE, bias=False)
# Causal mask (lower triangular)
self.register_buffer('causal_mask',
torch.tril(torch.ones(CONTEXT_LEN, CONTEXT_LEN)).view(
1, 1, CONTEXT_LEN, CONTEXT_LEN))
def forward(self, idx, targets=None):
B, T = idx.shape
x = self.embed(idx)
mask = self.causal_mask[:, :, :T, :T]
for block in self.blocks:
x = block(x, mask)
x = self.ln_f(x)
logits = self.lm_head(x) # (B, T, VOCAB_SIZE)
loss = None
if targets is not None:
loss = F.cross_entropy(
logits.view(-1, VOCAB_SIZE),
targets.view(-1)
)
return logits, loss
Count the total number of parameters in this GPT model. Compare to GPT-2 (117M, 345M, 762M, 1.5B). What configuration produces each size? What's the main scaling axis?
Remove the causal mask from the attention computation. What would the model now be able to do that it shouldn't? What type of model does this produce? (Hint: think BERT.)
Replace the learned absolute positional embeddings with RoPE (Rotary Position Embedding). The key: apply rotation matrices to Q and K before computing attention scores. Implement and compare training curves on a tiny dataset.