transformer-explainer
← /learn

Attention

Q/K/V, scores, mask, softmax, output. The headline act.

Concept

Attention is how the decoder looks back at earlier tokens. For every position i, it asks: "given what I'm computing here, which earlier positions matter, and how much?"

The trick is that we never compute "look back at position j" by hand-coding it. Each position broadcasts three vectors:

  • a query — what this position is asking about,
  • a key — what this position can answer with,
  • a value — what this position has to offer.

Then for every pair (query at i, key at j) we compute a similarity score (dot product). Apply a softmax across the row, and you have a probability distribution over which positions to read from. Take the weighted average of their values, and that's the attention output.

To keep the decoder causal — so position i can only see positions ≤ i — we add a "mask" that pushes future scores to −∞ before the softmax. Future weights end up at zero.

Try it below. The four panels are the actual values from one block of the toy decoder. Hover any query row in the scores or weights matrix to highlight the keys it pays attention to.

Head:

Maths

For a single head with head dimension d_head = d_model / n_heads:

Q = X · W_q              // [S, d_head]
K = X · W_k              // [S, d_head]
V = X · W_v              // [S, d_head]
scores = Q · Kᵀ / √d_head            // [S, S]
masked = scores + M                  // M = causal mask
weights = softmax(masked, dim=-1)    // each row sums to 1
output = weights · V                 // [S, d_head]

Multi-head attention runs H of these in parallel on disjoint slices of d_model, concatenates the per-head outputs, and projects through W_o:

heads_h = Attention(X · W_q[h], X · W_k[h], X · W_v[h])
concat  = [heads_0, heads_1, …, heads_{H-1}]   // [S, d_model]
output  = concat · W_o                         // [S, d_model]

Each head can specialise in a different relationship — one might track which token came right before, another which token shares an opening parenthesis, and so on. With this toy model and random weights you'll see less interpretable patterns, but the shape of the computation is identical.

Code

The attention head, in TypeScript, is shorter than its description:

// src/lib/transformer/attention.ts (excerpt)
export function singleHeadAttention(Q, K, V) {
  const seqLen = Q.length;
  const d_head = Q[0]?.length ?? 0;
  const scale = 1 / Math.sqrt(Math.max(1, d_head));

  const Kt = transpose(K);
  const raw = matmul(Q, Kt);
  const scores = raw.map((row) => row.map((v) => v * scale));

  const mask = causalMask(seqLen); // -Inf above diagonal
  const weights = softmaxRows(scores, mask); // row-wise stable softmax
  const output = matmul(weights, V);
  return { scores, weights, output };
}

Multi-head simply slices the model dimension into n_heads groups, runs singleHeadAttention on each, concatenates, and projects through W_o. The widget above hits /api/compute/attention, which calls the exact function verify-maths checks against the PyTorch fixtures.

Comments

Be the first to leave a comment on this section.

Sign in (top-right) to leave a comment.