Transformer Clear Explanation: Attention Is All You Need! — 2017
Before Transformer, RNN(recurrent neural networks), LSTM(long short-term memory, a variant of RNN) and gated RNN have been firmly established as state of the art approaches in sequence modeling and transduction problems such as language modeling and machine translation. But RNN’s sequential computation constraint make it hard to training large amount of dataset and this constraint is an obvious bottleneck of pushing this model forward further when handling large amount of dataset.
The paper of “Attention Is All You Need” which is usually called “Transformer”, eschewed recurrence and instead rely entirely on an attention mechanism to draw global dependencies between input and output, this allow the model can parallelize the computation during training. Which makes this model is scalable to handle large amount of dataset and can make good use of modern high concurrent computing hardware like GPU. This transformer model definitely opened a new door for the AI era, and accelerated the evolution of AI, especially large language model. The famous BERT(Bidirectional Encoder Representations from Transformer) language model and GPT(Generative Pre-trained Transformer) language model are all variants of this transformer model.
When I first read this “Attention Is All You Need” paper, I had a lot of questions like:
- What is Attention Mechanism? Why Dot-Product can represent attention?
- Why Multi-Head Attention?
- What is the purpose of Feed-Forward Networks in both encoder unit and decoder unit?
- What do “Queries”, “Keys” and “Values” mean when calculate the Dot-Product Attention?
- What is the purpose of residual connections?
- …
In this blog, I will share answers for these questions I had.
Machine Translation Task
In this paper, authors use English-to-German and English-to-French translation tasks to measure this transformer model in terms of quality and training speed.
The machine translation process is input a sentence in source language, and then the model output a sentence in target language.
The standard WMT 2014 English-German dataset consisting of about 4.5 million sentence pairs was used when training the English-to-German translation model. For English-French, authors used the significantly larger WMT 2014 English-French dataset consisting of 36M sentences and split tokens into a 32000 word-piece vocabulary.
The Sequential Nature of RNN
Before we dive into the Transformer model, let’s briefly see how RNN handle this translation task. We refer to the paper of “Sequence to Sequence Learning with Neural Networks — 2014”, this paper introduced a deep neural network model that can handle language translation tasks, this model is built on LSTM(long short-term memory, a variant of RNN). The following diagram is from this paper, it shows the basic procedure of translating source language to target language when using RNN models.
A, B, C in the diagram are words of source language(English), W, X, Y, Z are words of target language (French), <EOS> means “End Of Sequence”. You don’t totally understand how RNN works in this translation task, the only thing you need to know is that in RNN, tokens (words) just can be input and handled sequentially one by one during training to let the model captures the order of sequence and the token dependencies. Input A and compute, input B and compute, input C and compute, input <EOS> compute and predict next word is W, input last predicted word W as additional input and predict next word X … This sequential nature of RNN limited the training speed, especially when the dataset is large.
In the paper “Sequence to Sequence Learning with Neural Networks — 2014” , authors used 2 different LSTMs: one for the input sequence and another for the output sequence. The LSTM handles input sequence is “Encoder”, and the LSTM generates output sequence is “Decoder”.
Transformer Architecture
Ok, let’s move to Transformer model architecture. The left part in above digram is the Encoder unit, the right part is the Decoder unit. The encoder maps an input sequence of symbol representations (x1, …, xn) to a sequence of continuous representations z = (z1, …, zn). Given z, the decoder then generates an output sequence (y1, …, ym) of symbols one element at a time. At each step the model is auto-regressive, consuming the previously generated symbols as additional input when generating the next(the “Outputs(shifted right)” part of above diagram).
Encoder and Decoder Stacks
The encoder is composed of a stack of N=6 identical layers, each layer has two sub-layers: multi-head self-attention mechanism and position-wise fully connected feed-forward network. Authors employed a residual connection around each of the two sub-layers, followed by layer normalization. That is, the output of each sub-layer is LayerNorm(x + Sublayer(x)), where Sublayer(x) is the function implemented by the sub-layer itself. To facilitate these residual connections, all sub-layers in the model, as well as the embedding layers, produce outputs of dimension dmodel = 512.
Let me explain a little bit about the purpose of the residual connection: 1) prevents vanishing gradient problem (see this video for more details about vanishing and exploring gradient problem in deep neural networks), 2) preserves information flow which means the network can easily pass important features forward without forcing them through transformations; 3) helps avoid the degradation problem where adding more layers can lead to higher training error, etc. See the paper of “Deep Residual Learning for Image Recognition” for more details about residual connection.
The decoder is also composed of a stack of N=6 identical layers, each layer has three sub-layers, besides the 2 layers that same as encoder, an additional Masked Multi-Head Attention layer is added. Residual connection is also employed for each sub-layer in the decoder. See above architecture diagram.
Emm, what is attention? What is Multi-Head Attention?
What is Attention?
In the RNN models, tokens(words) are inputed sequentially during training, in this way, the RNN models can learn the dependencies between current token and previous tokens in the same sentence. Some variants of RNN like BiLSTM (Bidirectional LSTM )models, add an additional LSTM layer to handle these sequential tokens in reverse order to learn the dependencies between current token and tokens behind it in the same sentence. See the paper “Neural Machine Translation By Jointly Learning to Align and Translate” for more details.
In the Transformer model, the dependencies between tokens(words) in the same sentence are captured by the “Attention” mechanism. Attention means just pay attention to those tokens in the same sentence which might have dependencies with current token when handling a specific token(word). In the example of “It is in this spirit that a majority of American governments have passed new laws since 2009 making the registration or voting more difficult.”, when the encoder is handling the token of “making”, it should pay attention to tokens like “more”, “difficult”, “laws”, “2009” and “making”(itself), and ignore other tokens.
Word Embeddings and Cosine Similarity
But how we know we should pay attention to these tokens when handling the token of “making”? Let’s talk about word embeddings technique at first.
Computers or models just can recognize and compute numbers, so we need a way that use numbers to represent different words and then feed these numbers to models. The intuitive idea of representing words with numbers is allocating a unique number for each word, like 1 = “me”, 2 = “you”, … etc. But in this way, this is no notion of similarity between words. To resolve this issue, the word embeddings techniques came out.
Word embeddings are techniques to represent words as numbers in a multi-dimensional space(vector). Word embeddings capture semantic relationships between words, allowing models to understand and represent words in a continuous vector space where similar words are close to each other. Efficient Estimation of Word Representations in Vector Space — 2012 (Jeff Dean is one of the authors) introduced an advanced word representation method that similar words tend to be close to each other in the vector space, words also can have multiple degrees of similarity (for example, nouns can have multiple word endings, and if we search for similar words in a subspace of the original vector space, it is possible to find words that have similar endings). More over, using a word offset technique where simple algebraic operations are performed on the word vectors, it was shown for example that vector(”King”) - vector(”Man”) + vector(”Woman”) results in a vector that is closest to the vector representation of the word Queen.
Let’s come back to the original question: “how we know what tokens in the same sentence we should pay attention to when handling a specific token?”. From above word embeddings description we know that words have semantic relationships (words have similar semantics, two words that often appears together in the same sentence, etc) are close to each other in the vector space. We can simply transform above attention question to “finding tokens that have semantics relationships with the given token in the vector space”. Cosine Similarity is a way to calculate the distance of two vectors in the vector space. The formula of cosine similarity is:
Cosine Similarity = (A · B) / (||A|| * ||B||)
where “·” represents the Dot Product and “||A||” and “||B||” are the magnitudes (lengths) of vectors A and B respectively.
Scaled Dot-Product Attention
And then, let’s move to the Scaled Dot-Production Attention which is described on the above diagram. The function of Scaled Dot-Production Attention is:
Let’s use the encoder’s Self-Attention Unit as an example to explain how this attention function works. The attention function has 3 inputs: Queries(Q), Keys(K) and Values(V). Queries are the representation of the current token for which the model wants to focus on relevant tokens in the input sequence. Keys are the representation of each token in the sequence that can be matched against queries. Values are the representation of the information that the model ultimately extracts or aggregates when attention is applied.
Let’s use the example in above “Attention Visualizations” diagram “It is in this spirit that a majority of American governments have passed new laws since 2009 making the registration or voting more difficult.” to illustrate more clearly what does queries, keys and values mean. For the 1st layer of self-attention unit in the encoder, the Q is a matrix that contains all the input sequence tokens’ vector representation with additional encoded position information(use matrix dot product instead of vector dot product one by one to parallelize the attention computation for all tokens), and then dot product with the transposed Keys Matrix, the result is the Attention Weight Matrix. The element ij in the attention weight matrix means how much attention should pay to the jth token when handling the ith token (the cosine similarity between the ith token and the jth token). In this way, we know which tokens we should pay attention to when handling token “making” in this example.
And then divide each weight in the attention weight matrix by √ dk where dk is the dimension of queries and keys, and apply a softmax function to obtain the weights on values for each row in this attention weight matrix. The purpose of divide by √ dk is to prevent overly large gradients. The softmax is a normalization method that transform the weights into scores that add up to 1 (essentially it is a way to transform weights into probability distribution of different tokens), see wiki for more details about softmax. After we get the normalized attention weight matrix, multiply with the values and produced the output of the scaled dot-product attention function. Remember, there are 6 identical layers both for the encoder and the decoder. The input of the encoder’s 2nd~6th self-attention unit’s layers is the output of encoder’s previous layer.
We can see that the attention mechanism parallelize the tokens’ computation in the way of matrix computing, without any sequential step comparing to RNN models. This is a key point of the transformer models.
Multi-Head Attention
The structure of Multi-Head Attention is:
The function of Multi-Head Attention is:
“Instead of performing a single attention function with dmodel-dimensional keys, values and queries, we found it beneficial to linearly project the queries, keys and values h times with different, learned linear projections to dk, dk and dv dimensions, respectively. On each of these projected versions of queries, keys and values we then perform the attention function in parallel, yielding dv-dimensional output values. These are concatenated and once again projected, resulting in the final values as depicted in the diagram above.”
The purpose of using multi-head attention, the paper said is: “Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this.”. Different “head” can focus on information from different representation subspaces at different positions, prevent averaging inhibits.
“In this work we employ h = 8 parallel attention layers, or heads. For each of these we use dk = dv = dmodel(512)/h = 64. Due to the reduced dimension of each head, the total computational cost is similar to that of single-head attention with full dimensionality.”
Both encoder and decoder have Multi-Head Attention unit. As showed in the last section, the input of the encoder’s (1st layer) Multi-Head Attention is vector representations of the input sequence tokens with encoded position information. For the 2nd to 6th layers, Keys, Values and Queries come from the output of the previous layer in the encoder.
For the decoder’s Multi-Head Attention, the Queries(Q) come from the output of decoder’s Masked Multi-Head Attention, the memory Keys(K) and Values(V) come from the output of the encoder unit, as showed in the “Transformer Model Architecture” diagram.
The Masked Multi-Head Attention layer in the decoder unit is similar with the Multi-Head Attention layer in the encoder, except: 1) its Q, K, V come from representation of the previously generated tokens; 2) its masking mechanism ensures that the predictions for position i can depend only on the known outputs at positions less than i.
As showed in the “Transformer Model Architecture” diagram, learned linear transformation and softmax function are used to convert the decoder output to predicted next-token probabilities.
Positional Encoding
In order for the Parallelization Transformer Model to make use of the order of sequence, authors injected some information about the relative and absolute position of the tokens in the sequence. The “Positional Encoding” has the same dimension as the embeddings, and is added to the input/output embeddings at the bottoms of the encoder and decoder stacks. The function of this Positional Encoding is:
The pos is the position and i is the dimension. That is, each dimension of the positional encoding corresponds to a sinusoid. The wavelengths form a geometric progression from 2π to 10000 · 2π. I draw 3 diagrams to show what the PE looks like for different pos and i. The 1st diagram shows PE(pos, 0) and PE(pos, 1), the 2nd shows PE(pos, 6) and PE(pos, 7), the 3rd shows PE(pos, 12) and PE(pos, 13). We can see that when the dimension i is small, the wavelength is relative short, and when the dimension i is large, the wavelength is relative long. Authors choose this sinusoidal version positional encoding because they hypothesized it would allow the model to easily learn to attend by relative positions, since for any fixed offset k, PE( pos+k) can be represented as a linear function of PE (pos). Meanwhile the sinusoidal version positional encoding allows the model to extrapolate to sequence lengths longer than the ones encountered during training.
Position-wise Feed-Forward Networks
“In addition to attention sub-layers, each of the layers in our encoder and decoder contains a fully connected feed-forward network, which is applied to each position separately and identically. This consists of two linear transformations with a ReLU (refer to ReLU in my AlexNet blog) activation in between.
FFN(x) = max(0, xW1 + b1)W2 + b2
While the linear transformations are the same across different positions, they use different parameters from layer to layer. Another way of describing this is as two convolutions with kernel size 1. The dimensionality of input and output is dmodel = 512, and the inner-layer has dimensionality dff = 2048.”
FFNs are applied independently to each token in the sequence to introduce non-linearity (ReLU) and allow the model to learn transformations at a deeper level, independent of other tokens. The Multi-Head Attention layer is focus on capturing attention between tokens in the sequence, and the FFN is focus on learning deep level transformations.
The input and output layers dimensions of FFNs d=512 which is aligned with the dmodel, and the hidden layer dimensions d=2048 is wider than the dmodel to allow for richer transformations. The W1, W2, b1 and b2 are shared across all positions in a layer, and they differ from layer to layer allowing each layer learn different transformations.
Why Self-Attention: Analysis of Computational Complexity, Computation Parallelization, Long-range Dependencies
The paper compared various aspects of self-attention layers to recurrent and convolutional layers. The “Complexity per Layer” in following table denotes the computational complexity of different models. “Sequential Operations” means minimum number of sequential operations required for different model, O(1) means most computation can be parallelized which is faster than models with O(n) sequential operations during training. “Maximum Path Length” denotes path length between long-range dependencies in the network. Learning long-range dependencies is a key challenge in many sequence transduction tasks. One key factor affecting the ability to learn such dependencies is the length of the paths forward and backward signals have to traverse in the network. The shorter these paths between any combination of positions in the input and output sequences, the easier it is to learn long-range dependencies.
The above diagram summarized the comparison results. n is the sequence length, d(=512 in this paper) is the representation dimension, k is the kernel size of convolutions and r the size of the neighborhood in restricted self-attention.
From above table we can see that from the “Sequential Operations” and the “Maximum Path Length” perspective, Self-Attention beats Recurrent and Convolutional models. From the computational complexity aspect, if n is less than d which is true for most cases(it is rare that a sentence has more than 512 tokens), Self-Attention beat Recurrent and Convolutional again. To improve computational performance for tasks involving very long sequences, self-attention could be restricted to considering only a neighborhood of size r in the input sequence centered around the respective output position. This would increase the maximum path length to O(n/r). Authors didn’t evaluate this work in this paper.
Training: Optimizer
The Adam optimizer with β1 = 0.9, β2 = 0.98 and ϵ = 10−9 is used, and the learning rate varied over the course of training. Adam optimizer is an adaptive learning rate method, meaning it adjusts the learning rate for each parameter during training. This makes it more efficient than standard gradient descent, which uses a fixed learning rate for all parameters.
This corresponds to increasing the learning rate linearly for the first warmup_steps(4000) training steps, and decreasing it thereafter proportionally to the inverse square root of the step number. The learning rate curve looks like:
Training: Regularization
2 major regularization methods Residual Dropout and Label Smoothing are employed in the transformer model. Dropout is widely used regularization technique to prevent overfitting, it introduces some randomness for the neural network structure to prevent the model rely too heavily on any specific part of the network. See my last blog AlexNet for more details. Label smoothing is a simple yet effective regularization technique that can significantly improve the performance of neural networks on various tasks, especially in image classification and natural language processing. The basic idea of label smoothing is improving the generalization and robustness of the model by softening the target labels (in our case is target language). Instead of assigning a probability of 1.0 to the correct class and 0.0 to all others in the target labels, label smoothing assigns a slightly reduced probability to the correct class and distributes the remaining probability mass among the incorrect classes.
The dropout is employed to the output of each sub-layer, before it added to the sub-layer input and normalized. This means the output of each sub-layer has a probability to drop. In addition, the dropout is also applied to the sums of the embeddings and the positional encodings in both the encoder and decoder tasks. A rate of Pdrop = 0.1 was used during training.
In this paper, authors employed label smoothing of value ϵloss = 0.1, it means for the target language in the sentence pairs of training set, 90% probability is applied. See the 7th section of this paper(Rethinking the Inception Architecture for Computer Vision) for more details about label smoothing.
Results:
The Transformer achieves better BLEU scores than previous state-of-the-art models on the English-to-German and English-to-French newstest2014 tests at a fraction of the training cost.
The Generalization of Transformer for Other Tasks
In order to evaluate the generalization of this transformer model for other tasks, authors performed experiments on English constituency parsing. The major difference compare to the English-to-German task is authors trained a 4-layer(instead of 6) transformer with dmodel=1024(instead of 512), beam size=21(instead of 4) and alpha=0.3 (instead of 0.6). The experiments showed that despite the lack of task-specific tuning the transformer model performs surprisingly well, yielding better results than all previously reported models with the exception of the Recurrent Neural Network Grammar. It means the transformer model has good generalization that can handle various different tasks.