GPT-2 Source Code Notes, Part 3: The Transformer Anatomy (model.py)

GPT-2 Source Code Notes, Part 3: The Transformer Anatomy (model.py)

The GPT-2 model definition shows the transformer in its pure, uncluttered form. This article breaks down the block structure, residual pattern, MLP expansion, and weight tying.

The GPT-2 model definition in TensorFlow 1.x is another file that is small enough that you can walk through it and understand the entire architecture. It contains the canonical decoder-only transformer structure: embeddings, attention blocks, MLP blocks, residuals, layer norms, and a weight tied output head.

The transformer block

def block(x, scope, *, past, hparams):
    with tf.variable_scope(scope):
        nx = x.shape[-1].value
        a, present = attn(norm(x, 'ln_1'), 'attn', nx, past=past, hparams=hparams)
        x = x + a
        m = mlp(norm(x, 'ln_2'), 'mlp', nx*4, hparams=hparams)
        x = x + m
        return x, present

This is the core pattern.
Normalize input, self-attend, add residual. Normalize again, feed forward MLP, add residual. This is the core shape of the transformer block that basically every modern model still uses.

The MLP feedforward

def mlp(x, scope, n_state, *, hparams):
    with tf.variable_scope(scope):
        nx = x.shape[-1].value
        h = gelu(conv1d(x, 'c_fc', n_state))
        h2 = conv1d(h, 'c_proj', nx)
        return h2

GELU, a projection up, and a projection back down. This is the feedforward network used inside each block. The capacity of the transformer is driven heavily by this expansion width.

Cache shape declaration

def past_shape(*, hparams, batch_size=None, sequence=None):
    return [batch_size, hparams.n_layer, 2, hparams.n_head, sequence, hparams.n_embd // hparams.n_head]

The cache holds keys and values for each layer.
This explicitly shows the per-layer and per-head separation.
Modern paged KV systems are built on top of this same conceptual structure.

Forward pass and weight tying

presents = []
for layer, past in enumerate(pasts):
    h, present = block(h, 'h%d' % layer, past=past, hparams=hparams)
    presents.append(present)
results['present'] = tf.stack(presents, axis=1)
h = norm(h, 'ln_f')

h_flat = tf.reshape(h, [batch*sequence, hparams.n_embd])
logits = tf.matmul(h_flat, wte, transpose_b=True)
logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab])
results['logits'] = logits

After all blocks, apply a final normalization. Then project hidden states to vocabulary space by multiplying by the transpose of the embedding matrix. This is classic weight tying. It saves parameters and keeps embeddings and logits consistent.

Why this file is worth reading

This is the transformer in its most direct and unoptimized expression. No fused kernels. No flash attention. No rope. Just the basic concepts clearly spelled out in code.

If you want to understand transformers by source, this file is the correct scale to learn from. Once a person internalizes this file, understanding larger models becomes mainly about scale and optimization rather than trying to decode hidden abstractions.