Speculative Sampling
Speculative Sampling
In this article, I aim to reinterpret and rephrase the concepts from a video that I’ve recently watched and get inspired. The topic is to explain the term Speculative Sampling
, a technique in accelerating inference time for LLMs. The original credit should go to them
Speculative Sampling is a technique that speeds up autoregressive language model sampling through the use of a smaller ‘draft’ model. It is often called alternatively as “assisted generation”.On some data types this can give a 2x speedup with no loss in accuracy.
Inefficiency in LLM Inference
We first explain how it speeds up the inference time of large language model.
Autoregressive models, based on mathematical principles, pose a specific challenge when it comes to inference, especially in the context of language models. This drawback is why sampling in LLMs are expensive.
At the heart of autoregressive models is the chain rule of conditional probability. This means when we’re estimating the probability of a token $p(x_t)$, it’s contingent on the entire preceding sequence. Thus, for each token we want to generate, a brand-new forward pass over the entire preceding sequence is needed.
For clarity, let’s use an example. Suppose we are generating tokens following a given context: “<context>
I like a”. The generation process can be broken down as follows:
- In the first instance, the model samples and outputs the probability of the token “I”.
- Given the autoregressive nature, to predict the next token, we have to incorporate “I” into the context and restart the forward pass. The model then samples and gives us the probability for “like”.
- Again, following the same logic, to generate the next token “a”, the model requires yet another fresh forward pass with the updated context.
So, in this example, we find ourselves invoking the model three times consecutively to generate three new tokens.
However, there’s a nuance here. The structure of the Transformer (often used for these tasks), particularly its decoder, is designed to take in the entire sequence or context at once. The reason being, within a Transformer, every position in the input sequence has the capability to “attend” to every other position. Hence, when given a sequence, the Transformer provides a predicted token for each position in that sequence. This feature is beneficial during training because we have the ground truth labels and can easily compare the model’s predictions with them.
But during inference, a bottleneck emerges. Typically, we’re only interested in the last output token for our generated sequence. But, because of how Transformers work, to get the next token prediction, we need to feed in the entire sequence again and run another forward pass. This iterative approach is not time or computation efficient.
Speculative Sampling
Now that we understand the crux of the existing problem, we see how speculative sampling cope with this.
Noticing that no matter what we try, we have to undergo the process of inferencing new tokens using iterations, why not use smaller model to do the inference. By using a faster but less accurate model to generate our answer sequence, we get the so-called draft sequence, it can then serve as something similar to a “ground truth” for training. Specifically, we can use the large model to “examine” this draft sequence just like in a training scenery, because we can input the known draft sequence to the model at the same time. Finally we collect the output from the large model’s transformer, and compare it to the draft sequence. If there is a mismatch, we just throw away the subsequent draft sequence. Then assume the small model did everything correctly, we only need to invoke the large model once to generate our answer, at the expense of calling the small model several times. Even if there is a mistake, we can quickly throw away the false result and generate a new one, since the bottleneck is from the large model.
This is actually the rough idea of speculative sampling. Of course there are some implementing details, but the core concept of the problem is show above.
There are some tiny implications in the above logic. To support the above idea, we not only need the small model to be enough fast and enough strong at the same time, we also need the fact that the big model is faster at dealing with long context once, then at iteratively executing several samplings with relatively short context. The former is justified by experiments on code inference or other tasks. The latter, being not very intuitive, relies on the following fact, which is mentioned by @Karpathy in his blog. I quoted the following sentence as the reason to the problem.
“This unintuitive fact is because sampling is heavily memory bound: most of the “work” is not doing compute, it is reading in the weights of the transformer from VRAM into on-chip cache for processing”
Therefore starting a new round of inference which requires reloading weights to cache is more time-consuming, rather than inputting a long context in the Transformer which tends to involve much computation.
This make up for the last piece of the whole picture, I hope the explanation is clear. Again, if you want explanations with a bit of drawings and more intuitive, you can opt for the video