Understanding Transformers: The Architecture Behind Modern AI
Attention mechanisms, positional encoding, why the 'query-key-value' metaphor actually makes sense, and how a transformer processes a sentence — step by step.
Understanding Transformers: The Architecture Behind Modern AI
GPT, Gemini, Claude, BERT — every major language model today is built on the transformer architecture. It was introduced in the 2017 paper Attention Is All You Need and it replaced RNNs almost overnight. Understanding it means understanding modern AI.
This is a bottom-up walkthrough: we start with the problem transformers solve, then build up each piece.
The Problem with Sequences
Before transformers, the standard tool for sequences was the RNN (recurrent neural network). RNNs process tokens one at a time, passing a hidden state forward:
h_t = f(h_{t-1}, x_t)
This is fundamentally sequential — you can't process token 10 until you've processed tokens 1 through 9. That means:
- No parallelism during training — slow on modern GPUs
- Long-range dependencies are lossy — information from early tokens fades as the sequence grows
Transformers solve both by processing the entire sequence at once, using attention to let every token directly look at every other token.
Self-Attention: The Core Idea
Imagine reading the sentence: "The animal didn't cross the street because it was too tired."
What does "it" refer to? As a human, you attend to "animal" — not "street". Self-attention gives a neural network a way to do this: for each token, compute a weighted sum over all other tokens, where the weights reflect relevance.
Query, Key, Value
For each token, we project its embedding into three vectors:
- Query (Q): what this token is looking for
- Key (K): what this token offers to others
- Value (V): the actual information it contributes
Then attention scores are:
The dot product measures how much each query matches each key. We scale by to prevent the dot products from getting too large (which would push softmax into regions with near-zero gradients). Then softmax turns the scores into a probability distribution — how much to attend to each token. Finally we take a weighted sum of values.
In Code
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V, mask=None):
d_k = Q.size(-1)
# Scores: (batch, heads, seq, seq)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(scores, dim=-1)
return torch.matmul(weights, V), weights
Multi-Head Attention
Running attention once is good. Running it times in parallel, with different projections, is better — each "head" can learn to attend to different types of relationships.
class MultiHeadAttention(torch.nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_k = d_model // num_heads
self.num_heads = num_heads
self.W_q = torch.nn.Linear(d_model, d_model)
self.W_k = torch.nn.Linear(d_model, d_model)
self.W_v = torch.nn.Linear(d_model, d_model)
self.W_o = torch.nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
# (batch, seq, d_model) -> (batch, heads, seq, d_k)
x = x.view(batch_size, -1, self.num_heads, self.d_k)
return x.transpose(1, 2)
def forward(self, x, mask=None):
batch_size = x.size(0)
Q = self.split_heads(self.W_q(x), batch_size)
K = self.split_heads(self.W_k(x), batch_size)
V = self.split_heads(self.W_v(x), batch_size)
attn_output, _ = scaled_dot_product_attention(Q, K, V, mask)
# Concatenate heads: (batch, seq, d_model)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.num_heads * self.d_k)
return self.W_o(attn_output)
Positional Encoding
Attention has no built-in notion of order — treats all positions symmetrically. We fix this by adding a positional encoding to each token embedding before attention:
Different frequencies encode different positional scales — low frequencies capture long-range position, high frequencies capture fine-grained order. Modern models often use rotary position embeddings (RoPE) instead, but the principle is the same.
The Full Transformer Block
One transformer block = multi-head attention + a feedforward network, with residual connections and layer normalization around each:
x = LayerNorm(x + MultiHeadAttention(x))
x = LayerNorm(x + FFN(x))
The feedforward network is just two linear layers with a nonlinearity:
class TransformerBlock(torch.nn.Module):
def __init__(self, d_model, num_heads, ff_dim, dropout=0.1):
super().__init__()
self.attn = MultiHeadAttention(d_model, num_heads)
self.ff = torch.nn.Sequential(
torch.nn.Linear(d_model, ff_dim),
torch.nn.GELU(),
torch.nn.Linear(ff_dim, d_model),
)
self.norm1 = torch.nn.LayerNorm(d_model)
self.norm2 = torch.nn.LayerNorm(d_model)
self.drop = torch.nn.Dropout(dropout)
def forward(self, x, mask=None):
x = self.norm1(x + self.drop(self.attn(x, mask)))
x = self.norm2(x + self.drop(self.ff(x)))
return x
A full model is just N of these blocks stacked.
Why Residual Connections Matter
The x + ... pattern (residual / skip connections) is not cosmetic. Without them, gradients in deep networks tend to vanish. With them, gradients have a "highway" directly back to early layers — this is what allows training 96-layer models.
Encoder vs Decoder vs Encoder-Decoder
The original transformer was encoder-decoder (for translation). Modern language models simplify:
| Architecture | Used by | What it does |
|---|---|---|
| Encoder-only | BERT, RoBERTa | Bidirectional attention, used for classification/embeddings |
| Decoder-only | GPT, Llama, Claude | Causal (masked) attention, autoregressive text generation |
| Encoder-Decoder | T5, original Transformer | Encoder reads input, decoder generates output |
In a decoder, each token can only attend to previous tokens (not future ones). This is enforced by the causal mask: set future positions to before softmax.
What Makes This Architecture Powerful
Three things:
- Parallelism — the entire sequence is processed in one matrix multiplication. Training is GPU-friendly.
- Direct connections — any token can attend to any other in a single step, regardless of distance. Long-range dependencies don't degrade.
- Expressiveness — multi-head attention learns many different relationships simultaneously; the FFN adds non-linear transformation at each position.
Scale this up with more layers, bigger , and more data — and you get GPT-4.