The Math Behind In-Context Learning
From attention to gradient descent: unraveling how transformers learn from examples
In-context learning (ICL) — a transformer’s ability to adapt its behavior based on examples provided in the input prompt — has become a cornerstone of modern LLM usage. Few-shot prompting, where we provide several examples of a desired task, is particularly effective at showing an LLM what we want it to do. But here’s the interesting part: why can transformers so easily adapt their behavior based on these examples? In this article, I’ll give you an intuitive sense of how transformers might be pulling off this learning trick.
Source: Image by Author (made using dingboard)
This will provide a high-level introduction to potential mechanisms behind in-context learning, which may help us better understand how these models process and adapt to examples.
The core goal of ICL can be framed as: given a set of demonstration pairs ((x,y) pairs), can we show that attention mechanisms can learn/implement an algorithm that forms a hypothesis from these demonstrations to correctly map new queries to their outputs?
Softmax Attention
Let’s recap the basic softmax attention formula,
Source: Image by Author
We’ve all heard how temperature affects model outputs, but what’s actually happening under the hood? The key lies in how we can modify the standard softmax attention with an inverse temperature parameter. This single variable transforms how the model allocates its attention — scaling the attention scores before they go through softmax changes the distribution from soft to increasingly sharp. This would slightly modify the attention formula as,
Source: Image by Author
Where c is our inverse temperature parameter. Consider a simple vector z = [2, 1, 0.5]. Let’s see how softmax(c*z) behaves with different values of c:
When c = 0:
softmax(0 * [2, 1, 0.5]) = [0.33, 0.33, 0.33]All tokens receive equal attention, completely losing the ability to discriminate between similarities
When c = 1:
softmax([2, 1, 0.5]) ≈ [0.59, 0.24, 0.17]Attention is distributed proportionally to similarity scores, maintaining a balance between selection and distribution
When c = 10000 (near infinite):
softmax(10000 * [2, 1, 0.5]) ≈ [1.00, 0.00, 0.00]Attention converges to a one-hot vector, focusing entirely on the most similar token — exactly what we need for nearest neighbor behavior
Now here’s where it gets interesting for in-context learning: When c is tending to infinity, our attention mechanism essentially becomes a 1-nearest neighbor search! Think about it — if we’re attending to all tokens except our query, we’re basically finding the closest match from our demonstration examples. This gives us a fresh perspective on ICL — we can view it as implementing a nearest neighbor algorithm over our input-output pairs, all through the mechanics of attention.
But what happens when c is finite? In that case, attention acts more like a Gaussian kernel smoothing algorithm where it weights each token proportional to their exponential similarity.
We saw that Softmax can do nearest neighbor, great, but what’s the point in knowing that? Well if we can say that the transformer can learn a “learning algorithm” (like nearest neighbor, linear regression, etc.), then maybe we can use it in the field of AutoML and just give it a bunch of data and have it find the best model/hyperparameters; Hollmann et al. did something like this where they train a transformer on many synthetic datasets to effectively learn the entire AutoML pipeline. The transformer learns to automatically determine what type of model, hyperparameters, and training approach would work best for any given dataset. When shown new data, it can make predictions in a single forward pass — essentially condensing model selection, hyperparameter tuning, and training into one step.
In 2022, Anthropic released a paper where they showed evidence that induction head might constitute the mechanism for ICL. What are induction heads? As stated by Anthropic — “Induction heads are implemented by a circuit consisting of a pair of attention heads in different layers that work together to copy or complete patterns.”, simply put what the induction head does is given a sequence like — […, A, B,…, A] it will complete it with B with the reasoning that if A is followed by B earlier in the context, it is likely that A is followed by B again. When you have a sequence like “…A, B…A”, the first attention head copies previous token info into each position, and the second attention head uses this info to find where A appeared before and predict what came after it (B).
Recently a lot of research has shown that transformers could be doing ICL through gradient descent (Garg et al. 2022, Oswald et al. 2023, etc) by showing the relation between linear attention and gradient descent. Let’s revisit least squares and gradient descent,
Source: Image by Author
Now let’s see how this links with linear attention
Linear Attention
Here we treat linear attention as same as softmax attention minus the softmax operation. The basic linear attention formula,
Source: Image by Author
Let’s start with a single-layer construction that captures the essence of in-context learning. Imagine we have n training examples (x₁,y₁)…(xₙ,yₙ), and we want to predict y_{n+1} for a new input x_{n+1}.
Source: Image by Author
This looks very similar to what we got with gradient descent, except in linear attention we have an extra term ‘W’. What linear attention is implementing is something known as preconditioned gradient descent (PGD), where instead of the standard gradient step, we modify the gradient with a preconditioning matrix W,
Source: Image by Author
What we have shown here is that we can construct a weight matrix such that one layer of linear attention will do one step of PGD.
Conclusion
We saw how attention can implement “learning algorithms”, these are algorithms where basically if we provide lots of demonstrations (x,y) then the model learns from these demonstrations to predict the output of any new query. While the exact mechanisms involving multiple attention layers and MLPs are complex, researchers have made progress in understanding how in-context learning works mechanistically. This article provides an intuitive, high-level introduction to help readers understand the inner workings of this emergent ability of transformers.
To read more on this topic, I would suggest the following papers:
In-context Learning and Induction Heads
What Can Transformers Learn In-Context? A Case Study of Simple Function Classes
Transformers Learn In-Context by Gradient Descent
Transformers learn to implement preconditioned gradient descent for in-context learning
Acknowledgment
This blog post was inspired by coursework from my graduate studies during Fall 2024 at University of Michigan. While the courses provided the foundational knowledge and motivation to explore these topics, any errors or misinterpretations in this article are entirely my own. This represents my personal understanding and exploration of the material.
The Math Behind In-Context Learning was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.