GPT-2 Source Code Notes, Part 2: Sampling and Generation (sample.py)

GPT-2 Source Code Notes, Part 2: Sampling and Generation (sample.py)

Sampling in GPT-2 is almost minimalistic. Here we look directly at top-k, temperature, and the single-token autoregressive loop that drove early generation behavior.

The sampling file in GPT-2 is a very short piece of code, but conceptually this is where a lot of modern inference habits come from. It shows temperature, top-k filtering, and the token-by-token generation loop using a KV cache.

Top-k in plain TensorFlow

def top_k_logits(logits, k):
    if k == 0:
        # no truncation
        return logits

    def _top_k():
        values, _ = tf.nn.top_k(logits, k=k)
        min_values = values[:, -1, tf.newaxis]
        return tf.where(
            logits < min_values,
            tf.ones_like(logits, dtype=logits.dtype) * -1e10,
            logits,
        )

This is simple. If k is zero, you keep all logits. Otherwise, find the kth largest per row and zero out everything below it by assigning a very negative number. After softmax those tokens basically vanish.

Top-k here is not conceptual theory. It is directly visible logic.

One forward pass with the cache

def step(hparams, tokens, past=None):
    lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE)
    logits = lm_output['logits'][:, :, :hparams.n_vocab]
    presents = lm_output['present']
    presents.set_shape(model.past_shape(hparams=hparams, batch_size=tokens.shape[0]))
    return {
        'logits': logits,
        'presents': presents,
    }

This function handles one step of autoregressive inference.
Instead of passing the entire sequence every time, GPT-2 passes only the last token plus the cache.

The cache contains past keys and values that the model has produced before. This is the original simple KV cache concept.

The generation loop

def sample_sequence(*, hparams, length, start_token=None, batch_size=None,
                     context=None, temperature=1, top_k=0):
    ...
    def body(past, prev, output):
        next_outputs = step(hparams, prev, past=past)
        logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
        logits = top_k_logits(logits, k=top_k)
        samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
        return [
            tf.concat([past, next_outputs['presents']], axis=-2),
            samples,
            tf.concat([output, samples], axis=1)
        ]

The flow is:

  1. Run the model on the last token.
  2. Take the logits for the last position.
  3. Apply temperature.
  4. Apply top-k.
  5. Sample one token from the distribution.
  6. Append that token to the output.
  7. Append the new present values to the cache.
  8. Loop.

It is surprisingly straightforward. Modern inference systems still conceptually do the same thing, just with more efficiency hacks.

Why this file is useful to understand now

This minimal version of the decoding loop is extremely informative. You can see how LLMs generate text step by step. You can see how simple the control surfaces originally were. You can trace every transformation from logits to sampled token.

When you understand this source file, every other sampling method (nucleus sampling, beam search, contrastive decoding, etc) becomes easier to think about because you can see how each one just modifies this same basic loop.