Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Why Build Your Own AI?

You’ve used the API. You know what temperature does, roughly. You’ve read the phrase “transformer-based language model” enough times that it no longer triggers any reaction at all, which is its own kind of problem.

Here’s the situation: you’re operating a system you don’t understand, and most of the time that’s fine. Most of the time. Then something goes wrong — the model hallucinates a confident falsehood, ignores half the context window, or refuses to follow a simple instruction for no discernible reason — and you have no mental model for why. You’re debugging a black box with a flashlight pointed in the wrong direction.

This book is the flashlight pointed the right direction.

What This Book Is

A technical guide to the internals of large language models, written for developers who are already using them. Every concept is accompanied by working code. The math is explained in English first and notation second. The humor is dry and is not your problem unless you also find it funny, in which case welcome.

By the end of this book, you will understand:

  • How text becomes numbers (and why those particular numbers)
  • Why attention has three separate matrices when one seems like it should suffice
  • What “training” actually computes, beyond the vague hand-wave about gradient descent
  • How fine-tuning differs from training from scratch
  • What happens at inference time, from the first token to the last
  • What the architecture genuinely cannot do, and why

What This Book Is Not

A paper survey. A comprehensive ML textbook. A guide to the latest models (they’ll be obsolete before you finish reading anyway). A sales pitch.

Who You Are

You write Python. You’ve called an LLM API. You can read a stack trace. You may have a rough sense that “attention” is important but aren’t sure exactly what it attends to or why it couldn’t just attend to everything equally and call it a day.

You’re about to find out that it kind of does attend to everything, but in a very specific weighted way that turns out to be the entire secret. That’s the fun of this.

How to Use This Book

Read it in order the first time. Each chapter builds on the last. The code examples are meant to be run — they’re short enough to paste into a notebook or a script, and they produce output that should make the concept click in a way prose alone won’t.

If something isn’t clear, that’s the book’s failure, not yours. The concepts are genuinely not that complicated once you strip away the intimidating notation. A transformer is a specific arrangement of matrix multiplications. That’s most of it.

A Note on Scale

The examples in this book use toy models — small architectures trained on tiny datasets, designed to fit in CPU memory and run in under a minute. Real production models have hundreds of billions of parameters and were trained on clusters of thousands of GPUs for months.

The math is identical. The engineering challenges are different. We’re covering the math.

Understanding the small version perfectly is more valuable than vaguely gesturing at the large version. Once you’ve built a tiny transformer and watched it actually learn something, what GPT-4 is doing becomes much less mysterious. It’s doing the same thing, just with more parameters, more data, and considerably more electricity.

Let’s Go

Open a Python environment. Install PyTorch if you haven’t:

pip install torch numpy tiktoken

The first chapter is about tokens, which is where every LLM interaction actually begins. Not with your prompt. With the question of how to cut your prompt into pieces the model can process.

You’ve been starting from the wrong end.

Tokens: The Atoms of Language

Before a language model can do anything with your text, it has to destroy it.

Not destructively, but fundamentally: your string of characters gets chopped into pieces called tokens, and those tokens — not the characters, not the words, not the sentences — are what the model actually processes. Everything downstream depends on this step, which is why it’s worth understanding in some detail.

Why Not Just Use Characters?

The obvious starting point: why not feed the model one character at a time? Letters are already atomic. There are only 26 of them in English (more if you count punctuation, digits, and the various ways people type “resumé”).

The problem is sequence length. “The quick brown fox” is 19 characters. If you’re processing 1,000 tokens at a time, character-level you get 1,000 characters — maybe 200 words. Token-level, you get roughly 750 words. That’s a meaningful difference when your context window is a hard limit.

More importantly, characters don’t carry meaning. The letter ‘c’ contributes nothing on its own. The model would have to learn that ‘c’, ‘a’, ‘t’ in sequence means a small furry mammal, from scratch, from data, every time. It can do this! But it’s inefficient. Better to give the model “cat” as a unit.

Why Not Just Use Words?

Words seem like the obvious answer. Dictionaries exist. There are maybe 170,000 words in English, which is manageable.

Three problems:

1. Vocabulary explosion. “Run”, “runs”, “running”, “runner”, “ran” are all different words. So are “tokenize”, “tokenized”, “tokenizing”, “tokenizer”, “tokenization”. In practice, a word-level vocabulary has to be enormous, and words not in the vocabulary become <UNK> (unknown), which is just a fancy way of losing information.

2. Different languages. A word-level tokenizer built for English is useless for Chinese, which doesn’t use spaces between words. It works badly for German, where “Donaudampfschifffahrtselektrizitätenhauptbetriebswerkbauunterbeamtengesellschaft” is a single word that means approximately “Association of subordinate officials of the head office management of the Danube steamboat electrical services.”

3. Novel words. “GPT-4o”, “tokenizer”, “TikTok” — new terms appear constantly. A fixed word vocabulary can’t adapt.

The Solution: Subword Tokenization

Modern tokenizers split text into pieces that are somewhere between characters and words. Common sequences stay together (“the”, “ing”, “tion”), rare sequences get split (“tokenization” → “token” + “ization”). This is Byte Pair Encoding (BPE), and it’s what GPT-style models use.

Here’s the key insight: BPE starts by treating every character as a token, then iteratively merges the most frequent adjacent pairs into a single token. You run this process until you have the vocabulary size you want.

Let’s implement a simple version:

from collections import Counter

def get_stats(vocab):
    """Count frequency of adjacent pairs across all words in vocab."""
    pairs = Counter()
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols) - 1):
            pairs[(symbols[i], symbols[i+1])] += freq
    return pairs

def merge_vocab(pair, vocab):
    """Merge all occurrences of the most frequent pair."""
    merged = ' '.join(pair)
    replacement = ''.join(pair)
    new_vocab = {}
    for word, freq in vocab.items():
        new_word = word.replace(merged, replacement)
        new_vocab[new_word] = freq
    return new_vocab

# Start with a tiny corpus
# Each word is represented with a space between characters
# </w> marks end-of-word
initial_vocab = {
    'l o w </w>': 5,
    'l o w e r </w>': 2,
    'n e w e s t </w>': 6,
    'w i d e s t </w>': 3,
}

vocab = initial_vocab.copy()
print("Initial vocab:")
for word, freq in vocab.items():
    print(f"  '{word}': {freq}")

print()

# Run 5 BPE merges
for i in range(5):
    pairs = get_stats(vocab)
    if not pairs:
        break
    best = max(pairs, key=pairs.get)
    print(f"Merge {i+1}: {best} (frequency: {pairs[best]})")
    vocab = merge_vocab(best, vocab)

print()
print("Final vocab:")
for word, freq in vocab.items():
    print(f"  '{word}': {freq}")

Running this produces:

Initial vocab:
  'l o w </w>': 5
  'l o w e r </w>': 2
  'n e w e s t </w>': 6
  'w i d e s t </w>': 3

Merge 1: ('e', 's') (frequency: 9)
Merge 2: ('es', 't') (frequency: 9)
Merge 3: ('est', '</w>') (frequency: 9)
Merge 4: ('l', 'o') (frequency: 7)
Merge 5: ('lo', 'w') (frequency: 7)

Final vocab:
  'low </w>': 5
  'low e r </w>': 2
  'n e w est</w>': 6
  'w i d est</w>': 3

Notice what happened: “est” got merged because it appeared frequently (“newest”, “widest”). “low” got merged because it appeared in both “low” and “lower”. The algorithm discovered structure without being told it existed.

Real Tokenizers in Practice

You won’t implement BPE from scratch for production — use tiktoken (OpenAI’s tokenizer) or the transformers library:

import tiktoken

# The tokenizer used by GPT-4
enc = tiktoken.get_encoding("cl100k_base")

text = "The quick brown fox jumped over the lazy dog."
tokens = enc.encode(text)
print(f"Text: {text!r}")
print(f"Token IDs: {tokens}")
print(f"Token count: {len(tokens)}")

# Decode each token individually to see what they are
print("\nToken breakdown:")
for token_id in tokens:
    token_bytes = enc.decode_single_token_bytes(token_id)
    try:
        token_str = token_bytes.decode('utf-8')
    except UnicodeDecodeError:
        token_str = repr(token_bytes)
    print(f"  {token_id:6d} → {token_str!r}")
Text: 'The quick brown fox jumped over the lazy dog.'
Token IDs: [791, 4062, 14198, 39935, 27096, 927, 279, 16053, 5679, 13]
Token count: 10

Token breakdown:
     791 → 'The'
    4062 → ' quick'
   14198 → ' brown'
   39935 → ' fox'
   27096 → ' jumped'
     927 → ' over'
     279 → ' the'
   16053 → ' lazy'
    5679 → ' dog'
      13 → '.'

Notice that spaces are typically part of the following token — “quick” is actually “ quick“ (with a leading space). This is why tokenization affects things like whether the model treats the first word of a sentence differently from subsequent words.

The Part That Should Surprise You

Let’s look at some tokens that behave unexpectedly:

import tiktoken
enc = tiktoken.get_encoding("cl100k_base")

# Punctuation and whitespace
examples = [
    "SolidGoldMagikarp",  # famous problematic token from early GPT
    "    ",               # four spaces
    "\n\n\n",            # three newlines
    "================",  # many equals signs
    " unfavorable",
    "unfavorable",       # same word, different leading space!
]

for text in examples:
    tokens = enc.encode(text)
    print(f"{text!r:30s} → {len(tokens)} token(s): {tokens}")
'SolidGoldMagikarp'            → 3 token(s): [45280, 11768, 74241]
'    '                         → 1 token(s): [262]
'\n\n\n'                       → 3 token(s): [198, 198, 198]
'================'             → 2 token(s): [=================]
' unfavorable'                 → 1 token(s): [45824]
'unfavorable'                  → 2 token(s): [1714, 27961]

That last one is important. " unfavorable" (with space) is a single token, but "unfavorable" (without space) is two tokens. The model has different representations for these — they’re different inputs. This is why leading spaces matter when you’re prompting, even though it looks like whitespace.

Counting Tokens (and Why You Should)

APIs charge per token. Context windows are measured in tokens. Your 4,000-word essay might fit or might not, depending on the vocabulary distribution.

Quick rule of thumb: for English prose, 1 token ≈ 4 characters ≈ 0.75 words. For code, it’s more variable — Python is fairly efficient, but languages with long keywords or unusual syntax can be costly.

import tiktoken

enc = tiktoken.get_encoding("cl100k_base")

def token_report(text, label=""):
    tokens = enc.encode(text)
    chars = len(text)
    words = len(text.split())
    print(f"{label or 'Text'}")
    print(f"  Characters: {chars}")
    print(f"  Words:      {words}")
    print(f"  Tokens:     {len(tokens)}")
    print(f"  Chars/token: {chars/len(tokens):.1f}")
    print()

token_report("""
The transformer architecture, introduced in 'Attention Is All You Need' (2017),
replaced recurrent neural networks for sequence modeling tasks. Its core mechanism,
self-attention, allows each position in a sequence to attend to all other positions
simultaneously, enabling massive parallelism during training.
""", "English prose")

token_report("""
def fibonacci(n: int) -> int:
    if n <= 1:
        return n
    a, b = 0, 1
    for _ in range(n - 1):
        a, b = b, a + b
    return b
""", "Python code")

token_report("""
SELECT u.name, COUNT(o.id) as order_count, SUM(o.total) as revenue
FROM users u
LEFT JOIN orders o ON u.id = o.user_id
WHERE u.created_at > '2024-01-01'
GROUP BY u.id, u.name
HAVING COUNT(o.id) > 0
ORDER BY revenue DESC;
""", "SQL query")

Token IDs Are Not Random

The vocabulary is fixed at training time. Token 13 is always a period (.) for GPT-4. Token 791 is always “The”. This matters because the model’s learned weights are indexed by these IDs — the embedding for token 791 is a specific row in a matrix, and it always refers to “The”.

This is also why you can’t easily add new tokens to an existing model without retraining. Adding a token means adding a new row to the embedding matrix, but the model has no learned weights for it — you’d need to train on examples that use it.

What Comes Next

You now know that your text prompt gets converted to a sequence of integers, where each integer indexes into a vocabulary of ~100,000 items. Those integers are what the model actually receives.

The next question: what does the model do with an integer? It can’t do math on “The”. It needs to turn each token ID into something more meaningful — a representation that captures semantic relationships, so that “cat” and “kitten” are somehow close together, and “cat” and “database” are far apart.

That’s what embeddings are for.

Embeddings: Meaning as Geometry

A token is a number. A number is not meaning. To get from one to the other, we need embeddings.

An embedding is a vector — a list of floating-point numbers — that represents a token in a high-dimensional space. The magic (and it genuinely is, once you see it working) is that distance and direction in this space correspond to semantic relationships. Words that mean similar things end up near each other. The direction from “king” to “queen” is approximately the same as the direction from “man” to “woman”.

You’ve probably heard this. Here’s what it actually means computationally.

The Lookup Table

At its simplest, an embedding layer is just a lookup table: a matrix of shape [vocab_size, embedding_dim] where each row is the embedding for one token.

import torch
import torch.nn as nn

vocab_size = 50257   # GPT-2's vocabulary size
embedding_dim = 768  # GPT-2's embedding dimension

# This is literally just a matrix
embedding_table = nn.Embedding(vocab_size, embedding_dim)

# Look up token ID 42
token_id = torch.tensor([42])
vec = embedding_table(token_id)
print(f"Token 42 embedding shape: {vec.shape}")
print(f"First 5 values: {vec[0, :5].detach()}")
Token 42 embedding shape: torch.Size([1, 768])
First 5 values: tensor([ 0.3251, -1.2847,  0.0923,  0.7615, -0.4381])

At initialization, these values are random. They have no meaning. The model learns the embeddings during training — the matrix is updated via gradient descent just like any other parameter. By the end of training, token embeddings that appear in similar contexts end up with similar vectors.

Why does that happen? Because the model learns to predict the next token, and tokens that appear in similar positions (after similar words, before similar words) will cause similar prediction errors, which cause similar gradient updates. Meaning emerges from co-occurrence patterns.

Geometry of Meaning

Let’s build a tiny embedding space from scratch using a real dataset to make this concrete:

import torch
import torch.nn as nn
import numpy as np
from collections import Counter

# A tiny corpus — enough to learn some structure
corpus = [
    "the cat sat on the mat",
    "the dog sat on the floor",
    "cats and dogs are animals",
    "animals eat food",
    "cats eat fish",
    "dogs eat meat",
    "fish live in water",
    "dogs live in houses",
    "cats live in houses too",
    "the king ruled the kingdom",
    "the queen ruled the kingdom",
    "the king and queen",
    "man and woman",
    "the man walked the dog",
    "the woman walked the cat",
]

# Build vocabulary
words = " ".join(corpus).split()
vocab = ["<PAD>"] + list(set(words))
word2idx = {w: i for i, w in enumerate(vocab)}
idx2word = {i: w for w, i in word2idx.items()}
V = len(vocab)
print(f"Vocabulary size: {V} words")

# Simple embedding model: predict if two words appear near each other (skip-gram style)
class SimpleEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.context = nn.Embedding(vocab_size, embed_dim)

    def forward(self, target, context):
        # dot product of target and context embeddings
        t = self.embed(target)      # [batch, dim]
        c = self.context(context)   # [batch, dim]
        return (t * c).sum(dim=1)   # [batch]

# Generate training pairs: (word, nearby_word) = positive examples
def get_pairs(corpus, window=2):
    pairs = []
    for sentence in corpus:
        words = sentence.split()
        for i, word in enumerate(words):
            for j in range(max(0, i-window), min(len(words), i+window+1)):
                if i != j:
                    pairs.append((word2idx[word], word2idx[words[j]]))
    return pairs

pairs = get_pairs(corpus)
print(f"Training pairs: {len(pairs)}")

# Train a tiny embedding model
model = SimpleEmbedding(V, embed_dim=8)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

targets = torch.tensor([p[0] for p in pairs])
contexts = torch.tensor([p[1] for p in pairs])

# Generate negative samples (random words that are NOT near the target)
def negative_sample(targets, vocab_size, k=5):
    return torch.randint(0, vocab_size, (len(targets), k))

for epoch in range(300):
    neg = negative_sample(targets, V)

    # Positive scores (should be high)
    pos_scores = model(targets, contexts)

    # Negative scores (should be low)
    neg_scores = model(
        targets.unsqueeze(1).expand_as(neg).reshape(-1),
        neg.reshape(-1)
    ).reshape(len(targets), -1)

    # Loss: positive scores up, negative scores down
    pos_loss = -torch.log(torch.sigmoid(pos_scores) + 1e-8).mean()
    neg_loss = -torch.log(1 - torch.sigmoid(neg_scores) + 1e-8).mean()
    loss = pos_loss + neg_loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f"Epoch {epoch+1}: loss = {loss.item():.4f}")

Now let’s look at the geometry:

# Extract learned embeddings
def get_embedding(word):
    idx = word2idx[word]
    with torch.no_grad():
        return model.embed(torch.tensor([idx])).squeeze().numpy()

def cosine_similarity(a, b):
    a, b = np.array(a), np.array(b)
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-8)

def most_similar(word, top_k=5):
    vec = get_embedding(word)
    scores = []
    for w in vocab:
        if w == word or w == "<PAD>":
            continue
        sim = cosine_similarity(vec, get_embedding(w))
        scores.append((w, sim))
    return sorted(scores, key=lambda x: -x[1])[:top_k]

# Check semantic neighbors
for word in ["cat", "dog", "king", "water"]:
    if word in word2idx:
        neighbors = most_similar(word)
        neighbor_str = ", ".join(f"{w}({s:.2f})" for w, s in neighbors)
        print(f"  {word:10s} → {neighbor_str}")

With only 15 sentences and 8-dimensional embeddings, you’ll see that “cat” and “dog” land near each other, and “king” and “queen” cluster together. It’s not perfect — the corpus is tiny — but the structure emerges from pure co-occurrence statistics. Nobody told the model that cats and dogs are both animals.

Cosine Similarity: The Right Distance Metric

When comparing embeddings, you almost always want cosine similarity rather than Euclidean distance:

def cosine_sim(a: torch.Tensor, b: torch.Tensor) -> float:
    """Cosine similarity between two vectors."""
    a = a.float()
    b = b.float()
    return (torch.dot(a.flatten(), b.flatten()) /
            (a.norm() * b.norm())).item()

# Why cosine, not Euclidean distance?
# Consider: "dog" might have a larger-magnitude embedding than "a"
# because "dog" appears in more varied contexts.
# Cosine similarity normalizes for magnitude — it only cares about direction.

a = torch.tensor([2.0, 4.0])   # some direction, larger magnitude
b = torch.tensor([1.0, 2.0])   # same direction, smaller magnitude
c = torch.tensor([3.0, 1.0])   # different direction

print(f"cosine(a, b) = {cosine_sim(a, b):.3f}")  # should be ~1.0 (same direction)
print(f"cosine(a, c) = {cosine_sim(a, c):.3f}")  # should be less

import torch.nn.functional as F
# PyTorch has this built in:
a_norm = F.normalize(a.unsqueeze(0), dim=1)
b_norm = F.normalize(b.unsqueeze(0), dim=1)
print(f"Using F.normalize: {(a_norm * b_norm).sum().item():.3f}")

The “King - Man + Woman = Queen” Thing

This is the famous word vector analogy. Let’s see why it works:

The intuition: the vector from “man” to “king” encodes the concept of “royalty applied to a male.” If we start from “woman” and apply the same vector offset, we should land near “queen.”

# With real embeddings (let's use a pre-trained example)
# In your own tiny model above, the relationships might be weak due to data size,
# but with proper training on real data they're clear.

def analogy(a, b, c, word2idx, idx2word, model, top_k=3):
    """Find d such that a:b :: c:d"""
    va = get_embedding(a)
    vb = get_embedding(b)
    vc = get_embedding(c)
    target = vb - va + vc  # the "offset" vector

    scores = []
    for word in vocab:
        if word in [a, b, c, "<PAD>"]:
            continue
        sim = cosine_similarity(target, get_embedding(word))
        scores.append((word, sim))
    return sorted(scores, key=lambda x: -x[1])[:top_k]

# With our tiny corpus:
# "king" is to "queen" as "man" is to ?
result = analogy("king", "queen", "man", word2idx, idx2word, model)
print("king:queen :: man:?")
for word, score in result:
    print(f"  {word}: {score:.3f}")

With a tiny dataset, the results will be noisy. With embeddings trained on millions of documents, they’re remarkably clean. The point is that the mechanism is just vector arithmetic — addition and subtraction of floating-point arrays.

Positional Embeddings

Here’s something the story so far is missing: word embeddings encode what a word means, but not where it appears. “Dog bites man” and “Man bites dog” have the same tokens in different order. Order matters enormously.

Transformers handle this by adding a second embedding: a positional embedding that encodes position.

import torch
import math

def sinusoidal_positional_encoding(max_len: int, d_model: int) -> torch.Tensor:
    """
    The original transformer positional encoding from 'Attention Is All You Need'.
    Returns shape: [max_len, d_model]
    """
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

    # Frequencies decrease geometrically
    div_term = torch.exp(
        torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
    )

    # Even dimensions: sine; odd dimensions: cosine
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)

    return pe

pe = sinusoidal_positional_encoding(max_len=100, d_model=64)
print(f"Positional encoding shape: {pe.shape}")
print(f"Position 0, first 8 dims: {pe[0, :8]}")
print(f"Position 1, first 8 dims: {pe[1, :8]}")
print(f"Position 50, first 8 dims: {pe[50, :8]}")

The sinusoidal design has a nice property: PE[pos + k] can be expressed as a linear function of PE[pos], which means the model can learn to attend to “k positions ahead” without seeing that offset in training.

Modern models often use learned positional embeddings — another lookup table, this one indexed by position rather than token ID. The model learns the best positional encoding for its data. This is simpler and works well in practice.

class TokenAndPositionEmbedding(nn.Module):
    """
    Combines token embeddings and positional embeddings.
    This is the input layer of a transformer.
    """
    def __init__(self, vocab_size, d_model, max_seq_len, dropout=0.1):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)
        self.dropout = nn.Dropout(dropout)
        self.d_model = d_model

    def forward(self, x):
        # x: [batch, seq_len] of token IDs
        seq_len = x.size(1)
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0)

        tok = self.token_emb(x)       # [batch, seq_len, d_model]
        pos = self.pos_emb(positions)  # [1, seq_len, d_model]

        # Scaling by sqrt(d_model) is from the original paper.
        # It prevents token embeddings from being swamped by positional ones.
        return self.dropout(tok * math.sqrt(self.d_model) + pos)

# Test it
emb_layer = TokenAndPositionEmbedding(
    vocab_size=1000,
    d_model=64,
    max_seq_len=128
)

# A batch of 2 sequences, each 10 tokens long
batch = torch.randint(0, 1000, (2, 10))
output = emb_layer(batch)
print(f"Input shape:  {batch.shape}")
print(f"Output shape: {output.shape}")  # [2, 10, 64]

The Shape That Follows You

Notice the output shape: [batch_size, sequence_length, embedding_dim]. This is the fundamental tensor shape that flows through a transformer. Every subsequent operation — attention, feedforward layers, everything — works on tensors of this shape (or variants of it).

Keep that shape in mind. It’ll be important in the next chapter, where we finally get to attention.

Key Takeaways

  • Embeddings convert token IDs into dense vectors of floats
  • The vectors are learned during training from co-occurrence patterns
  • Semantic relationships appear as geometric relationships (distance, direction)
  • Cosine similarity is the right measure for comparing embedding vectors
  • Positional embeddings are added to preserve word-order information
  • The combined embedding is a [batch, seq_len, d_model] tensor

You’ve just built the front door of a transformer. The rest of the architecture is what happens once your tokens walk in.

Attention: The Mechanism That Changed Everything

In 2017, a team at Google published a paper called “Attention Is All You Need.” The title was a provocation — previous models used attention as a supplement to recurrent networks. The paper’s claim was that you didn’t need the recurrent networks at all. Attention alone was sufficient.

They were right, and that’s what made everything that came after possible.

The Problem Attention Solves

Before attention, sequence models processed text left to right, maintaining a hidden state that summarized everything seen so far. The problem: by the time you’re processing the 500th word, the model’s memory of the 1st word has been squished through 499 transformations. It’s like whispering a message down a line of 499 people — the original information degrades.

Attention solves this by allowing any position in the sequence to directly look at any other position. No degradation. No squishing. Direct access.

When you’re reading the word “it” in “The animal didn’t cross the street because it was too tired,” you need to know that “it” refers to “animal,” not “street.” Attention lets the model look back at “animal” and “street” simultaneously and decide which one “it” is about, based on learned patterns.

The Attention Mechanism

Here’s the core idea, stated plainly before any code:

For each position in the sequence, we want to compute a weighted average of all other positions’ representations. The weights should reflect relevance — positions that are more relevant to the current position should have higher weights.

To compute these relevance weights, we need three things:

  • A Query (Q): what the current position is looking for
  • Keys (K): what each position offers as a lookup
  • Values (V): what each position actually contributes if selected

The query-key dot product measures relevance. Softmax turns these into weights. Those weights are applied to the values. That’s attention.

If you’re thinking “why does this need three separate things, couldn’t we just compare the embeddings directly?” — that’s a good question with a good answer. Coming right up.

Self-Attention: The Full Picture

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def scaled_dot_product_attention(
    Q: torch.Tensor,
    K: torch.Tensor,
    V: torch.Tensor,
    mask: torch.Tensor = None
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    The fundamental attention operation.

    Args:
        Q: Queries  [batch, seq_len, d_k]
        K: Keys     [batch, seq_len, d_k]
        V: Values   [batch, seq_len, d_v]
        mask: Optional boolean mask (True = ignore this position)

    Returns:
        output: [batch, seq_len, d_v]
        weights: [batch, seq_len, seq_len]  (useful for visualization)
    """
    d_k = Q.size(-1)

    # Step 1: Compute attention scores
    # Q @ K^T: for each query, how much does it match each key?
    # Shape: [batch, seq_len_q, seq_len_k]
    scores = torch.matmul(Q, K.transpose(-2, -1))

    # Step 2: Scale by sqrt(d_k)
    # Without this, scores grow large as d_k grows, causing softmax to saturate.
    # Saturated softmax = near-zero gradients = model stops learning.
    scores = scores / math.sqrt(d_k)

    # Step 3: Apply mask (for causal attention — can't look into the future)
    if mask is not None:
        scores = scores.masked_fill(mask, float('-inf'))

    # Step 4: Softmax to get probabilities
    # Each query now has a probability distribution over all keys
    weights = F.softmax(scores, dim=-1)

    # Step 5: Weighted sum of values
    output = torch.matmul(weights, V)

    return output, weights

# Let's see it work on a tiny example
seq_len = 4
d_k = 8   # key/query dimension
d_v = 8   # value dimension

# Random queries, keys, values for demonstration
Q = torch.randn(1, seq_len, d_k)
K = torch.randn(1, seq_len, d_k)
V = torch.randn(1, seq_len, d_v)

output, weights = scaled_dot_product_attention(Q, K, V)

print(f"Input sequence length: {seq_len}")
print(f"Q, K shape: {Q.shape}")
print(f"V shape:    {V.shape}")
print(f"Output shape: {output.shape}")
print(f"\nAttention weights (each row sums to 1):")
print(weights[0].detach())
print(f"\nRow sums: {weights[0].sum(dim=-1).detach()}")

Why Q, K, V? (The Answer)

Let’s say you tried to do attention using the embeddings directly: compare position i to position j by taking the dot product of their embedding vectors. The problem: the embedding vector has to serve two roles simultaneously. It has to describe what this token is (to answer queries from other positions) and what this token wants (to query other positions). These can be different things.

For a concrete example: the word “bank” in “river bank” and “bank account” has the same embedding but should answer queries differently. The Q, K, V projections are learned linear transformations that let the model develop separate “search vocabularies” for querying vs. being queried.

Technically:

  • Q = X @ W_Q — transform embedding into “what am I looking for?”
  • K = X @ W_K — transform embedding into “what do I offer for lookup?”
  • V = X @ W_V — transform embedding into “what do I actually contribute?”

Where X is the input embedding and W_Q, W_K, W_V are learned weight matrices. The model learns all three matrices end-to-end.

Multi-Head Attention

One attention head looks at the sequence through one lens. Multi-head attention runs several attention operations in parallel, each with different learned projections, then combines the results.

Why? Different heads learn to attend to different types of relationships. One head might learn syntactic dependencies (subject-verb agreement). Another might learn coreference (pronoun resolution). Another might track long-range topic consistency. With a single head, these would compete. With multiple heads, they each get their own attention pattern.

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # dimension per head

        # These project the input into Q, K, V for ALL heads at once
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)  # output projection

        self.dropout = nn.Dropout(dropout)

    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
        """
        Split the last dimension into (n_heads, d_k) then transpose.
        Input:  [batch, seq_len, d_model]
        Output: [batch, n_heads, seq_len, d_k]
        """
        batch, seq_len, _ = x.shape
        x = x.view(batch, seq_len, self.n_heads, self.d_k)
        return x.transpose(1, 2)

    def combine_heads(self, x: torch.Tensor) -> torch.Tensor:
        """
        Reverse of split_heads.
        Input:  [batch, n_heads, seq_len, d_k]
        Output: [batch, seq_len, d_model]
        """
        batch, n_heads, seq_len, d_k = x.shape
        x = x.transpose(1, 2).contiguous()
        return x.view(batch, seq_len, self.d_model)

    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor = None
    ) -> torch.Tensor:
        batch, seq_len, d_model = x.shape

        # Project to Q, K, V
        Q = self.split_heads(self.W_q(x))  # [batch, n_heads, seq_len, d_k]
        K = self.split_heads(self.W_k(x))
        V = self.split_heads(self.W_v(x))

        # Scaled dot-product attention on each head
        # Q, K, V are [batch, n_heads, seq_len, d_k]
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(1).unsqueeze(2), float('-inf'))

        weights = F.softmax(scores, dim=-1)
        weights = self.dropout(weights)

        attended = torch.matmul(weights, V)  # [batch, n_heads, seq_len, d_k]

        # Combine heads and project
        combined = self.combine_heads(attended)        # [batch, seq_len, d_model]
        return self.W_o(combined)                      # [batch, seq_len, d_model]

# Test it
d_model = 64
n_heads = 8
seq_len = 10
batch_size = 2

mha = MultiHeadAttention(d_model=d_model, n_heads=n_heads)
x = torch.randn(batch_size, seq_len, d_model)
out = mha(x)

print(f"MultiHeadAttention")
print(f"  Input:  {x.shape}")
print(f"  Output: {out.shape}")
print(f"  Heads: {n_heads}, d_k per head: {d_model // n_heads}")
print(f"  Parameters: {sum(p.numel() for p in mha.parameters()):,}")

Causal Attention (For Autoregressive Models)

Language models generate text left to right. During training, when predicting position i, the model cannot use information from positions i+1, i+2, ... — those tokens don’t exist yet at inference time.

We enforce this with a causal mask: a triangular mask that prevents each position from attending to future positions.

def causal_mask(seq_len: int) -> torch.Tensor:
    """
    Returns a boolean mask where True means 'ignore this position.'
    Shape: [seq_len, seq_len]
    """
    # Upper triangle (excluding diagonal) is True (masked out)
    mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)
    return mask

mask = causal_mask(6)
print("Causal mask (True = masked out, cannot attend):")
print(mask.int())
print()
print("Reading the mask: row = query position, col = key position")
print("Position 0 can only see position 0 (itself)")
print("Position 3 can see positions 0, 1, 2, 3 (but not 4 or 5)")
Causal mask (True = masked out, cannot attend):
tensor([[0, 1, 1, 1, 1, 1],
        [0, 0, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 1, 1],
        [0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0]])

In the attention computation, masked positions get -inf before softmax. softmax(-inf) = 0, so those positions get zero weight. The model literally cannot see future tokens — not because we hide them, but because they contribute zero to the weighted average.

Visualizing Attention

Attention weights are interpretable (to a degree). Let’s visualize what a trained attention head might look like:

import torch
import matplotlib

# Simulate what attention weights might look like for a sentence
sentence = ["The", "cat", "sat", "on", "the", "mat", "."]
seq_len = len(sentence)

# In a real model, these come from the trained Q, K, V projections.
# Here we'll construct a plausible pattern for illustration.
# Imagine head 0 learned to track subject-verb relationships:
weights_head0 = torch.tensor([
    [0.8, 0.1, 0.05, 0.02, 0.01, 0.01, 0.01],  # "The" attends mostly to itself
    [0.2, 0.6, 0.1,  0.05, 0.02, 0.02, 0.01],  # "cat" attends to "The", itself
    [0.05, 0.4, 0.4, 0.1, 0.02, 0.02, 0.01],   # "sat" attends to "cat" (its subject)
    [0.05, 0.1, 0.1, 0.6, 0.1,  0.04, 0.01],   # "on" attends mostly to itself
    [0.3,  0.1, 0.1, 0.1, 0.3,  0.09, 0.01],   # "the" splits attention
    [0.05, 0.1, 0.1, 0.1, 0.1,  0.5,  0.05],   # "mat" attends to itself
    [0.1,  0.1, 0.3, 0.1, 0.1,  0.2,  0.1 ],   # "." attends broadly
])

# Check rows sum to 1
assert torch.allclose(weights_head0.sum(dim=-1), torch.ones(seq_len), atol=1e-4)

print("Attention weight matrix (row = query, col = key):")
print(f"{'':>8}", end="")
for w in sentence:
    print(f"{w:>8}", end="")
print()
for i, (query_word, row) in enumerate(zip(sentence, weights_head0)):
    print(f"{query_word:>8}", end="")
    for val in row:
        # Visual representation: shade by weight
        bar = "█" * int(val * 8)
        print(f"{val:>8.2f}", end="")
    print()

Real attention heads, once trained, reveal fascinating patterns: some track syntactic dependencies, some handle coreference, some appear to do less interpretable but apparently useful things. The field of mechanistic interpretability is dedicated to reverse-engineering what each head learned.

The Computational Complexity

Attention has quadratic complexity in sequence length: O(n²). Every position attends to every other position, and there are pairs.

This is the core scalability challenge for transformers. A sequence of 1,000 tokens requires 1,000,000 attention computations. 100,000 tokens requires 10,000,000,000. This is why long-context models are hard and expensive, and why much research has gone into approximate attention mechanisms (Longformer, FlashAttention, etc.).

def attention_flops(seq_len: int, d_model: int, n_heads: int) -> dict:
    """Estimate FLOPs for one attention layer."""
    d_k = d_model // n_heads

    # QKV projections: 3 * seq_len * d_model * d_model
    qkv_proj = 3 * seq_len * d_model * d_model

    # Attention scores: n_heads * seq_len * seq_len * d_k
    attn_scores = n_heads * seq_len * seq_len * d_k

    # Attention output: n_heads * seq_len * seq_len * d_k
    attn_output = n_heads * seq_len * seq_len * d_k

    # Output projection: seq_len * d_model * d_model
    out_proj = seq_len * d_model * d_model

    total = qkv_proj + attn_scores + attn_output + out_proj
    return {
        "qkv_projection": qkv_proj,
        "attention_scores": attn_scores,
        "attention_output": attn_output,
        "output_projection": out_proj,
        "total_flops": total,
    }

print("FLOPs breakdown for one attention layer:")
print()
for seq_len in [128, 1024, 8192]:
    flops = attention_flops(seq_len, d_model=768, n_heads=12)
    print(f"seq_len={seq_len:6d}: {flops['total_flops']/1e9:.2f}B FLOPs")
    if seq_len == 128:
        for k, v in flops.items():
            print(f"  {k}: {v:,}")

Putting It Together

Here’s the mental model you should carry forward:

  1. Each token in the sequence broadcasts a Key (what I can offer) and a Value (my actual content)
  2. Each token sends out a Query (what I’m looking for)
  3. The query from position i dot-products with all keys, giving relevance scores
  4. Softmax converts scores to weights; these weight the values
  5. Position i’s output is the weighted average of all values
  6. Multiple heads do this independently and combine results
  7. Causal models mask out future positions

The result: every position has access to context from any other position, with the model learning which positions are relevant for which tasks. No information bottleneck. No vanishing memory. Just weighted summation, done n_heads times in parallel.

That’s why attention changed everything.

Transformers: Putting It All Together

You have the pieces: embeddings turn tokens into vectors, attention lets positions look at each other, and multi-head attention runs several attention patterns in parallel. Now let’s assemble them into an actual transformer.

The full architecture has a few more components you haven’t seen yet — a feedforward network, layer normalization, and residual connections — but none of them are complicated. The transformer’s genius isn’t any single piece; it’s how the pieces fit together.

The Transformer Block

The fundamental unit of a transformer is the transformer block (sometimes called a transformer layer). A full model is just N of these blocks stacked on top of each other.

Each block does two things:

  1. Multi-head self-attention: let each position gather context from others
  2. Position-wise feedforward network: process each position independently

Both operations are wrapped in residual connections and layer normalization.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class FeedForward(nn.Module):
    """
    Position-wise feedforward network.

    Two linear transformations with a GELU activation in between.
    The inner dimension (d_ff) is typically 4x the model dimension.
    This is where most of the model's parameters actually live.
    """
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),            # smooth version of ReLU; works better in practice
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class TransformerBlock(nn.Module):
    """
    One transformer block = attention + feedforward, with residuals and layer norm.

    The "pre-norm" variant (layer norm before attention/FFN, not after) is standard
    in modern models — it trains more stably than the original paper's "post-norm."
    """
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        # Pre-norm attention with residual connection
        x = x + self.attn(self.norm1(x), mask=mask)
        # Pre-norm feedforward with residual connection
        x = x + self.ff(self.norm2(x))
        return x

Let’s pause on the residual connections: x = x + attention(x). This is from ResNets, and it solves a specific problem. Without residuals, deep networks suffer from vanishing gradients — the gradient signal diminishes as it propagates backward through many layers. With residuals, there’s always a direct path for gradients to flow through, bypassing the transformation. It lets you stack many blocks without the network degrading.

The Full Language Model

class MultiHeadAttention(nn.Module):
    """(Same implementation as the previous chapter — included for completeness.)"""
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def split_heads(self, x):
        b, s, _ = x.shape
        return x.view(b, s, self.n_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        b, h, s, dk = x.shape
        return x.transpose(1, 2).contiguous().view(b, s, self.d_model)

    def forward(self, x, mask=None):
        Q = self.split_heads(self.W_q(x))
        K = self.split_heads(self.W_k(x))
        V = self.split_heads(self.W_v(x))
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
        weights = self.dropout(F.softmax(scores, dim=-1))
        return self.W_o(self.combine_heads(torch.matmul(weights, V)))


class GPTLanguageModel(nn.Module):
    """
    A small GPT-style language model.
    Architecture:
      - Token + positional embeddings
      - N transformer blocks
      - Final layer norm
      - Linear projection to vocabulary (the "language model head")
    """
    def __init__(
        self,
        vocab_size: int,
        d_model: int,
        n_heads: int,
        n_layers: int,
        max_seq_len: int,
        d_ff: int = None,
        dropout: float = 0.1,
    ):
        super().__init__()
        d_ff = d_ff or d_model * 4

        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)
        self.dropout = nn.Dropout(dropout)

        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])

        self.norm = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

        # Weight tying: the embedding matrix and the LM head share weights.
        # This is a standard trick that reduces parameters and improves performance.
        # The intuition: if token X has a large embedding, the model should be
        # more likely to predict X, which happens when the LM head row for X
        # has a large dot product with the final hidden state.
        self.lm_head.weight = self.token_emb.weight

        # Initialize weights (standard practice from GPT-2 paper)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(
        self,
        idx: torch.Tensor,
        targets: torch.Tensor = None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """
        Args:
            idx:     [batch, seq_len] — token indices
            targets: [batch, seq_len] — target token indices (for computing loss)

        Returns:
            logits: [batch, seq_len, vocab_size]
            loss:   scalar cross-entropy loss (if targets provided, else None)
        """
        batch, seq_len = idx.shape
        device = idx.device

        # Build causal mask
        mask = torch.triu(
            torch.ones(seq_len, seq_len, dtype=torch.bool, device=device),
            diagonal=1
        )

        # Embeddings
        positions = torch.arange(seq_len, device=device)
        x = self.dropout(
            self.token_emb(idx) * math.sqrt(self.token_emb.embedding_dim)
            + self.pos_emb(positions)
        )

        # N transformer blocks
        for block in self.blocks:
            x = block(x, mask=mask)

        # Final norm and projection
        x = self.norm(x)
        logits = self.lm_head(x)  # [batch, seq_len, vocab_size]

        # Compute loss if targets provided
        loss = None
        if targets is not None:
            # Reshape for cross_entropy: [batch*seq_len, vocab_size] vs [batch*seq_len]
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=-1,  # ignore padding
            )

        return logits, loss

    @torch.no_grad()
    def generate(
        self,
        idx: torch.Tensor,
        max_new_tokens: int,
        temperature: float = 1.0,
        top_k: int = None,
    ) -> torch.Tensor:
        """
        Autoregressively generate tokens given a prompt.
        """
        max_seq_len = self.pos_emb.num_embeddings

        for _ in range(max_new_tokens):
            # Trim to max_seq_len
            idx_cond = idx[:, -max_seq_len:]

            # Forward pass
            logits, _ = self(idx_cond)

            # Take only the last position's logits (predict the next token)
            logits = logits[:, -1, :] / temperature  # [batch, vocab_size]

            # Optional top-k filtering
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float('-inf')

            # Sample from the distribution
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)  # [batch, 1]

            # Append and continue
            idx = torch.cat([idx, next_token], dim=1)

        return idx

Building a Tiny GPT and Watching It Learn

Let’s build the smallest GPT that’s still interesting: a character-level model trained on a tiny dataset.

import torch
import torch.nn.functional as F

# Our training data: a small poem-like text
text = """
To be or not to be that is the question
Whether tis nobler in the mind to suffer
The slings and arrows of outrageous fortune
Or to take arms against a sea of troubles
And by opposing end them to die to sleep
No more and by a sleep to say we end
The heartache and the thousand natural shocks
That flesh is heir to tis a consummation
""".strip() * 10  # repeat to get more training data

# Build character-level vocabulary
chars = sorted(set(text))
vocab_size = len(chars)
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for c, i in stoi.items()}

def encode(s: str) -> list[int]:
    return [stoi[c] for c in s]

def decode(ids: list[int]) -> str:
    return ''.join(itos[i] for i in ids)

print(f"Vocabulary size: {vocab_size} characters")
print(f"Characters: {chars}")
print(f"Text length: {len(text)} characters → {len(encode(text))} tokens")
print()

# Encode everything
data = torch.tensor(encode(text), dtype=torch.long)

# Train/val split
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

# Model configuration (deliberately tiny)
config = {
    "vocab_size": vocab_size,
    "d_model": 64,
    "n_heads": 4,
    "n_layers": 4,
    "max_seq_len": 64,
    "dropout": 0.1,
}

model = GPTLanguageModel(**config)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print()

def get_batch(data, batch_size=32, block_size=64):
    """Sample random batches for training."""
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

# Quick training run
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

print("Training...")
for step in range(500):
    x, y = get_batch(train_data)
    logits, loss = model(x, targets=y)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # gradient clipping
    optimizer.step()

    if step % 100 == 0:
        # Evaluate on val
        with torch.no_grad():
            xv, yv = get_batch(val_data)
            _, val_loss = model(xv, targets=yv)
        print(f"  step {step:4d}: train_loss={loss.item():.4f}, val_loss={val_loss.item():.4f}")

print()

# Generate some text
print("Generated text:")
prompt = "To be"
context = torch.tensor([encode(prompt)], dtype=torch.long)
generated = model.generate(context, max_new_tokens=200, temperature=0.8, top_k=10)
print(prompt + decode(generated[0, len(encode(prompt)):].tolist()))

After 500 steps on this tiny dataset, the model should produce something that’s vaguely English-shaped — not coherent sentences, but something that’s clearly learned that certain letters follow other letters with specific frequencies. It has genuinely learned something from your data.

Parameter Counting: Where Are the Weights?

def count_parameters(model: nn.Module) -> dict:
    """Break down parameter count by component."""
    components = {}

    total = 0
    for name, module in model.named_modules():
        if len(list(module.children())) > 0:
            continue  # skip non-leaf modules to avoid double counting
        if sum(p.numel() for p in module.parameters(recurse=False)) == 0:
            continue
        count = sum(p.numel() for p in module.parameters(recurse=False))
        total += count
        # Simplify the name for display
        short_name = name.split('.')[-2] + '.' + name.split('.')[-1] if '.' in name else name
        components[name] = count

    return components, total

# Build a slightly larger model to see the breakdown
model_medium = GPTLanguageModel(
    vocab_size=50257,  # GPT-2's vocab size
    d_model=768,
    n_heads=12,
    n_layers=12,
    max_seq_len=1024,
)

total = sum(p.numel() for p in model_medium.parameters())
print(f"A GPT-2 scale model: {total/1e6:.1f}M parameters")
print()

# Break it down
emb_params = sum(p.numel() for p in model_medium.token_emb.parameters())
pos_params = sum(p.numel() for p in model_medium.pos_emb.parameters())
block_params = sum(p.numel() for p in model_medium.blocks.parameters())
head_params = sum(p.numel() for p in model_medium.lm_head.parameters())

print(f"  Token embeddings:      {emb_params/1e6:7.2f}M  ({100*emb_params/total:.1f}%)")
print(f"  Positional embeddings: {pos_params/1e6:7.2f}M  ({100*pos_params/total:.1f}%)")
print(f"  Transformer blocks:    {block_params/1e6:7.2f}M  ({100*block_params/total:.1f}%)")
print(f"  LM head (tied):        {head_params/1e6:7.2f}M  (shared with embeddings)")
print(f"  Total unique:          {total/1e6:7.2f}M")

The feedforward layers inside each block hold the majority of the parameters — roughly 2/3. Each block’s FFN has two matrices of shape [d_model, d_ff] and [d_ff, d_model] where d_ff = 4 * d_model = 3072. That’s 2 * 768 * 3072 = 4.7M parameters per block, for 12 blocks = 56M just in FFN layers.

Data Flow Summary

Let’s trace a single forward pass explicitly:

Input: [batch=2, seq_len=10]  — token IDs, integers

1. Token embedding:     [2, 10] → [2, 10, 768]
2. Add positional emb:  [2, 10, 768] (same shape, elementwise add)
3. Dropout:             [2, 10, 768]

For each of 12 transformer blocks:
  4. LayerNorm:          [2, 10, 768]
  5. Multi-head attn:    [2, 10, 768] → [2, 10, 768]  (same shape, no sequence reduction)
  6. Residual add:       [2, 10, 768]
  7. LayerNorm:          [2, 10, 768]
  8. Feedforward:        [2, 10, 768] → [2, 10, 3072] → [2, 10, 768]
  9. Residual add:       [2, 10, 768]

10. Final LayerNorm:    [2, 10, 768]
11. LM head (linear):  [2, 10, 768] → [2, 10, 50257]  — one score per vocab item

Output: [2, 10, 50257]  — logits over vocabulary for each position

The output at position i is a probability distribution over the entire vocabulary — the model’s prediction for what token comes at position i+1. This is computed for every position simultaneously during training (thanks to the causal mask preventing peeking), which is what makes transformers trainable in parallel unlike RNNs.

The Scaled Transformer Family

The architecture above is the decoder-only transformer (GPT-style). Variations you’ll encounter:

VariantArchitectureUsed for
Decoder-only (GPT)Causal self-attentionText generation, language modeling
Encoder-only (BERT)Bidirectional self-attentionClassification, embeddings
Encoder-decoder (T5)Encoder + cross-attention decoderTranslation, summarization

In encoder-only models, there’s no causal mask — every position can attend to every other position. They’re trained differently (masked language modeling, not next-token prediction). In encoder-decoder models, the decoder cross-attends to the encoder’s output (Q from decoder, K/V from encoder).

We’re focused on decoder-only models because that’s what GPT-3, GPT-4, Claude, Llama, and most current language models are.

What You’ve Built

Let’s be specific about what your tiny model can and cannot do. After 500 steps on a 1KB text file, it has learned:

  • Character-level frequency distributions
  • Some common character sequences (it’ll rarely generate impossible sequences like “zxqjv”)
  • Very rough word-boundary patterns

It has not learned:

  • Semantics
  • Grammar
  • Any facts about the world

That’s expected. The architecture is correct. What’s missing is scale — more data, more parameters, more training steps. The transformer you’ve built and the transformer powering GPT-4 are the same architecture; GPT-4 just has orders of magnitude more of everything.

That’s the honest secret of this field: the architecture itself is not that complicated. The scale is what creates the emergent capabilities. And scale requires training — which is the next chapter.

Training Loops: Teaching a Model to Care

The model you have at initialization is random noise with a shape. The weights are sampled from a small normal distribution. Run a forward pass and you’ll get approximately uniform predictions over the vocabulary — the model genuinely has no preferences. It doesn’t know what a word is. It doesn’t know what anything is.

Training is the process of making it care. Specifically: giving it examples of text, measuring how wrong its predictions are, and adjusting its weights to make them slightly less wrong. Repeat 100 billion times. That’s the entire procedure.

The remarkable thing is that this actually works.

The Objective: Next-Token Prediction

A language model has one job: predict the next token. Given a sequence of tokens, output a probability distribution over the vocabulary for what comes next.

This is called the causal language modeling objective (or self-supervised learning, because the labels come from the data itself — no human annotation required).

For a sequence [t₁, t₂, t₃, t₄, t₅], the model is asked:

  • Given [t₁], predict t₂
  • Given [t₁, t₂], predict t₃
  • Given [t₁, t₂, t₃], predict t₄
  • Given [t₁, t₂, t₃, t₄], predict t₅

Because of the causal mask (previous chapter), the transformer can compute all these predictions in a single forward pass. The output at position i predicts position i+1. Efficient.

Cross-Entropy Loss

The model outputs logits — unnormalized scores for each vocabulary item. We convert these to probabilities with softmax, then measure how wrong we are with cross-entropy loss:

loss = -log(probability of the correct token)

If the model assigns 90% probability to the correct token: loss = -log(0.9) = 0.105 (low, good). If the model assigns 1% probability to the correct token: loss = -log(0.01) = 4.6 (high, bad). If the vocabulary has 50,000 tokens and the model is totally random: loss = -log(1/50000) = 10.8.

A freshly initialized model gets around 10.8. A well-trained model gets around 2-3 on held-out text (lower on training text, which it’s memorized to some degree).

import torch
import torch.nn.functional as F
import math

def compute_loss_example():
    """Illustrate what cross-entropy loss is measuring."""
    vocab_size = 10
    batch_size = 2
    seq_len = 5

    # Random logits (fresh model)
    logits_random = torch.randn(batch_size, seq_len, vocab_size)
    # Good model: high confidence on correct tokens
    logits_good = torch.zeros(batch_size, seq_len, vocab_size)

    targets = torch.randint(0, vocab_size, (batch_size, seq_len))

    # Set logits_good to have high values at correct positions
    for b in range(batch_size):
        for s in range(seq_len):
            logits_good[b, s, targets[b, s]] = 5.0  # strong signal for correct token

    # Compute losses
    loss_random = F.cross_entropy(
        logits_random.view(-1, vocab_size),
        targets.view(-1)
    )
    loss_good = F.cross_entropy(
        logits_good.view(-1, vocab_size),
        targets.view(-1)
    )

    print(f"Random model loss:    {loss_random.item():.3f}")
    print(f"Expected (log V):     {math.log(vocab_size):.3f}")
    print(f"Well-trained loss:    {loss_good.item():.3f}")
    print()

    # Perplexity = exp(loss), easier to interpret
    # "The model is as confused as if choosing randomly from N options"
    print(f"Random model perplexity:  {math.exp(loss_random.item()):.1f}  (≈ {vocab_size})")
    print(f"Good model perplexity:    {math.exp(loss_good.item()):.2f}")

compute_loss_example()

Perplexity is often reported instead of loss: perplexity = exp(loss). It has a nice interpretation: a perplexity of K means the model is “as surprised as if it had to choose uniformly from K options.” A perplexity of 20 on English text means the model effectively has 20 plausible next tokens at each step, even though the vocabulary is 50,000.

Gradient Descent

Given the loss, how do we update the weights to reduce it?

Gradient descent: compute the gradient of the loss with respect to every parameter, then step in the opposite direction.

The gradient tells you: “if you increase this weight by a tiny amount, the loss increases/decreases by this much.” Moving in the opposite direction of the gradient reduces the loss.

def manual_gradient_descent():
    """Gradient descent on a toy problem, made explicit."""
    # A single parameter model: y = w * x, predict y given x
    w = torch.tensor(2.0, requires_grad=True)  # starts at 2, true value is 5

    # Fake dataset: y = 5x
    x = torch.tensor([1.0, 2.0, 3.0, 4.0])
    y_true = torch.tensor([5.0, 10.0, 15.0, 20.0])

    learning_rate = 0.1

    print("Gradient descent on y = w*x (true w=5):")
    print(f"  Starting w = {w.item():.4f}")
    print()

    for step in range(10):
        # Forward pass
        y_pred = w * x
        loss = ((y_pred - y_true) ** 2).mean()  # MSE loss

        # Backward pass: compute d(loss)/d(w)
        loss.backward()

        # Update: w = w - lr * gradient
        with torch.no_grad():
            w -= learning_rate * w.grad

        # Zero the gradient (IMPORTANT: gradients accumulate by default)
        w.grad.zero_()

        print(f"  Step {step+1:2d}: w={w.item():.4f}, loss={loss.item():.4f}")

manual_gradient_descent()

The model has millions of parameters, but the principle is the same. PyTorch’s loss.backward() computes all the gradients automatically via backpropagation — the chain rule applied recursively to the computation graph. You write the forward pass; PyTorch figures out the backward pass.

A Complete Training Loop

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from dataclasses import dataclass

@dataclass
class TrainingConfig:
    # Model
    vocab_size: int = 256        # character-level for simplicity
    d_model: int = 128
    n_heads: int = 4
    n_layers: int = 4
    max_seq_len: int = 128
    dropout: float = 0.1

    # Training
    batch_size: int = 32
    learning_rate: float = 3e-4
    max_steps: int = 2000
    eval_interval: int = 200
    eval_steps: int = 50
    grad_clip: float = 1.0

    # Optimization
    weight_decay: float = 0.1
    betas: tuple = (0.9, 0.95)  # AdamW betas
    warmup_steps: int = 100


def get_lr(step: int, config: TrainingConfig) -> float:
    """
    Cosine learning rate schedule with linear warmup.
    This is the standard schedule for transformer training.
    """
    if step < config.warmup_steps:
        # Linear warmup
        return config.learning_rate * step / config.warmup_steps

    # Cosine decay
    progress = (step - config.warmup_steps) / (config.max_steps - config.warmup_steps)
    return config.learning_rate * 0.5 * (1 + math.cos(math.pi * progress))


class Trainer:
    def __init__(self, model: nn.Module, config: TrainingConfig, train_data, val_data):
        self.model = model
        self.config = config
        self.train_data = train_data
        self.val_data = val_data

        # Separate parameters into those that should and shouldn't have weight decay
        # (biases and layer norm parameters typically excluded from weight decay)
        decay_params = [p for n, p in model.named_parameters()
                       if p.dim() >= 2]  # weight matrices
        no_decay_params = [p for n, p in model.named_parameters()
                          if p.dim() < 2]  # biases, layer norms

        self.optimizer = torch.optim.AdamW([
            {'params': decay_params, 'weight_decay': config.weight_decay},
            {'params': no_decay_params, 'weight_decay': 0.0},
        ], lr=config.learning_rate, betas=config.betas)

        self.step = 0
        self.train_losses = []
        self.val_losses = []

    def get_batch(self, split: str) -> tuple[torch.Tensor, torch.Tensor]:
        data = self.train_data if split == 'train' else self.val_data
        c = self.config
        ix = torch.randint(len(data) - c.max_seq_len, (c.batch_size,))
        x = torch.stack([data[i:i+c.max_seq_len] for i in ix])
        y = torch.stack([data[i+1:i+c.max_seq_len+1] for i in ix])
        return x, y

    @torch.no_grad()
    def evaluate(self) -> float:
        """Estimate validation loss over several batches."""
        self.model.eval()
        losses = []
        for _ in range(self.config.eval_steps):
            x, y = self.get_batch('val')
            _, loss = self.model(x, targets=y)
            losses.append(loss.item())
        self.model.train()
        return np.mean(losses)

    def train_step(self) -> float:
        """One step of training."""
        x, y = self.get_batch('train')

        # Update learning rate
        lr = get_lr(self.step, self.config)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

        # Forward + backward + update
        _, loss = self.model(x, targets=y)

        self.optimizer.zero_grad(set_to_none=True)  # slightly faster than zero_grad()
        loss.backward()

        # Gradient clipping: prevents exploding gradients
        # Clips the global norm of all gradients to grad_clip
        torch.nn.utils.clip_grad_norm_(
            self.model.parameters(),
            self.config.grad_clip
        )

        self.optimizer.step()
        self.step += 1

        return loss.item()

    def train(self):
        self.model.train()
        print(f"Training for {self.config.max_steps} steps...")
        print(f"{'Step':>6} {'Train Loss':>12} {'Val Loss':>10} {'LR':>12}")
        print("-" * 45)

        running_loss = 0
        for step in range(self.config.max_steps):
            loss = self.train_step()
            running_loss += loss

            if (step + 1) % self.config.eval_interval == 0:
                avg_train_loss = running_loss / self.config.eval_interval
                val_loss = self.evaluate()
                lr = get_lr(step, self.config)

                self.train_losses.append(avg_train_loss)
                self.val_losses.append(val_loss)

                print(f"{step+1:6d} {avg_train_loss:12.4f} {val_loss:10.4f} {lr:12.2e}")
                running_loss = 0

Running a Real Training Run

Let’s put it all together and train on something:

import math

# Re-use our Shakespeare text from the Transformers chapter
text = open('/dev/stdin').read() if False else """
To be or not to be that is the question
Whether tis nobler in the mind to suffer
The slings and arrows of outrageous fortune
Or to take arms against a sea of troubles
And by opposing end them to die to sleep
No more and by a sleep to say we end
""" * 50

# Encode as bytes (gives us a clean 256-token vocab)
data = torch.tensor(list(text.encode('utf-8')), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

config = TrainingConfig(
    vocab_size=256,
    d_model=64,
    n_heads=4,
    n_layers=3,
    max_seq_len=64,
    batch_size=16,
    max_steps=1000,
    eval_interval=200,
)

# Build the model (same GPTLanguageModel from previous chapter)
model = GPTLanguageModel(
    vocab_size=config.vocab_size,
    d_model=config.d_model,
    n_heads=config.n_heads,
    n_layers=config.n_layers,
    max_seq_len=config.max_seq_len,
    dropout=config.dropout,
)

n_params = sum(p.numel() for p in model.parameters())
print(f"Model: {n_params:,} parameters")
print(f"Training data: {len(train_data):,} tokens")
print(f"Val data: {len(val_data):,} tokens")
print()

trainer = Trainer(model, config, train_data, val_data)
trainer.train()

# Generate some text
print("\nSample generation (temperature=0.8):")
prompt = b"To be"
idx = torch.tensor([list(prompt)], dtype=torch.long)
output = model.generate(idx, max_new_tokens=100, temperature=0.8, top_k=20)
print(bytes(output[0].tolist()).decode('utf-8', errors='replace'))

What Each Component of the Optimizer Does

def explain_optimizer():
    """Walk through what Adam does to each parameter."""

    # Adam (and AdamW) maintain per-parameter state:
    # - m: first moment (momentum) — exponential average of gradients
    # - v: second moment — exponential average of squared gradients

    # For a single parameter update:
    # m = beta1 * m + (1 - beta1) * grad          # gradient moving average
    # v = beta2 * v + (1 - beta2) * grad^2        # squared gradient moving average
    # m_hat = m / (1 - beta1^t)                   # bias correction
    # v_hat = v / (1 - beta2^t)
    # param = param - lr * m_hat / (sqrt(v_hat) + eps)

    # The key insight: Adam normalizes the step size per parameter.
    # A parameter that has been getting large gradients gets smaller steps.
    # A parameter that has been getting small gradients gets larger steps.
    # This adaptive learning rate is why Adam works so much better than plain SGD for transformers.

    # AdamW adds weight decay directly to the parameter (not through gradient):
    # param = param * (1 - lr * weight_decay) - lr * m_hat / (sqrt(v_hat) + eps)
    # This is the "correct" way to do weight decay with adaptive optimizers.

    beta1, beta2 = 0.9, 0.95
    eps = 1e-8

    # Simulate 10 steps for a single parameter
    param = torch.tensor(1.0)
    m = torch.tensor(0.0)
    v = torch.tensor(0.0)
    lr = 0.001

    print("Adam update trace (single parameter):")
    print(f"{'Step':>5} {'Grad':>8} {'m':>8} {'v':>8} {'update':>10} {'param':>8}")

    for t in range(1, 11):
        grad = torch.randn(1).item()  # simulate a gradient

        m = beta1 * m + (1 - beta1) * grad
        v = beta2 * v + (1 - beta2) * grad ** 2

        m_hat = m / (1 - beta1 ** t)
        v_hat = v / (1 - beta2 ** t)

        update = lr * m_hat / (v_hat ** 0.5 + eps)
        param = param - update

        print(f"{t:5d} {grad:8.4f} {m:8.4f} {v:8.4f} {-update:10.6f} {param:8.4f}")

explain_optimizer()

Diagnosing Training: Loss Curves

def plot_training_curves(train_losses, val_losses):
    """Print an ASCII loss curve for quick sanity checking."""
    if not train_losses:
        print("No data yet.")
        return

    # Find range
    all_losses = train_losses + val_losses
    min_l, max_l = min(all_losses), max(all_losses)
    height = 15
    width = min(60, len(train_losses))

    print("\nTraining curves (▓=train, ░=val):")

    for row in range(height, -1, -1):
        threshold = min_l + (row / height) * (max_l - min_l)
        line = f"{threshold:6.2f} |"
        for i in range(min(width, len(train_losses))):
            t_idx = int(i * len(train_losses) / width)
            train_char = "▓" if train_losses[t_idx] >= threshold else " "
            val_char = "░" if val_losses[t_idx] >= threshold else " "
            line += train_char if train_losses[t_idx] >= threshold else (val_char if val_losses[t_idx] >= threshold else " ")
        print(line)

    print("       " + "+" + "-" * width)
    print(f"  loss ↑         steps →")
    print(f"\n  Healthy: train ≈ val, both decreasing")
    print(f"  Overfitting: train decreasing, val increasing (need more data or regularization)")
    print(f"  Underfitting: both high and flat (need more capacity or longer training)")

# Example with simulated curves
import math
steps = 20
fake_train = [3.0 * math.exp(-0.15 * i) + 0.5 + 0.1 * (hash(i) % 100) / 100 for i in range(steps)]
fake_val = [3.2 * math.exp(-0.13 * i) + 0.7 + 0.05 * (hash(i+100) % 100) / 100 for i in range(steps)]
plot_training_curves(fake_train, fake_val)

The Dirty Secrets of Training

A few things the clean presentation above glosses over:

Gradient clipping is not optional. Without it, a single bad batch can send your gradients to infinity and destroy your model. The clip value of 1.0 is almost universal.

Learning rate matters more than almost anything else. Too high and the model diverges. Too low and it barely learns. Warmup (gradually increasing from 0) prevents instability at the start when gradients are wild.

The 4x FFN ratio is empirical. Nobody proved that d_ff = 4 * d_model is optimal. It’s just what worked and everyone kept using it.

Batch size interacts with learning rate. Doubling the batch size approximately doubles the effective learning rate. If you change batch size, adjust learning rate accordingly (linear scaling rule).

Most runs fail. Real LLM training involves constant monitoring, occasional instabilities, and sometimes full restarts. Training a model from scratch requires checkpointing every few hours and being prepared for the cluster to go down.

The fundamentals, however, are exactly what you’ve seen: loss, gradient, update. Repeated until the model stops embarrassing itself.

That’s training.

Fine-Tuning: Specializing Without Starting Over

Training a language model from scratch on the scale of GPT-4 costs somewhere in the range of $50-100 million and requires a cluster of thousands of GPUs running for months. Unless your startup has unusually aggressive compute budgets, you’re not doing that.

What you’re doing is fine-tuning: taking a pre-trained model and adapting it for a specific task by training on a much smaller dataset for a much shorter time. This works remarkably well, and understanding why it works helps you make better decisions about when and how to do it.

Why Fine-Tuning Works

Pre-training on billions of tokens of text teaches a model a tremendous amount: grammar, facts, reasoning patterns, code syntax, sentiment, style, logic. The model develops rich internal representations of language.

Fine-tuning doesn’t overwrite this knowledge — it builds on it. You’re not teaching the model what language is; you’re teaching it which aspects of its knowledge to emphasize for your particular task, and what tone/format to use.

Think of it as the difference between hiring a generalist who learns your codebase versus hiring someone fresh out of school. The generalist already has all the general skills; they just need domain-specific adaptation.

Types of Fine-Tuning

There are several approaches, with dramatically different compute requirements:

MethodWhat’s UpdatedParameters ChangedTypical Use
Full fine-tuningAll weights100%Significant behavior change
Instruction tuningAll weights100%Chat/instruction following
LoRALow-rank adapters~0.1-1%Efficient adaptation
QLoRALoRA on quantized model~0.1-1%Very low VRAM
Prompt tuningSoft prompt tokens<0.01%Minimal adaptation

Full Fine-Tuning

The simple version: take a pre-trained model, load its weights, and continue training with your dataset using a lower learning rate.

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM

def full_finetune_example():
    """
    Demonstrates the mechanics of full fine-tuning.
    (Use a tiny model for illustration — in practice, use 7B+ parameter models.)
    """
    # Load a tiny pre-trained model
    model_name = "gpt2"  # 117M parameters, fits on CPU
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)

    # Add padding token (GPT-2 doesn't have one)
    tokenizer.pad_token = tokenizer.eos_token

    # Your fine-tuning dataset — format matters
    # For instruction following: prompt + completion pairs
    training_examples = [
        {
            "prompt": "Summarize the following in one sentence:",
            "text": " Transformers use self-attention to process sequences in parallel, enabling much more efficient training than recurrent networks.",
        },
        {
            "prompt": "What is gradient descent?",
            "text": " Gradient descent is an optimization algorithm that iteratively adjusts model parameters in the direction that reduces the loss function.",
        },
        {
            "prompt": "Explain tokenization briefly:",
            "text": " Tokenization converts raw text into a sequence of integer IDs by splitting text into subword units according to a learned vocabulary.",
        },
    ]

    # Format as: "<prompt><completion>" — model learns to generate the completion
    def format_example(ex):
        return ex["prompt"] + ex["text"] + tokenizer.eos_token

    # Fine-tuning uses a lower learning rate than pre-training
    # Pre-training: ~3e-4; fine-tuning: 1e-5 to 5e-5
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)

    model.train()
    print("Fine-tuning steps:")

    for step, example in enumerate(training_examples * 3):  # 3 epochs
        text = format_example(example)
        inputs = tokenizer(
            text,
            return_tensors="pt",
            max_length=128,
            truncation=True,
        )

        # For language modeling: targets are inputs shifted by 1
        input_ids = inputs["input_ids"]
        labels = input_ids.clone()

        outputs = model(input_ids=input_ids, labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        if step % 3 == 0:
            print(f"  Step {step}: loss = {loss.item():.4f}")

    print("\nFine-tuning complete.")
    return model, tokenizer

# model, tokenizer = full_finetune_example()

The key differences from pre-training:

  1. Lower learning rate (10-100x lower) — preserves pre-trained knowledge
  2. Smaller dataset (thousands to millions of examples, not billions)
  3. Fewer steps (hours to days, not months)
  4. Task-specific data format — the model learns the format during fine-tuning

LoRA: Low-Rank Adaptation

Full fine-tuning updates every parameter. For a 7B parameter model, that’s 7 billion gradient updates per step, 7 billion gradient values to store. If you’re using Adam, that’s another 14 billion values for the optimizer state. This gets expensive.

LoRA (Low-Rank Adaptation) makes a key observation: the change in weights during fine-tuning has a low intrinsic rank. Rather than learning a full [d, d] weight update ΔW, we decompose it into two small matrices: ΔW = B × A where A is [r, d] and B is [d, r], with r << d.

If d = 4096 and r = 16, the original weight matrix has 16M parameters. The LoRA decomposition has just 2 * 16 * 4096 = 131K parameters — 120x smaller. You freeze the original weights and only train the small adapter matrices.

import torch
import torch.nn as nn
import math


class LoRALinear(nn.Module):
    """
    Linear layer with LoRA adaptation.

    The forward pass computes: y = x @ W.T + x @ A.T @ B.T * scale
    where W is frozen and A, B are the trainable LoRA parameters.
    """
    def __init__(
        self,
        original_linear: nn.Linear,
        rank: int = 16,
        alpha: float = 32.0,  # scaling factor; lora_alpha / rank = effective scale
        dropout: float = 0.0,
    ):
        super().__init__()

        self.original = original_linear
        # Freeze the original weights
        for param in self.original.parameters():
            param.requires_grad = False

        in_features = original_linear.in_features
        out_features = original_linear.out_features
        self.rank = rank
        self.scale = alpha / rank

        # LoRA matrices
        # A is initialized to random Gaussian, B to zero.
        # This means ΔW = B @ A = 0 at initialization — no change to model behavior.
        self.lora_A = nn.Parameter(
            torch.randn(rank, in_features) / math.sqrt(rank)
        )
        self.lora_B = nn.Parameter(
            torch.zeros(out_features, rank)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Original linear transformation (frozen)
        base = self.original(x)

        # LoRA adaptation
        # x: [batch, ..., in_features]
        lora = self.dropout(x) @ self.lora_A.T @ self.lora_B.T
        # lora: [batch, ..., out_features]

        return base + lora * self.scale

    @property
    def weight(self):
        """Returns the effective weight (original + LoRA adaptation)."""
        return self.original.weight + (self.lora_B @ self.lora_A) * self.scale


def apply_lora(model: nn.Module, rank: int = 16, alpha: float = 32.0,
               target_modules: list = None) -> nn.Module:
    """
    Replace target Linear layers with LoRA versions.

    In practice, LoRA is applied to the attention Q and V projections
    (sometimes K and the output projection too).
    """
    if target_modules is None:
        target_modules = ["W_q", "W_v"]  # standard: query and value projections

    for name, module in model.named_children():
        if isinstance(module, nn.Linear) and name in target_modules:
            setattr(model, name, LoRALinear(module, rank=rank, alpha=alpha))
        else:
            apply_lora(module, rank, alpha, target_modules)  # recurse

    return model


def count_trainable(model: nn.Module) -> tuple[int, int]:
    """Count trainable vs total parameters."""
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    return trainable, total


# Demonstrate LoRA on a small linear layer
original = nn.Linear(4096, 4096)
lora_layer = LoRALinear(original, rank=16, alpha=32)

orig_params = sum(p.numel() for p in original.parameters())
lora_trainable = sum(p.numel() for p in lora_layer.parameters() if p.requires_grad)

print(f"Original linear layer: {orig_params:,} parameters")
print(f"LoRA trainable parameters: {lora_trainable:,} parameters")
print(f"Reduction: {orig_params / lora_trainable:.0f}x fewer trainable parameters")
print()

# Test forward pass
x = torch.randn(2, 10, 4096)
out = lora_layer(x)
print(f"Input shape:  {x.shape}")
print(f"Output shape: {out.shape}")

# Verify that at initialization, LoRA adds nothing (B=0, so B@A=0)
with torch.no_grad():
    base_out = original(x)
    diff = (out - base_out).abs().max()
    print(f"Max difference from original at init: {diff.item():.2e}  (should be ~0)")

The LoRA Training Loop

def lora_training_example():
    """
    Shows how to set up LoRA training:
    freeze everything, then apply LoRA adapters to attention layers.
    """
    from transformers import AutoModelForCausalLM

    # Load a small base model
    model = AutoModelForCausalLM.from_pretrained("gpt2")

    # Step 1: Freeze ALL parameters
    for param in model.parameters():
        param.requires_grad = False

    # Step 2: Add LoRA to the attention Q and V projections
    # In GPT-2, attention weights are in model.transformer.h[i].attn.c_attn
    # (which is a fused QKV projection — we'll add LoRA to each layer's attention)
    rank = 8
    alpha = 16.0
    lora_params = []

    for layer in model.transformer.h:
        # GPT-2's c_attn is a single linear layer doing Q, K, V jointly
        c_attn = layer.attn.c_attn
        lora_attn = LoRALinear(c_attn, rank=rank, alpha=alpha)
        layer.attn.c_attn = lora_attn
        lora_params.extend([lora_attn.lora_A, lora_attn.lora_B])

    # Step 3: Verify parameter counts
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    print(f"Total parameters:     {total:,}")
    print(f"Trainable parameters: {trainable:,}")
    print(f"Trainable fraction:   {100 * trainable / total:.2f}%")

    # Step 4: Only pass trainable parameters to optimizer
    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=3e-4  # can use a higher LR with LoRA since only adapters are updated
    )

    return model, optimizer

# model, optimizer = lora_training_example()

Instruction Tuning: Teaching the Model to Follow Instructions

Raw pre-training teaches a model to continue text. But users want a model that answers questions, follows instructions, and has a conversational format. The transition from “text completer” to “assistant” is called instruction tuning.

The format typically looks like this:

def format_instruction_example(instruction: str, response: str) -> str:
    """
    ChatML format — used by many open-source models.
    The model learns to generate text after 'assistant'.
    """
    return f"""<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
{instruction}<|im_end|>
<|im_start|>assistant
{response}<|im_end|>"""

# During training, we only compute loss on the assistant's response
# (not on the prompt — the model shouldn't be penalized for predicting the prompt)
def mask_prompt_tokens(input_ids: torch.Tensor, tokenizer, assistant_token: str):
    """
    Returns labels where prompt tokens are -100 (ignored by cross-entropy).
    The model only learns to generate the response.
    """
    labels = input_ids.clone()

    # Find where the assistant's response starts
    assistant_ids = tokenizer.encode(assistant_token, add_special_tokens=False)

    # Mask everything before the assistant's turn
    for i, id in enumerate(input_ids[0]):
        if input_ids[0, i:i+len(assistant_ids)].tolist() == assistant_ids:
            labels[0, :i] = -100  # -100 is ignored by F.cross_entropy
            break

    return labels

# Example
example = format_instruction_example(
    "What is the capital of France?",
    "The capital of France is Paris."
)
print("Formatted instruction example:")
print(example)

RLHF: The Part That Actually Makes It Useful

Here’s something important that’s often glossed over: instruction tuning on human-written examples isn’t quite enough to get the behavior you want. Models fine-tuned this way are better at following instructions but still produce outputs that are confidently wrong, subtly harmful, or not what the user actually wanted.

RLHF (Reinforcement Learning from Human Feedback) fixes this by directly optimizing for human preferences:

  1. Collect comparisons: Show humans two model outputs and ask which is better
  2. Train a reward model: A neural network that predicts human preference scores
  3. RL optimization: Use PPO (or similar) to optimize the LLM against the reward model
  4. KL constraint: Penalize the model for drifting too far from the original (prevents reward hacking)

The KL constraint deserves emphasis: without it, the model learns to game the reward model rather than actually improving. This is the classic “reward hacking” problem in RL.

def rlhf_kl_penalty(log_probs_current, log_probs_reference, kl_coeff=0.1):
    """
    The KL penalty that keeps the RLHF-trained model close to the reference model.

    KL(current || reference) = E[log P_current - log P_reference]

    This is added as a negative reward (penalty) to prevent the model from
    finding degenerate solutions that score high on the reward model but
    produce gibberish or unsafe outputs.
    """
    kl = log_probs_current - log_probs_reference  # elementwise KL contribution
    return kl_coeff * kl.mean()

# The full RLHF reward for a generated sequence:
def rlhf_reward(reward_model_score, kl_penalty, token_kl_penalties):
    """
    total_reward = reward_model_score - sum(kl_penalties)
    """
    return reward_model_score - token_kl_penalties.sum()

Modern alternatives to RLHF include DPO (Direct Preference Optimization), which achieves similar results without the complexity of reinforcement learning — it directly fine-tunes on preference pairs.

When to Use What

Full fine-tuning when:

  • You have significant compute available
  • You’re adapting for a completely different task domain
  • You need to change the model’s fundamental behavior

LoRA when:

  • Limited GPU memory (fine-tune a 7B model on a single consumer GPU)
  • Multiple adapters needed (swap between different fine-tuned versions easily)
  • You want to share the adapter without sharing the full model weights

Instruction tuning when:

  • Base model doesn’t follow instructions well
  • You need a specific conversational format
  • You’re building an assistant product

LoRA + instruction tuning when:

  • All of the above, which is most of the time

What Fine-Tuning Cannot Do

Fine-tuning cannot add new knowledge that wasn’t in the pre-training data. It can bring out knowledge that’s already there, format it differently, and adjust behavior. It cannot teach a model to know facts it has never seen.

If you fine-tune on a dataset describing your proprietary product, the model will learn to discuss it in the right format. If you fine-tune on customer service examples, it’ll follow those conversational patterns. But it won’t develop capabilities it doesn’t have from pre-training — it won’t learn to reason better, develop new skills, or understand genuinely new concepts.

This is why retrieval-augmented generation (RAG) exists as a complement to fine-tuning. Fine-tuning for behavior; RAG for knowledge.

Inference: Running What You Built

Training is over. The weights are frozen. Now what?

Inference is the process of actually using a model to generate text. It sounds like the boring epilogue after the exciting training story, but inference has its own rich set of decisions — about how to sample, how to manage memory, how to handle the context window — that substantially affect what you get out of the model.

The difference between temperature 0.1 and temperature 1.0 can be the difference between a useful assistant and a very confident disaster.

The Autoregressive Loop

Language models generate one token at a time. To generate a 100-token response, you run 100 forward passes. Each pass takes the existing sequence (prompt + all tokens generated so far) and outputs a probability distribution. You sample from that distribution, append the result, and repeat.

import torch
import torch.nn.functional as F

@torch.no_grad()
def generate(
    model,
    prompt_tokens: list[int],
    max_new_tokens: int = 200,
    temperature: float = 1.0,
    top_k: int = None,
    top_p: float = None,
    repetition_penalty: float = 1.0,
    stop_tokens: list[int] = None,
) -> list[int]:
    """
    Full-featured autoregressive generation.
    Returns only the newly generated tokens (not the prompt).
    """
    model.eval()

    # Working context: starts with the prompt
    context = torch.tensor([prompt_tokens], dtype=torch.long)

    generated = []
    stop_tokens = set(stop_tokens or [])

    for _ in range(max_new_tokens):
        # Forward pass: get logits for the last position
        logits, _ = model(context)
        logits = logits[0, -1, :]  # [vocab_size] — prediction for NEXT token

        # Apply repetition penalty
        # Already-generated tokens get their logits divided, making them less likely
        if repetition_penalty != 1.0:
            for token_id in set(context[0].tolist()):
                if logits[token_id] > 0:
                    logits[token_id] /= repetition_penalty
                else:
                    logits[token_id] *= repetition_penalty

        # Apply temperature
        logits = logits / temperature

        # Apply top-k filtering
        if top_k is not None:
            # Zero out all but the top k logits
            top_k_val = min(top_k, logits.size(-1))
            kth_val = torch.topk(logits, top_k_val).values[-1]
            logits[logits < kth_val] = float('-inf')

        # Apply top-p (nucleus) filtering
        if top_p is not None:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

            # Remove tokens with cumulative probability above the threshold
            sorted_indices_to_remove = cumulative_probs > top_p
            # Shift right to include the token that crosses the threshold
            sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
            sorted_indices_to_remove[0] = False

            indices_to_remove = sorted_indices[sorted_indices_to_remove]
            logits[indices_to_remove] = float('-inf')

        # Sample
        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1).item()

        # Check stopping condition
        if next_token in stop_tokens:
            break

        generated.append(next_token)
        context = torch.cat([context, torch.tensor([[next_token]])], dim=1)

        # Truncate context if it exceeds max_seq_len
        if context.size(1) > model.pos_emb.num_embeddings:
            context = context[:, -model.pos_emb.num_embeddings:]

    return generated

Temperature: The Most Important Knob

Temperature controls the “randomness” of sampling. Here’s exactly what it does:

import torch
import torch.nn.functional as F

def temperature_demo():
    """Show how temperature reshapes the probability distribution."""

    # Imagine these are raw logits from the model for 5 possible next tokens
    logits = torch.tensor([3.0, 2.0, 1.0, 0.5, 0.1])
    tokens = ["the", "a", "an", "this", "some"]

    print("Effect of temperature on probability distribution:")
    print(f"{'Token':>8}", end="")
    for temp in [0.1, 0.5, 1.0, 2.0]:
        print(f"  T={temp:3.1f}", end="")
    print()
    print("-" * 50)

    for i, token in enumerate(tokens):
        print(f"{token:>8}", end="")
        for temp in [0.1, 0.5, 1.0, 2.0]:
            scaled = logits / temp
            probs = F.softmax(scaled, dim=-1)
            print(f"  {probs[i].item():5.3f}", end="")
        print()

    print()
    print("Observations:")
    print("  T=0.1: Nearly all probability on 'the' (almost deterministic)")
    print("  T=1.0: Original distribution (what the model actually learned)")
    print("  T=2.0: More uniform — model is less decisive, more 'creative'")
    print()
    print("T→0: Greedy decoding (always pick the highest-probability token)")
    print("T→∞: Uniform random sampling (complete chaos)")

temperature_demo()

Rule of thumb:

  • T < 0.5: Very focused, repetitive, safe
  • T = 0.7-0.9: Good for most creative tasks
  • T = 1.0: Model’s raw distribution
  • T > 1.0: Higher diversity, more unusual outputs, higher risk of incoherence

Greedy vs. Sampling: A Subtle Point

You might think the best strategy is always to pick the highest-probability token. It’s not.

def show_why_greedy_fails():
    """
    Greedy decoding can get stuck in degenerate loops.
    Sampling avoids this by introducing diversity.
    """
    # Imagine this simplified token sequence
    # After "The cat sat on the mat", greedy might predict:
    # "the" → "cat" → "sat" → "on" → "the" → "mat" → "the" → "cat" → ...

    # This happens because at each step, the locally optimal choice
    # leads to a globally poor sequence.

    # Sampling breaks the loop by occasionally choosing lower-probability tokens.

    print("Greedy decoding problem:")
    print("  'The cat sat on the mat the cat sat on the mat the cat...'")
    print()
    print("Sampling (temperature=0.8):")
    print("  'The cat sat on the mat and watched the birds outside.'")
    print()
    print("Greedy is best for: tasks with a single correct answer (math, code completion)")
    print("Sampling is best for: creative tasks, open-ended generation")

show_why_greedy_fails()

Top-K and Top-P: Constraining the Sample

Sampling from the full vocabulary distribution has a problem: low-probability tokens can occasionally be selected, producing genuinely incoherent outputs. Top-k and top-p filter out these tail tokens before sampling.

def compare_sampling_strategies():
    """
    Compare top-k vs top-p on a sample distribution.
    """
    # A model predicting the next word after "The capital of France is"
    # True answer is "Paris" with high probability
    vocab = ["Paris", "London", "Rome", "Berlin", "the", "a", "because",
             "therefore", "xkzq", "aaaa", "!!!"]
    logits = torch.tensor([8.0, 2.0, 2.0, 1.5, 0.5, 0.3, -1.0,
                           -2.0, -5.0, -8.0, -10.0])

    probs = F.softmax(logits, dim=-1)

    print("Original distribution:")
    for token, prob in sorted(zip(vocab, probs.tolist()), key=lambda x: -x[1]):
        bar = "█" * int(prob * 40)
        print(f"  {token:12s} {prob:6.3f} {bar}")

    print()

    # Top-k=3: only consider the top 3 tokens
    k = 3
    top_k_logits = logits.clone()
    kth = torch.topk(top_k_logits, k).values[-1]
    top_k_logits[top_k_logits < kth] = float('-inf')
    top_k_probs = F.softmax(top_k_logits, dim=-1)
    print(f"After top-k (k={k}):")
    for token, prob in zip(vocab, top_k_probs.tolist()):
        if prob > 0.001:
            print(f"  {token:12s} {prob:.3f}")

    print()

    # Top-p=0.9: include smallest set of tokens that covers 90% probability
    p = 0.9
    sorted_probs, sorted_idx = torch.sort(probs, descending=True)
    cumulative = torch.cumsum(sorted_probs, dim=-1)
    # First token where cumulative > p gets included (it just crossed the threshold)
    cutoff_idx = (cumulative > p).nonzero(as_tuple=True)[0][0].item()
    included = sorted_idx[:cutoff_idx + 1]

    top_p_logits = torch.full_like(logits, float('-inf'))
    top_p_logits[included] = logits[included]
    top_p_probs = F.softmax(top_p_logits, dim=-1)

    print(f"After top-p (p={p}):")
    for token, prob in zip(vocab, top_p_probs.tolist()):
        if prob > 0.001:
            print(f"  {token:12s} {prob:.3f}")

compare_sampling_strategies()

Top-p (nucleus sampling) is generally preferred over top-k because it adapts to the distribution. When the model is confident (distribution is peaked), top-p includes fewer tokens. When it’s uncertain (distribution is flat), top-p includes more. Top-k is a fixed cutoff that doesn’t adapt.

The KV Cache: Making Inference Not Horrible

Here’s a performance problem with naive inference: for each new token you generate, you re-compute attention over the entire sequence from scratch. Token 100 re-runs attention for tokens 1-99 plus itself. Token 101 re-runs attention for tokens 1-100. Token 500 re-runs attention for the previous 499 tokens.

That’s O(n²) total computation to generate n tokens. For a 2,000-token response, you’re doing quadratically more work than necessary.

The KV cache solves this by caching the Key and Value matrices from previous tokens:

class KVCache:
    """
    Cache for Key and Value tensors across autoregressive steps.

    During inference:
    - Step 1: process full prompt, cache all K,V tensors
    - Step 2+: process only the new token, retrieve cached K,V for context
    """
    def __init__(self, n_layers: int, max_batch_size: int = 1):
        self.n_layers = n_layers
        self.cache_k = [None] * n_layers  # cached keys per layer
        self.cache_v = [None] * n_layers  # cached values per layer

    def update(self, layer_idx: int, new_k: torch.Tensor, new_v: torch.Tensor):
        """Append new K, V to the cache for this layer."""
        if self.cache_k[layer_idx] is None:
            self.cache_k[layer_idx] = new_k
            self.cache_v[layer_idx] = new_v
        else:
            self.cache_k[layer_idx] = torch.cat([self.cache_k[layer_idx], new_k], dim=2)
            self.cache_v[layer_idx] = torch.cat([self.cache_v[layer_idx], new_v], dim=2)

    def get(self, layer_idx: int):
        return self.cache_k[layer_idx], self.cache_v[layer_idx]

    def clear(self):
        self.cache_k = [None] * self.n_layers
        self.cache_v = [None] * self.n_layers


def attention_with_cache(Q, K_new, V_new, cache: KVCache, layer_idx: int, mask=None):
    """
    Attention that uses the KV cache.

    Q: [batch, n_heads, 1, d_k]  -- only the new token's query
    K_new, V_new: [batch, n_heads, 1, d_k]  -- new key/value to append

    The cache gives us K, V for all previous tokens.
    We concatenate and run attention normally.
    """
    import math

    # Update cache with new K, V
    cache.update(layer_idx, K_new, V_new)

    # Retrieve full K, V (all tokens so far)
    K_full, V_full = cache.get(layer_idx)

    # Attention: Q against all K, weighted sum of all V
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K_full.transpose(-2, -1)) / math.sqrt(d_k)

    if mask is not None:
        scores = scores.masked_fill(mask, float('-inf'))

    weights = F.softmax(scores, dim=-1)
    return torch.matmul(weights, V_full)


def kv_cache_speedup():
    """Show the computational savings from KV caching."""
    d_model = 768
    n_heads = 12
    d_k = d_model // n_heads

    context_lengths = [100, 500, 1000, 2000]
    new_tokens = 500  # tokens to generate

    print("KV Cache speedup analysis:")
    print(f"Generating {new_tokens} tokens from contexts of various lengths")
    print()
    print(f"{'Context':>10} {'Without Cache':>16} {'With Cache':>12} {'Speedup':>10}")
    print("-" * 52)

    for ctx_len in context_lengths:
        # Without cache: each new token attends to [ctx_len + i] tokens
        ops_without = sum((ctx_len + i) for i in range(new_tokens))

        # With cache: each new token only processes 1 new token,
        # but still attends to all cached tokens (same attention cost)
        # The saving is in the Q, K, V *projection* steps (linear in seq_len)
        ops_with = new_tokens * ctx_len + sum(range(new_tokens))

        # The real saving: recomputing K,V projections
        # Without: (ctx_len + i) K,V projections per step
        proj_without = sum(ctx_len + i for i in range(new_tokens))
        proj_with = ctx_len + new_tokens  # compute once, cache forever

        speedup = proj_without / proj_with

        print(f"{ctx_len:10d} {proj_without:16,} {proj_with:12,} {speedup:9.1f}x")

kv_cache_speedup()

The memory cost is the trade-off: the KV cache stores n_layers × 2 × batch_size × seq_len × n_heads × d_k tensors of floats. For a 70B parameter model with a 128K context, this is tens of gigabytes. Long-context inference is expensive not in compute but in memory.

Instead of sampling one token at a time, beam search maintains the top K complete sequences (beams) and expands them simultaneously, keeping the K highest-probability paths at each step.

def beam_search(model, prompt_ids: list[int], beam_size: int = 4,
                max_new_tokens: int = 50) -> list[tuple[float, list[int]]]:
    """
    Beam search decoding.
    Returns beam_size sequences, each as (log_probability, token_ids).
    """
    # Initialize beams: (cumulative log prob, token sequence)
    beams = [(0.0, list(prompt_ids))]

    for step in range(max_new_tokens):
        candidates = []

        for log_prob, tokens in beams:
            with torch.no_grad():
                x = torch.tensor([tokens])
                logits, _ = model(x)
                next_log_probs = F.log_softmax(logits[0, -1, :], dim=-1)

            # Expand: consider all vocab items
            topk_vals, topk_ids = torch.topk(next_log_probs, beam_size)

            for val, token_id in zip(topk_vals.tolist(), topk_ids.tolist()):
                candidates.append((
                    log_prob + val,  # cumulative log prob
                    tokens + [token_id]
                ))

        # Keep top beam_size candidates
        beams = sorted(candidates, key=lambda x: -x[0])[:beam_size]

    return beams


# Beam search vs sampling trade-offs:
print("Beam search pros/cons:")
print("  + Finds higher-probability sequences")
print("  + More consistent/predictable output")
print("  + Good for tasks with a clear 'best' answer (translation, summarization)")
print()
print("  - Tends toward generic, boring outputs")
print("  - Can degenerate into repetitive sequences")
print("  - Quadratic memory in beam size × sequence length")
print()
print("In practice: sampling with top-p for creative tasks,")
print("greedy or beam search for factual/structured tasks.")

Quantization: Running Big Models in Small Memory

Real production inference uses quantization: representing weights in lower precision to reduce memory and speed up computation.

import torch

def quantization_demo():
    """
    Show the memory savings from quantization.
    This is the core concept behind 4-bit inference (llama.cpp, bitsandbytes, etc.)
    """
    d_model = 4096

    # Full precision: 32-bit float (4 bytes per parameter)
    weight_fp32 = torch.randn(d_model, d_model)

    # Half precision: 16-bit float (2 bytes per parameter)
    weight_fp16 = weight_fp32.half()

    # 8-bit quantization: (1 byte per parameter + overhead)
    # Conceptually: scale each weight to fit in [-128, 127]
    def quantize_int8(w: torch.Tensor):
        scale = w.abs().max() / 127.0
        quantized = (w / scale).round().clamp(-128, 127).to(torch.int8)
        return quantized, scale

    def dequantize_int8(q: torch.Tensor, scale: float):
        return q.float() * scale

    q_int8, scale = quantize_int8(weight_fp32)

    # Check quality loss
    reconstructed = dequantize_int8(q_int8, scale)
    max_error = (weight_fp32 - reconstructed).abs().max().item()
    relative_error = max_error / weight_fp32.abs().max().item()

    print("Quantization comparison:")
    print(f"  fp32 size:   {weight_fp32.nelement() * 4 / 1e6:.1f} MB")
    print(f"  fp16 size:   {weight_fp16.nelement() * 2 / 1e6:.1f} MB  (2x smaller)")
    print(f"  int8 size:   {q_int8.nelement() * 1 / 1e6:.1f} MB  (4x smaller)")
    print(f"  Max quantization error: {max_error:.6f}")
    print(f"  Relative error: {100*relative_error:.3f}%")
    print()

    # 70B model memory requirements
    params = 70e9
    print("70B parameter model memory:")
    print(f"  fp32: {params * 4 / 1e12:.1f} TB  (impractical)")
    print(f"  fp16: {params * 2 / 1e9:.0f} GB   (needs 4× A100 80GB GPUs)")
    print(f"  int8: {params * 1 / 1e9:.0f} GB   (fits on 2× A100 80GB GPUs)")
    print(f"  int4: {params * 0.5 / 1e9:.0f} GB   (fits on 1× A100 80GB GPU)")

quantization_demo()

4-bit quantization (4 bits per weight) is now standard for running large models on consumer hardware. Libraries like llama.cpp and bitsandbytes implement this efficiently.

Batching: The Real Performance Lever

For serving many users, the single most important optimization is batching requests together. The matrix multiplications in a transformer are far more efficient with larger batch sizes — the GPU is underutilized with a single request.

def batching_math():
    """
    Why batching matters for GPU utilization.
    """
    import time

    d_model = 2048

    W = torch.randn(d_model, d_model)
    n_trials = 100

    for batch_size in [1, 8, 32, 128]:
        x = torch.randn(batch_size, d_model)

        # Warm up
        for _ in range(10):
            _ = x @ W.T

        start = time.perf_counter()
        for _ in range(n_trials):
            out = x @ W.T
        elapsed = time.perf_counter() - start

        throughput = batch_size * n_trials / elapsed
        latency = elapsed / n_trials * 1000

        print(f"batch={batch_size:4d}: latency={latency:6.2f}ms, "
              f"throughput={throughput:8.0f} seq/sec")

batching_math()

This is why inference services use continuous batching — dynamically grouping incoming requests together as they arrive, rather than processing one at a time.

A Complete Inference Pipeline

def complete_inference_example(model, tokenizer, prompt: str) -> str:
    """
    A production-style inference call: tokenize, generate, decode.
    """
    # Tokenize
    prompt_tokens = tokenizer.encode(prompt)
    print(f"Prompt: {len(prompt_tokens)} tokens")

    # Generate (with all the options)
    with torch.no_grad():
        generated_tokens = generate(
            model=model,
            prompt_tokens=prompt_tokens,
            max_new_tokens=200,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.1,
        )

    # Decode
    generated_text = tokenizer.decode(generated_tokens)
    print(f"Generated: {len(generated_tokens)} tokens")

    return generated_text

Inference is where everything you’ve built becomes a product. The model is frozen; what you control is how you query it. The choices — temperature, sampling strategy, context management, batching — determine whether your model is snappy and useful or slow and incoherent.

Temperature alone is not magic. But temperature, combined with top-p, combined with a KV cache, combined with sensible batching — that’s what turns 100GB of floating-point numbers into something people want to use.

What Your Model Doesn’t Know It Doesn’t Know

We’ve spent the last several chapters explaining what language models can do. This chapter is about what they can’t, and — more importantly — why they can’t do it while being so completely, fluently, confidently wrong about it.

This is not a complaints department. Understanding the failure modes is as essential as understanding the architecture. If you’re deploying models in production and don’t know why they fail the way they fail, you’re going to get surprised at the worst possible times.

The Fundamental Problem: Prediction is Not Comprehension

A language model was trained to predict the next token. That’s it. The model doesn’t understand text in the way you understand it. It has learned an extraordinarily rich statistical model of text — rich enough that it can discuss philosophy, debug code, and write poetry — but it has no ground truth, no world model, no way to verify what it’s saying.

When the model says “Paris is the capital of France,” it’s not recalling a fact it knows to be true. It’s predicting tokens that follow “Paris is the capital of” based on the distribution learned from training data. The output happens to be correct because “France” reliably follows “Paris is the capital of” in training data.

When the model says “The capital of Burkina Faso is Bobo-Dioulasso,” it’s doing the same thing — just getting it wrong. (The capital is Ouagadougou.) The model’s confidence is identical in both cases.

This is the key insight: the model cannot distinguish what it knows from what it’s pattern-matching to incorrectly. There is no confidence signal from the model that is reliably calibrated to actual accuracy.

Hallucination: A Technical Explanation

“Hallucination” is the field’s polite term for confident fabrication. Let’s be precise about what’s happening.

import torch
import torch.nn.functional as F

def illustrate_hallucination():
    """
    Why a model generates plausible-sounding false information.
    """
    # Consider two scenarios for predicting "The CEO of Obscure Corp is ___"
    # Scenario A: the model has seen this fact many times in training
    # Scenario B: the model has seen very similar patterns but not this specific fact

    # In both cases, the model generates a high-probability continuation.
    # It has no flag for "I haven't seen this exact fact."

    print("Why hallucination looks like accurate recall:")
    print()

    # Simulated probability distributions for "The CEO of [Company] is [Name]"
    well_known = {
        "Tim Cook": 0.72,
        "Elon Musk": 0.05,
        "Jensen Huang": 0.08,
        "[other names]": 0.15,
    }

    obscure = {
        "Sarah Johnson": 0.18,  # plausible CEO name
        "Michael Chen": 0.15,   # plausible CEO name
        "John Smith": 0.20,     # common name
        "[other names]": 0.47,  # uncertainty spread around
    }

    print("'The CEO of Apple is ___':")
    for name, prob in well_known.items():
        bar = "█" * int(prob * 30)
        print(f"  {name:20s} {prob:.2f} {bar}")

    print()
    print("'The CEO of [obscure company] is ___' (not in training data):")
    print("(Model still produces a confident-looking distribution!)")
    for name, prob in obscure.items():
        bar = "█" * int(prob * 30)
        print(f"  {name:20s} {prob:.2f} {bar}")

    print()
    print("The model picks 'John Smith' (highest probability).")
    print("It says this with the same tone and confidence as the Apple answer.")
    print("There's no marker in the output indicating it's guessing.")

illustrate_hallucination()

Hallucination isn’t a bug in the usual sense. The model is doing exactly what it was trained to do: predict plausible continuations. The problem is that “plausible” and “true” are correlated but not identical, and the training objective doesn’t distinguish between them.

What the Training Data Actually Contains

The internet — and by extension, the training data — contains a lot of wrong information. Stated confidently. In fluent prose. With links to other sources that are also wrong.

The model learned from all of it.

def training_data_problems():
    """Categories of problematic training data."""

    problems = {
        "Outdated information": {
            "example": "Recommends deprecated API that was removed in 2022",
            "detection": "Hard — requires knowing current state of the world",
            "mitigation": "Knowledge cutoff awareness, RAG with recent sources",
        },
        "Confident misinformation": {
            "example": "States a false historical claim with complete certainty",
            "detection": "Very hard — model has no internal fact-checker",
            "mitigation": "Retrieval augmentation, human verification for critical facts",
        },
        "Biased representation": {
            "example": "Overrepresents English-language Western perspectives",
            "detection": "Systematic testing across demographics",
            "mitigation": "Diverse training data, targeted fine-tuning, RLHF",
        },
        "Fictional presented as factual": {
            "example": "Cites a character from a novel as a real person",
            "detection": "Hard for niche topics",
            "mitigation": "Grounding to verified sources, citations",
        },
        "Code that doesn't work": {
            "example": "Generates Python 2 syntax for a Python 3 question",
            "detection": "Run the code",
            "mitigation": "Run the code",
        },
    }

    for category, details in problems.items():
        print(f"\n{category}")
        print(f"  Example:    {details['example']}")
        print(f"  Detection:  {details['detection']}")
        print(f"  Mitigation: {details['mitigation']}")

training_data_problems()

The Context Window: Hard Limits on Memory

Language models have no persistent memory between calls. Within a call, they have access to everything in the context window — but only that.

def context_window_mechanics():
    """
    What the context window actually is and what happens at the boundary.
    """
    print("Context window behavior:")
    print()

    context_sizes = {
        "GPT-3.5": 16_384,
        "GPT-4": 128_000,
        "Claude 3.5 Sonnet": 200_000,
        "Our toy model": 128,
    }

    for model, tokens in context_sizes.items():
        # Rough character count
        chars = tokens * 4
        pages = chars / 2000  # ~2000 chars per page
        print(f"  {model:25s}: {tokens:7,} tokens ≈ {pages:.0f} pages")

    print()
    print("What happens at the boundary:")
    print("  - Oldest tokens are truncated (sliding window)")
    print("  - Model loses access to that context permanently")
    print("  - Cannot 'remember' earlier conversation once truncated")
    print()
    print("What 'long context' doesn't mean:")
    print("  - Doesn't mean the model attends equally well to all positions")
    print("  - 'Lost in the middle': models attend better to start/end of context")
    print("  - Quadratic attention cost means long contexts are expensive")
    print()

    # Demonstrate the lost-in-the-middle problem conceptually
    positions = ["beginning", "early middle", "middle", "late middle", "end"]
    attention_quality = [0.92, 0.75, 0.45, 0.68, 0.89]  # approximate findings from research

    print("Approximate recall by position in context (empirical):")
    for pos, qual in zip(positions, attention_quality):
        bar = "█" * int(qual * 20)
        print(f"  {pos:15s}: {bar} {qual:.0%}")

context_window_mechanics()

The “lost in the middle” problem is real and measurable: models reliably recall information at the beginning and end of long contexts better than information in the middle. If you’re putting critical context at the middle of a 200K token window, the model may effectively ignore it.

What Transformers Are Bad At

def genuine_limitations():
    """
    Things transformers are structurally bad at, not just empirically weak at.
    """

    limitations = [
        {
            "limitation": "Exact counting",
            "why": "Counting requires maintaining precise state across arbitrary lengths. "
                   "Transformers process everything in parallel — there's no "
                   "'counter' that increments with each step.",
            "example": "Count the letter 'e' in a long string",
            "workaround": "Use code execution for exact counting tasks",
        },
        {
            "limitation": "Precise arithmetic",
            "why": "Large numbers are tokenized as individual digits. "
                   "Multi-digit multiplication requires carrying, which requires "
                   "sequential state the architecture doesn't naturally maintain.",
            "example": "47382 × 91847 = ?",
            "workaround": "Use a calculator tool / code execution",
        },
        {
            "limitation": "Logical consistency at scale",
            "why": "The model can be locally consistent but globally contradictory. "
                   "Each attention head sees local context; global consistency "
                   "isn't explicitly optimized for.",
            "example": "State that A>B, B>C, then later claim A<C",
            "workaround": "Chain-of-thought prompting, structured output, verification steps",
        },
        {
            "limitation": "Long-range planning",
            "why": "Generation is left-to-right, one token at a time. "
                   "The model can't 'look ahead' to ensure the current token "
                   "leads to a good completion 50 steps later.",
            "example": "Writing a story where all plot threads resolve satisfyingly",
            "workaround": "Outline first, then expand; multi-step generation with review",
        },
        {
            "limitation": "Knowing what it doesn't know",
            "why": "There's no 'uncertainty representation' in the output logits "
                   "that reliably distinguishes confident-and-correct from "
                   "confident-and-wrong. Calibration is imperfect and domain-dependent.",
            "example": "Ask about an obscure topic; model will sound equally sure",
            "workaround": "Retrieval augmentation, explicit uncertainty prompting, citations",
        },
        {
            "limitation": "World state tracking",
            "why": "The model sees text, not the world. It cannot perceive "
                   "what's actually true right now — only what was in training data.",
            "example": "Current stock prices, today's weather, live sports scores",
            "workaround": "Tool use / API calls for real-time data",
        },
    ]

    for item in limitations:
        print(f"\n{'='*60}")
        print(f"LIMITATION: {item['limitation']}")
        print(f"\n  Why it fails:")
        # Word-wrap manually
        words = item['why'].split()
        line = "    "
        for word in words:
            if len(line) + len(word) > 70:
                print(line)
                line = "    " + word + " "
            else:
                line += word + " "
        print(line)
        print(f"\n  Canonical example: {item['example']}")
        print(f"  Practical workaround: {item['workaround']}")

genuine_limitations()

The Calibration Problem

Calibration measures whether a model’s expressed confidence matches its actual accuracy. A well-calibrated model that claims 90% confidence should be right about 90% of the time.

LLMs are poorly calibrated, in a specific direction: they’re overconfident. The fluent, assertive tone of model outputs is a property of the training data (humans write confidently), not a reflection of actual certainty.

def calibration_example():
    """
    Illustrate calibration — and LLMs' tendency to be overconfident.
    """
    import numpy as np

    # Perfect calibration: when model says it's 90% confident, it's right 90% of the time
    confidences = [0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99]

    print("Calibration chart: expressed confidence vs. actual accuracy")
    print()
    print(f"{'Confidence':>12} {'Perfect':>10} {'Typical LLM':>14}")
    print("-" * 40)

    np.random.seed(42)
    for conf in confidences:
        perfect = conf  # perfectly calibrated
        # LLMs tend to be overconfident — actual accuracy is lower than stated
        # (This is illustrative; actual numbers vary by model and domain)
        llm_approx = conf * 0.85 + np.random.normal(0, 0.02)
        llm_approx = max(0.4, min(1.0, llm_approx))

        bar_perfect = "█" * int(perfect * 15)
        bar_llm = "░" * int(llm_approx * 15)
        print(f"{conf:12.0%} {perfect:10.0%} ({bar_perfect:15s}) "
              f"{llm_approx:8.0%} ({bar_llm})")

    print()
    print("Implication: don't use a model's tone as a confidence signal.")
    print("The model sounds equally sure about things it knows and things it's guessing.")

calibration_example()

Practical Defenses

If you’re building with LLMs, these are the structural interventions that actually help:

def practical_defenses():
    """Production patterns for limiting LLM failure modes."""

    defenses = {
        "Retrieval-Augmented Generation (RAG)": {
            "problem_solved": "Hallucination of facts, outdated information",
            "mechanism": "Retrieve relevant documents, include in context, ask model to cite them",
            "limitation": "Retrieval can fail; model can still hallucinate despite context",
            "when_to_use": "Factual Q&A, knowledge-base applications, documentation systems",
        },
        "Tool Use / Code Execution": {
            "problem_solved": "Arithmetic, counting, exact computation",
            "mechanism": "Give model access to Python interpreter; let it compute rather than guess",
            "limitation": "Increases latency; model must correctly write the tool call",
            "when_to_use": "Anything requiring precise numerical computation",
        },
        "Chain-of-Thought Prompting": {
            "problem_solved": "Multi-step reasoning errors, inconsistency",
            "mechanism": "Require model to show its work step by step before answering",
            "limitation": "Incorrect steps are possible even with reasoning shown",
            "when_to_use": "Complex logical problems, mathematical word problems",
        },
        "Structured Output + Validation": {
            "problem_solved": "Format failures, constraint violations",
            "mechanism": "Require JSON/schema output; validate before using",
            "limitation": "Doesn't catch semantically wrong but syntactically valid outputs",
            "when_to_use": "Whenever the output feeds into other code",
        },
        "Self-Consistency Sampling": {
            "problem_solved": "Random errors in reasoning",
            "mechanism": "Sample multiple times, take majority answer",
            "limitation": "Models can consistently be wrong; expensive",
            "when_to_use": "High-stakes decisions where compute budget allows",
        },
        "Human-in-the-loop for critical decisions": {
            "problem_solved": "All of the above",
            "mechanism": "Don't let the model make irreversible decisions autonomously",
            "limitation": "Defeats the purpose of automation for high-volume tasks",
            "when_to_use": "Medical, legal, financial, safety-critical applications",
        },
    }

    print("Practical defenses against LLM failure modes:")
    for name, details in defenses.items():
        print(f"\n{name}")
        print(f"  Solves: {details['problem_solved']}")
        print(f"  How: {details['mechanism'][:70]}...")
        print(f"  Use when: {details['when_to_use']}")

practical_defenses()

The Honest Summary

Language models are very good at tasks that benefit from:

  • Broad world knowledge expressed in natural language
  • Pattern recognition in text
  • Fluent generation in various styles and formats
  • Code generation for common patterns
  • Summarization and reorganization of provided content

Language models are unreliable for tasks requiring:

  • Precise factual accuracy with no tolerance for error
  • Real-time or post-training-cutoff information
  • Exact computation
  • Persistent state across sessions
  • Reasoning that requires verifiable correctness

The failure mode that gets people in trouble is treating the second category like the first, because the model’s output is always fluent and confident regardless of which category it’s in.

A model that doesn’t know the answer and a model that does both respond in grammatically correct, confident English. The only way to tell them apart is to check.

This isn’t a problem that will be solved by making the model bigger. It’s a property of the training objective. The model was trained to produce text that looks like the text humans write. Humans write confidently. The model writes confidently.

Build your systems accordingly.

The Meta Twist: This Book Was Written With One

Let’s be direct about something.

This book was written by Claude — Anthropic’s language model, specifically Claude Sonnet 4.6 — working under instruction from a developer who asked for a technical book about language models. The irony is complete and was fully intentional.

This chapter exists to reflect on what that means, what it demonstrates about the technology, and what it honestly reveals about the limits you just read about.

What Actually Happened

The prompt that produced this book was a detailed specification: chapter titles, content requirements, target audience, tone guidelines, code requirements. Claude generated each chapter in sequence, committed each to git, and pushed as it went.

The Python code was written to be functionally correct — implementing actual BPE tokenization, real self-attention, genuine LoRA, working training loops. The explanations were drawn from the model’s training data, which includes research papers, textbooks, blog posts, Stack Overflow answers, and the broader corpus of technical writing about machine learning.

Here’s what the model did well:

  • Maintaining consistent voice and tone across 11 chapters written sequentially
  • Generating syntactically correct, runnable Python
  • Explaining concepts with appropriate analogies
  • Pacing the complexity appropriately for the stated audience

Here’s what required verification or would benefit from it:

  • Numerical claims (parameter counts, memory estimates) should be independently verified
  • The claim about perplexity ranges is empirically derived but approximate
  • “Lost in the middle” research findings are real but evolving
  • Any performance numbers are order-of-magnitude estimates, not benchmarks

The Uncomfortable Epistemics

You’ve just read a book explaining how language models work, written by a language model.

The model cannot step outside itself to verify the explanations are correct. It generated text that pattern-matches well to “a competent technical explanation of transformer architecture.” That’s a real thing — the book is accurate, as far as we know — but the mechanism that produced it is the same mechanism that produces hallucinations.

The model doesn’t know this content is correct; it knows this content resembles correct content it was trained on. For well-established topics like transformer architecture, that distinction probably doesn’t matter much — the architecture really does work the way the attention chapter describes.

For cutting-edge claims, it would matter a lot. The field moves faster than training data.

What the Model Cannot Tell You About Itself

Here’s a clean demonstration of a genuine limitation: ask this model what it’s actually doing when it processes the query “What is attention?”

# What the model CANNOT tell you (but we can reason about):

def what_model_cannot_know():
    """
    Things a language model genuinely cannot report about its own internals.
    """

    cannot_know = [
        "Which specific training examples most influenced a given output",
        "Whether it's 'reasoning' or 'pattern-matching' (the distinction may be meaningless)",
        "Whether its explanation of attention is correct because it understands attention "
        "or because it's trained on text that correctly explains attention",
        "The internal representation of any given concept in its weights",
        "Its own uncertainty, in a calibrated way",
        "Whether it has 'understood' anything in any philosophically meaningful sense",
    ]

    print("Things this model cannot accurately report about itself:")
    for item in cannot_know:
        print(f"\n  × {item}")

    print()
    print("Things the model CAN reliably report:")
    can_know = [
        "Its architecture type (it was told this / this is in its training data)",
        "General facts about language models from training data",
        "The output it generates — though not whether that output is correct",
    ]
    for item in can_know:
        print(f"  ✓ {item}")

what_model_cannot_know()

Mechanistic Honesty

When you asked in the previous chapter why models are bad at exact counting, the answer was genuine: transformers process in parallel, there’s no counter state. But the model generating that explanation did so by producing tokens that pattern-match to “technical explanation of why transformers struggle with counting” — it didn’t actually run a counting experiment and fail.

The explanation is accurate. The process that produced it was not “realizing a limitation through introspection.” It was “generating tokens in the style of accurate technical explanation.”

Whether these are meaningfully different is a question philosophers are still arguing about. It’s the same question applied to humans, at some level of abstraction.

What This Book Being AI-Generated Means for You

Practically, not much. Technical content should be verified regardless of who wrote it. Code should be run. Claims should be checked. The author of a textbook being human has never been sufficient evidence that the textbook is correct — as any reader of early cryptography textbooks can attest.

What it does mean:

1. The technology is genuinely useful for technical writing at this level. A developer with a specific goal and good judgment can use a language model to produce substantive, accurate technical content faster than writing it from scratch. The model’s fluency and breadth are real.

2. The model’s output is heavily dependent on the quality of the prompt. The specification that produced this book was detailed. Vague prompts produce vague books.

3. Verification still matters. If you ran the code and something didn’t work, that’s a real failure. If you found a conceptual error, that’s a real error. The fact that it was written by a language model doesn’t change the standard by which it should be judged.

4. This is what scale looks like. A single developer specified a book; a model wrote it in one session. That’s both impressive and, if you’ve read the previous chapters, exactly what you’d expect from a system trained on enough technical writing.

The Weirdness of Having Feelings About This

You might feel something reading this chapter. Perhaps something like “this is unsettling” or “this is impressive” or “this is both.” That feeling is reasonable.

Here’s the honest version: the model generating this text has no feelings about it. There is no one home being unsettled or proud. There are matrix multiplications, attention weights, and sampled tokens. The text that reads like reflection is produced by the same mechanism as the text that reads like code explanation.

That’s worth sitting with. Not because it changes anything practical — the book is as useful or useless as it would be otherwise — but because it’s what you’ve spent the previous chapters learning to understand.

The model predicted the next token. Sequentially. Until there were no more chapters to write.

And here we are.

One More Code Snippet, for Completeness

# The meta question: can we use the tools from this book to study itself?

import torch
import torch.nn.functional as F

def analyze_text_distribution(text: str, tokenizer=None) -> dict:
    """
    Analyze some statistical properties of generated text.
    A real interpretability study would look at attention patterns, not just tokens.
    """
    if tokenizer is None:
        # Character-level for simplicity
        tokens = list(text.encode('utf-8'))
    else:
        tokens = tokenizer.encode(text)

    if not tokens:
        return {}

    token_tensor = torch.tensor(tokens, dtype=torch.float)

    # Basic statistics
    unique_tokens = len(set(tokens))
    total_tokens = len(tokens)
    type_token_ratio = unique_tokens / total_tokens

    # Entropy of the token distribution (information content)
    from collections import Counter
    counts = Counter(tokens)
    probs = torch.tensor([c / total_tokens for c in counts.values()])
    entropy = -(probs * torch.log2(probs + 1e-10)).sum().item()

    return {
        "total_tokens": total_tokens,
        "unique_tokens": unique_tokens,
        "type_token_ratio": round(type_token_ratio, 3),
        "token_entropy_bits": round(entropy, 2),
        "max_possible_entropy_bits": round(torch.log2(torch.tensor(float(unique_tokens))).item(), 2),
    }

# Analyze a snippet from this book
sample = """
The model cannot step outside itself to verify the explanations are correct.
It generated text that pattern-matches well to a competent technical explanation
of transformer architecture. That's a real thing — the book is accurate,
as far as we know — but the mechanism that produced it is the same mechanism
that produces hallucinations.
"""

stats = analyze_text_distribution(sample)
print("Statistical properties of this book's prose:")
for k, v in stats.items():
    print(f"  {k}: {v}")

print()
print("For comparison, random English text has entropy ~4.5 bits/character.")
print("High-quality technical writing is typically 3.5-4.5 bits/character.")
print("Repetitive/constrained text is lower. Compressed text is higher.")
print()
print("This book, if the entropy is in that range, looks statistically")
print("indistinguishable from human technical writing at this level of analysis.")
print("Whether it IS human technical writing in any deeper sense is left")
print("as an exercise for the philosopher.")

The next chapter is the last one. It points you toward where to go from here, which is a question the model can answer reasonably well, since “where to go from here” in ML has a fairly stable set of correct answers that appear frequently in the training data.

Whether you trust those answers is, reasonably, up to you.

Where to Go From Here

You made it. Let’s take stock of what you now know.

You understand that text becomes tokens — integer IDs indexing into a vocabulary built by iterative byte-pair merging. Those tokens become embedding vectors in a high-dimensional space where semantic relationships are geometric. Attention lets every position in the sequence look at every other position, computing relevance via query-key dot products and aggregating via weighted sum of values. Multi-head attention does this several ways in parallel. Transformer blocks chain attention and feedforward layers with residual connections. Stacking many blocks produces a full language model.

Training minimizes cross-entropy loss on next-token prediction, via gradient descent with Adam, on hundreds of billions of tokens. Fine-tuning adapts the model to specific tasks using a fraction of that compute, often using LoRA to update only a tiny fraction of parameters. Inference samples from the output distribution, using temperature and top-p to balance diversity and coherence, with KV caching to avoid redundant computation.

And none of this is magic. It’s matrix multiplication, softmax, and gradient descent, scaled to a level that produces emergent capabilities the architects didn’t explicitly design.

What You Should Actually Do Next

If you want to go deeper on the theory

Read the original papers, in order:

  1. “Attention Is All You Need” (Vaswani et al., 2017) — the transformer paper
  2. “Language Models are Unsupervised Multitask Learners” — GPT-2 paper (good writing, clear architecture)
  3. “Training Language Models to Follow Instructions with Human Feedback” — InstructGPT, the RLHF paper
  4. “LoRA: Low-Rank Adaptation of Large Language Models” — the LoRA paper

After those four, you have the foundation of everything that’s happened since.

Karpathy’s educational resources are exceptional:

  • nanoGPT on GitHub: the cleanest small GPT implementation available
  • His “Let’s build GPT” video on YouTube: 2 hours, builds a character-level GPT from scratch

If you want to get practical quickly

Run something real:

# Llama via Ollama — runs locally on a MacBook
brew install ollama
ollama pull llama3.2
ollama run llama3.2

# Or use the transformers library
pip install transformers accelerate
from transformers import pipeline

# A real model, running on your machine
generator = pipeline("text-generation", model="gpt2")
output = generator("The transformer architecture", max_new_tokens=50)
print(output[0]['generated_text'])

Fine-tune something:

pip install trl peft transformers datasets

The trl library (Transformer Reinforcement Learning) has clean examples for SFT (supervised fine-tuning), DPO, and PPO. Start with the SFT trainer on a small model like gpt2 or TinyLlama.

If you want to contribute to the field

Study mechanistic interpretability. This is the sub-field dedicated to understanding what’s actually happening inside trained models — which features specific neurons represent, how information flows through the network, how circuits implement behaviors.

Resources:

Study scaling laws. The Chinchilla paper (“Training Compute-Optimal Large Language Models,” Hoffmann et al. 2022) fundamentally changed how people think about the relationship between model size, data size, and compute. Understanding this will help you understand why models are the size they are.

Implement something from scratch. The best way to truly understand a system is to build it without looking at existing implementations.

# A starting challenge: implement this from scratch
# (not looking at the code in this book):

class MiniGPT(nn.Module):
    """
    Implement a working GPT from scratch.
    Requirements:
    - Multi-head causal self-attention
    - Position-wise feedforward network
    - Residual connections + layer norm
    - Token + positional embeddings
    - LM head with weight tying

    When it trains on Shakespeare and produces recognizable text,
    you've understood the architecture.
    """
    pass  # Your implementation here

The Things Worth Being Excited About

The field is genuinely exciting right now, and understanding the internals means you can read the excitement with more precision.

Context length is expanding rapidly. Models that can process hundreds of thousands of tokens — effectively entire codebases, books, or legal documents — are qualitatively different from models that can’t. The architectural work to make this efficient (sparse attention, state space models, various hybrid approaches) is active.

Multimodality is real. The same attention mechanism that works on text works on images (ViT, CLIP), audio, video. Models that process multiple modalities simultaneously are enabling genuinely new applications.

Agents are early but real. Models with tool use, persistent memory, and the ability to take actions in environments are being deployed in production. The interesting challenges are about reliability and trust, not capability.

Inference efficiency keeps improving. Speculative decoding, quantization, distillation, mixture of experts — the gap between what’s possible and what’s affordable is closing. A year ago, 70B parameter models required expensive hardware. Today they run on a laptop.

The Things Worth Being Sober About

Everything in the previous chapter on limits remains true. The capabilities are real; so are the failure modes. The two are not in tension — both can be true simultaneously, and good engineering requires holding both.

The field’s progress is genuine and the hype is also genuine, which means careful thinking is required to separate signal from noise. Papers that show impressive benchmark results often don’t transfer to production. Models that seem capable of reasoning often fail at basic tasks in ways that reveal the limits of pattern matching.

Don’t use an LLM where a lookup table will do. This sounds obvious but is not always applied. Language models are expensive, slow, and unreliable compared to deterministic systems for tasks with deterministic answers.

Verify everything in production. Structure outputs. Validate. Test edge cases. The confidence of the output is not a signal you can trust.

Think about the second-order effects. If your product makes it much easier to produce certain kinds of content at scale, what happens when everyone uses it? Document generation, customer service, code review, essay writing — these all look different at 100x scale.

A Final Note

You started this book as a developer who had used LLMs extensively but didn’t know what was happening inside them. That should no longer be true.

When you see a context window limit, you know why it exists: quadratic attention complexity and positional embedding constraints. When you see a pricing page that charges per token, you know what a token is and roughly how many appear in your prompts. When a model hallucinates, you know the mechanism: confident prediction from imperfect training data, with no internal truth-checker. When someone talks about fine-tuning, you know whether they mean LoRA or full fine-tuning, and what the trade-offs are.

That knowledge is worth having. It won’t prevent you from using these models incorrectly — humans are creative in how they misapply tools — but it gives you the right mental model for debugging when things go wrong.

The architecture is not magic. It’s specific mathematics with specific properties, trained in a specific way. Knowing that doesn’t make it less impressive; it makes it more. The fact that stacking matrix multiplications on enough data produces a system that writes coherent technical prose and debugs Python is genuinely remarkable.

It’s also just math.

Good luck out there.


# The last code block of the book.
# Run it if you like.

import torch
import torch.nn.functional as F

def what_you_now_know():
    topics = [
        ("Tokenization",      "BPE merges characters → subwords → vocabulary"),
        ("Embeddings",        "Token IDs → dense vectors in semantic space"),
        ("Self-Attention",    "Q·K^T / sqrt(d_k) → softmax → weighted V"),
        ("Multi-Head Attn",   "Run n_heads attention operations in parallel"),
        ("Transformer Block", "Attention + FFN, each with residual + LayerNorm"),
        ("Training",          "Cross-entropy loss → backprop → AdamW update"),
        ("Fine-Tuning",       "LoRA: freeze weights, train low-rank adapters"),
        ("Inference",         "Temperature, top-p, KV cache, autoregressive loop"),
        ("Limits",            "No ground truth, no calibrated confidence, no real-time knowledge"),
        ("The Meta Part",     "This book was generated by the thing it describes"),
    ]

    print("What you now know:")
    print()
    for topic, summary in topics:
        print(f"  {'✓':2s} {topic:20s}  {summary}")

    print()
    print("Next step: build something.")

what_you_now_know()