Attention sink

More evidence

Intro

I read the induction heads paper a while back, while taking the ARENA course. The paper lays out a super interesting mechanistic study for in-context learning and specifically examines induction head in transformer language models.

While playing around with induction heads in GPT2, I thought to myself that "What if the input to induction heads isn't present, what do the induction heads pay attention to?" I thought this might be a good question to investigate and after a quick literature search, I stumbled on the attention sink paper and a bunch of other works that made fantastic attempts at answering the question.

Guo, et al., in "Active-Dormant Attention Heads" investigated the same question but from a different angle. They trained a 3L GPT2-style transformer on bigram backcopy task and then investigated which heads were heavily involved in the backcopy task. Then they showed this heads were dormant when the bigram backcopy input isn't present.

While the question was sort of answered already, I thought it would still be a good exercise present the thought process I went through while attempting to answer the question.

In this remainder of this post, I briefly motivate what an attention head is doing, explain induction heads and how to look for them (with visualizations) and show what happens when the induction heads input isn't present.

Feel free to skip parts you're familiar with.

What is an attention head doing ?

Intuitive overview

In summary, attention heads move information between tokens!

A simplified view of the transformer
Fig. 1: A simplified view of the transformer. Source: A mathematical framework for transformer circuits.

The residual stream is the main object in the transformer. A way I think of it is that it represents what the model currently thinks about all the tokens in it's context, up to a particular layer. To enrich and further refine the representation of the tokens in the context, attention heads move information from earlier tokens in the context to later tokens in the context and MLP blocks compose information and perform retrieval tasks .

More concretely

The input to the attention layer is the residual stream In the first layer, this is the sum of token embeddings and positional embeddings. with shape [batch_size, seq_len, d_model]. This input is linearly projected using three weight matrices: W_Q, W_K, and W_V, each of shape [d_model, d_model], to produce the query (Q), key (K), and value (V) matrices.

In multi-head attention, Q, K, and V are split into num_heads parts. Each head processes a subspace of the input, with Q and K shaped as [batch_size, num_heads, seq_len, d_k] and V as [batch_size, num_heads, seq_len, d_v], where d_k = d_v = d_model / num_heads.

For each head, attention scores are computed as the dot product of query and key vectors, scaled by 1/√d_k. The scores are passed through a softmax to obtain the attention pattern, which represents the importance of each token relative to others. This pattern is then multiplied by the value vectors to produce the head's output. The outputs of all heads are concatenated and projected using a weight matrix W_O of shape [d_model, d_model] to yield the final attention output, shaped [batch_size, seq_len, d_model].

From Vaswani et al. , the attention mechanism is defined as: $$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$ Multi-head attention is expressed as: $$ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W^O $$ $$ \text{where } \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) $$ Here, W_i^Q, W_i^K, and W_i^V are head-specific projection matrices.

Further questions

The next logical question is, how does each attention head across all the layers know what sort of information to pay attention to? During pre-training, the goal is optimizing the next-token objective w.r.t the parameters of the model, over the language domain. It stands to reason that over the course of multiple steps of gradient descent, each attention head learns to pay attention to some pattern (semantic or syntactic) in the language data and this pattern, when learned, contributes to lower loss.

And indeed, numerous papers have explored this assumption.

In both decoder-only and encoder-decocder transformers, attention heads have been discovered that specialize in attending to different parts of speech, as well as other lingustic propertites such direct objects of verbs, noun determiners, e.t.c.

Interesting mechanisms that further enable LLMs to act autoregressively have been discovered, such as Copy Supression heads and Induction heads

A logical conclusion of the above paragraphs is that what attention heads pay attention to is input-specific. This begs the question : What does an attention head pay attention to, when it's input isn't present?

Induction heads

I'll present a super simplified explanation of Induction heads here, but to better understand Induction heads mechanistically, Callum McDougall wrote a quite interesting explainer blog which I invite readers to check out. The paper also goes into a lot more details that I only mention slightly such as the presence of previous-token heads and the role of the QK/OV circuit.

Assume arbitrary tokens A, B. Then assume a sequence of tokens with A followed by B and then some other arbitrary tokens. The next time the model sees A, i.e [A B ... A], B turns out to be one of the highly likely next tokens.

Anthropic researchers found these phenomenon in as little as 2L transformer. One of the conclusions is that the model has learnt to increase the logits on B if the last token in the sequence is A and indeed, it's theorized that Induction heads is one of the mechanisms behind In-context learning.

For this to be true, there has to be a previous-token head. This ensures that the first occurrence of [B] pays attention to the first occurrence of [A] and the \(W_V\) matrix copies A to the subspace of B. Then when A occurs in the context again, for some head \(\hat{h}\), the second occurrence of A pays attention to the first occurrence of B, sees that A is in the residual stream of B and then copies B to the residual stream of the second occurrence of A and increases it's logits. This new head \(\hat{h}\) is an induction head.

Identifying Induction heads in GPT2

Inputs

We sample N = 25 random tokens from the vocabulary of a transformer language model and duplicate it along it's axis. This becomes the input to the transformer.

Input to the transformer is a matrix of shape [1, 2 * N + 1, d_model] i.e batch is 1, sequence length is 2 * N + 1 + 1 because we append the bos token to the sequence and d_model = embedding dimension of the transformer language model.

Pass this sequence of randomly repeated tokens into GPT2 and cache the activations. This can be done easily by loading the model with transformer lens and running
_, cache = model.run_with_cache(input_tokens)

Metric

Assume we have some head h at some layer l, the attention pattern is defined as, $$ \text{A}^{l, h} = \text{softmax}\left(\frac{Q_{l, h}K^T_{l, h}}{\sqrt{d_k}}\right) $$

We define induction score for head h in layer l as a measurement of how much attention a token in the second repeat (at position i + N) pays to its corresponding token in the first repeat (at position i). It's represented as:

$$I(l, h) = \frac{1}{N} \sum_{i = 1} ^N A^{l, h} [i + N, i] $$

Identifying induction heads

Retrieve the attention pattern from the cache and for each head in each layer, calculate the induction score as defined above.
def induction_head_detector( cache, cfg, ) -> list:
    induction_heads = [] 
    for layer_idx in range(cfg.n_layers): 
        for head_idx in range(cfg.n_heads): 
            # fetch the attention pattern at some layer and some head
            attn_pattern = cache["pattern",layer_idx][head_idx] 
            rand_tok_seq_len = (attn_pattern.shape[1] -1) // 2 
            # compute the induction score for the attention pattern
            score = attn_pattern.diagonal(-rand_tok_seq_len + 1).mean() 
            # filter with threshold of 0.4
            if score.item() >= 0.4:
                induction_heads.append((layer_idx, head_idx)) 
    return induction_heads

Results

Below is a visual map of the induction heads present in GPT2

Below is an interactive visualization of the attention patterns for the induction heads identified above.

What happens if the induction input isn't present?

Inputs

Load a tiny subset of the 10K pile dataset.For the purpose of this experiment, I used batch = 1 and sequence_length = 128.

Forward pass is also ran on this input and the activations are cached as in the case above as well.

For the induction heads that were identified in the section above, we simply visualize the attention pattern for these heads.

Results

As can be observed, these heads all pay an overwhelming amount of attention to the first token.

Closing thoughts

Guo et. al., observed that not only the first token, but other special tokens, get an overwhelming amount of attention in dormant cases.

They showed further evidence of this phenomenon by confirming that the value vectors of these tokens were much smaller than that of other tokens This lends evidence to the fact that the information being written back to the residual stream is not of huge consequence. and the residual stream norm for this tokens were relatively small as well.

This however isn't the only explanation for the first-token/special token phenomenon observed in attention heads. Federico Personally, I enjoy Federico's papers and especially his interview on MLST podcast. et. al. , has a paper where he also investigates why attention sinks exists and presents an alternative explanation. TLDR: They serve the purpose of preventing mode collapse.