machine-learningtransformersnlpdeep-learning

Understanding Transformers: The Architecture Behind Modern AI

Attention mechanisms, positional encoding, why the 'query-key-value' metaphor actually makes sense, and how a transformer processes a sentence — step by step.

·5 min read

Understanding Transformers: The Architecture Behind Modern AI

GPT, Gemini, Claude, BERT — every major language model today is built on the transformer architecture. It was introduced in the 2017 paper Attention Is All You Need and it replaced RNNs almost overnight. Understanding it means understanding modern AI.

This is a bottom-up walkthrough: we start with the problem transformers solve, then build up each piece.

The Problem with Sequences

Before transformers, the standard tool for sequences was the RNN (recurrent neural network). RNNs process tokens one at a time, passing a hidden state forward:

h_t = f(h_{t-1}, x_t)

This is fundamentally sequential — you can't process token 10 until you've processed tokens 1 through 9. That means:

  • No parallelism during training — slow on modern GPUs
  • Long-range dependencies are lossy — information from early tokens fades as the sequence grows

Transformers solve both by processing the entire sequence at once, using attention to let every token directly look at every other token.

Self-Attention: The Core Idea

Imagine reading the sentence: "The animal didn't cross the street because it was too tired."

What does "it" refer to? As a human, you attend to "animal" — not "street". Self-attention gives a neural network a way to do this: for each token, compute a weighted sum over all other tokens, where the weights reflect relevance.

Query, Key, Value

For each token, we project its embedding into three vectors:

  • Query (Q): what this token is looking for
  • Key (K): what this token offers to others
  • Value (V): the actual information it contributes

Then attention scores are:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V

The dot product QKTQK^T measures how much each query matches each key. We scale by dk\sqrt{d_k} to prevent the dot products from getting too large (which would push softmax into regions with near-zero gradients). Then softmax turns the scores into a probability distribution — how much to attend to each token. Finally we take a weighted sum of values.

In Code

import torch
import torch.nn.functional as F

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)
    # Scores: (batch, heads, seq, seq)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    weights = F.softmax(scores, dim=-1)
    return torch.matmul(weights, V), weights

Multi-Head Attention

Running attention once is good. Running it hh times in parallel, with different projections, is better — each "head" can learn to attend to different types of relationships.

class MultiHeadAttention(torch.nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.W_q = torch.nn.Linear(d_model, d_model)
        self.W_k = torch.nn.Linear(d_model, d_model)
        self.W_v = torch.nn.Linear(d_model, d_model)
        self.W_o = torch.nn.Linear(d_model, d_model)

    def split_heads(self, x, batch_size):
        # (batch, seq, d_model) -> (batch, heads, seq, d_k)
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        return x.transpose(1, 2)

    def forward(self, x, mask=None):
        batch_size = x.size(0)
        Q = self.split_heads(self.W_q(x), batch_size)
        K = self.split_heads(self.W_k(x), batch_size)
        V = self.split_heads(self.W_v(x), batch_size)
        attn_output, _ = scaled_dot_product_attention(Q, K, V, mask)
        # Concatenate heads: (batch, seq, d_model)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, -1, self.num_heads * self.d_k)
        return self.W_o(attn_output)

Positional Encoding

Attention has no built-in notion of order — QKTQ K^T treats all positions symmetrically. We fix this by adding a positional encoding to each token embedding before attention:

PE(pos,2i)=sin(pos100002i/dmodel)PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i+1)=cos(pos100002i/dmodel)PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)

Different frequencies encode different positional scales — low frequencies capture long-range position, high frequencies capture fine-grained order. Modern models often use rotary position embeddings (RoPE) instead, but the principle is the same.

The Full Transformer Block

One transformer block = multi-head attention + a feedforward network, with residual connections and layer normalization around each:

x = LayerNorm(x + MultiHeadAttention(x))
x = LayerNorm(x + FFN(x))

The feedforward network is just two linear layers with a nonlinearity:

class TransformerBlock(torch.nn.Module):
    def __init__(self, d_model, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ff = torch.nn.Sequential(
            torch.nn.Linear(d_model, ff_dim),
            torch.nn.GELU(),
            torch.nn.Linear(ff_dim, d_model),
        )
        self.norm1 = torch.nn.LayerNorm(d_model)
        self.norm2 = torch.nn.LayerNorm(d_model)
        self.drop = torch.nn.Dropout(dropout)

    def forward(self, x, mask=None):
        x = self.norm1(x + self.drop(self.attn(x, mask)))
        x = self.norm2(x + self.drop(self.ff(x)))
        return x

A full model is just N of these blocks stacked.

Why Residual Connections Matter

The x + ... pattern (residual / skip connections) is not cosmetic. Without them, gradients in deep networks tend to vanish. With them, gradients have a "highway" directly back to early layers — this is what allows training 96-layer models.

Encoder vs Decoder vs Encoder-Decoder

The original transformer was encoder-decoder (for translation). Modern language models simplify:

ArchitectureUsed byWhat it does
Encoder-onlyBERT, RoBERTaBidirectional attention, used for classification/embeddings
Decoder-onlyGPT, Llama, ClaudeCausal (masked) attention, autoregressive text generation
Encoder-DecoderT5, original TransformerEncoder reads input, decoder generates output

In a decoder, each token can only attend to previous tokens (not future ones). This is enforced by the causal mask: set future positions to -\infty before softmax.

What Makes This Architecture Powerful

Three things:

  1. Parallelism — the entire sequence is processed in one matrix multiplication. Training is GPU-friendly.
  2. Direct connections — any token can attend to any other in a single step, regardless of distance. Long-range dependencies don't degrade.
  3. Expressiveness — multi-head attention learns many different relationships simultaneously; the FFN adds non-linear transformation at each position.

Scale this up with more layers, bigger dmodeld_{model}, and more data — and you get GPT-4.