DeepSeek Technical Analysis — (3) Multi-Token Prediction
Background
The is the 3rd blog of my DeepSeek Model technical analysis series blog, for the whole background please refer to the 1st blog of this series “DeepSeek Technical Analysis — (1) MoE”. For those who want to skip this blog and jump to your interested topic of this DeepSeek series, here is the blog list:
- Mixture-of-Experts which reduced the training cost and improved the inference efficiency.
- Multi-Head Latent Attention which reduced the KV cache for the attention part.
- Multi-Token Prediction which improved the performance(accuracy) of the model.
- DualPipe which improved the computation-to-communication ratio and efficiency of the large scale GPUs cluster.
- FP8 Training which reduced the training cost further through low precision training.
- DeepSeek-R1: incentivizing Reasoning Capability in LLMs via Reinforcement Learning.
In the last 2 blogs, I explained Mixture-of-Experts(MoE) and Multi-Head Latent Attention(MLA) respectively. The MoE reduced the training cost significantly by reducing the number of activated parameters for each token, for example the number of activated parameters of DeepSeek-V3-671B is 37B. The MLA reduced the KV cache size by 93.3% (compare to original Multi-Head Attention), and boosted the inference speed by several times.
In this blog, I’ll focus on another technique adopted by DeepSeek (start from version V3) — Multi-Token Prediction which can improve the performance(accuracy) of the model.
Next-Token Prediction
Large language models such as GPT and Llama are trained with a next-token prediction loss. These models learn about a large text corpus x1, . . . xT by implementing a next-token prediction task. Formally, the learning objective is to minimize the cross entropy loss:
where Pθ is our large language model under training, as to maximize the probability of xt+1 as the next future token, given the history of past tokens xt:1 = xt, . . . , x1.
Multi-Token Prediction
Better & Faster Large Language Models via Multi-token Prediction — 2024 this paper generalized the above by implementing a multi-token prediction task, where at each position of the training corpus, the model is instructed to predict n future tokens at once in parallel using independent output heads. This translates into the cross-entropy loss:
This paper did several experiments to find the optimal n(how many future tokens to predict), also verified that multi-token prediction can improve the performance(accuracy) over next-token prediction in training language models for generative or reasoning tasks.
This paper also demonstrated that multi-token prediction leads to qualitative changes in model capabilities and generalization behaviors. The reason behind that probably is because multi-token prediction mitigates the distributional discrepancy between training-time teacher forcing and inference-time autoregressive generation.
Multi-Token Prediction in DeepSeek
DeepSeek V3 adopted the main idea of Multi-Token Prediction mentioned in above paper, but with the change that sequentially predict additional tokens and keep the complete causal chain at each prediction depth.
Instead of a parallel-structure in the original paper, it is a chain-structure Multi-Token Prediction in the DeepSeek. The input tokens [t1,t2,t3,t4] go through the main model’s transformer blocks and then go through the output head of main model to produce next predicted token t5. Meanwhile the representation of the input tokens[t1,t2,t3,t4](output of main model’s transformer blocks) will be passed to the MTP module and combine with new input tokens’ embedding[t2,t3,t4,t5(new predicted)] to help produce additional token t6… In DeepSeek-V3, the model predicts next 2 tokens.
In this DeepSeek-V3 Technical Report, authors demonstrated that the Multi-Token Prediction can improve the performance in most cases.
My Comments
Does Multi-Token Prediction have improvement for all cases? Better & Faster Large Language Models via Multi-token Prediction — 2024 this paper demonstrated that multi-token prediction may introduce regression for multiple-choice and likelihood-based benchmarks. The MMLU(Massive Multitask Language Understanding) regression (67.5 -> 66.6) in DeepSeek with MTP matches this conclusion.