Concept
Attention is how the decoder looks back at earlier tokens. For every
position i, it asks: "given what I'm computing here, which earlier
positions matter, and how much?"
The trick is that we never compute "look back at position j" by
hand-coding it. Each position broadcasts three vectors:
- a query — what this position is asking about,
- a key — what this position can answer with,
- a value — what this position has to offer.
Then for every pair (query at i, key at j) we compute a similarity
score (dot product). Apply a softmax across the row, and you have a
probability distribution over which positions to read from. Take the
weighted average of their values, and that's the attention output.
To keep the decoder causal — so position i can only see positions
≤ i — we add a "mask" that pushes future scores to −∞ before the
softmax. Future weights end up at zero.
Try it below. The four panels are the actual values from one block of the toy decoder. Hover any query row in the scores or weights matrix to highlight the keys it pays attention to.
Maths
For a single head with head dimension d_head = d_model / n_heads:
Q = X · W_q // [S, d_head]
K = X · W_k // [S, d_head]
V = X · W_v // [S, d_head]
scores = Q · Kᵀ / √d_head // [S, S]
masked = scores + M // M = causal mask
weights = softmax(masked, dim=-1) // each row sums to 1
output = weights · V // [S, d_head]
Multi-head attention runs H of these in parallel on disjoint
slices of d_model, concatenates the per-head outputs, and projects
through W_o:
heads_h = Attention(X · W_q[h], X · W_k[h], X · W_v[h])
concat = [heads_0, heads_1, …, heads_{H-1}] // [S, d_model]
output = concat · W_o // [S, d_model]
Each head can specialise in a different relationship — one might track which token came right before, another which token shares an opening parenthesis, and so on. With this toy model and random weights you'll see less interpretable patterns, but the shape of the computation is identical.
Code
The attention head, in TypeScript, is shorter than its description:
// src/lib/transformer/attention.ts (excerpt)
export function singleHeadAttention(Q, K, V) {
const seqLen = Q.length;
const d_head = Q[0]?.length ?? 0;
const scale = 1 / Math.sqrt(Math.max(1, d_head));
const Kt = transpose(K);
const raw = matmul(Q, Kt);
const scores = raw.map((row) => row.map((v) => v * scale));
const mask = causalMask(seqLen); // -Inf above diagonal
const weights = softmaxRows(scores, mask); // row-wise stable softmax
const output = matmul(weights, V);
return { scores, weights, output };
}
Multi-head simply slices the model dimension into n_heads groups, runs
singleHeadAttention on each, concatenates, and projects through
W_o. The widget above hits /api/compute/attention, which calls the
exact function verify-maths checks against the PyTorch fixtures.
Comments
Be the first to leave a comment on this section.
Sign in (top-right) to leave a comment.