Vision Transformer with BatchNorm: Optimizing the depth
Vision Transformer with BatchNorm
How integrating BatchNorm in a standard Vision transformer architecture leads to faster convergence and a more stable network
Comparing Vision Transformer without and with BatchNorm at various depths.
Introduction
The Vision Transformer (ViT) is the first purely self-attention-based architecture for image classification tasks. While ViTs do perform better than the CNN-based architectures, they require pre-training over very large datasets. In an attempt to look for modifications of the ViT which may lead to faster training and inference — especially in the context of medium-to-small input data sizes — I began exploring in a previous article ViT-type models which integrate Batch Normalization (BatchNorm) in their architecture. BatchNorm is known to make a deep neural network converge faster — a network with BatchNorm achieves higher accuracy compared to the base-line model when trained over the same number of epochs. This in turn speeds up training. BatchNorm also acts as an efficient regularizer for the network, and allows a model to be trained with a higher learning rate. The main goal of this article is to investigate whether introducing BatchNorm can lead to similar effects in a Vision Transformer.
For the sake of concreteness, I will focus on a model where a BatchNorm layer is introduced in the Feedforward Network (FFN) within the transformer encoder of the ViT, and the LayerNorm preceding the FFN is omitted. Everywhere else in the transformer — including the self-attention module — one continues to use LayerNorm. I will refer to this version of ViT as ViTBNFFN — Vision Transformer with BatchNorm in the Feedforward Network. I will train and test this model on the MNIST dataset with image augmentations and compare the Top-1 accuracy of the model with that of the standard ViT over a number of epochs. I will choose identical architectural configuration for the two models (i.e. identical width, depth, patch size and so on) so that one can effectively isolate the effect of the BatchNorm layer.
Here’s a quick summary of the main findings:
For a reasonable choice of hyperparameters (learning rate and batch size), ViTBNFFN does converge faster than ViT, provided the transformer depth (i.e number of layers in the encoder) is sufficiently large.As one increases the learning rate, ViTBNFFN turns out to be more stable than ViT, especially at larger depths.
I will open with a brief discussion on BatchNorm in a deep neural network, illustrating some of the properties mentioned above using a concrete example. I will then discuss in detail the architecture of the model ViTBNFFN. Finally, I will take a deep dive into the numerical experiments that study the effects of BatchNorm in the Vision Transformer.
The Dataset : MNIST with Image Augmentation
Let us begin by introducing the augmented MNIST dataset which I will use for all the numerical experiments described in this article. The training and test datasets are given by the function get_datasets_mnist() as shown in Code Block 1.
https://medium.com/media/42c08e3fd14dd9632ad557d996073a52/href
The important lines of code are given in lines 5–10, which list the details of the image augmentations I will use. I have introduced three different transformations:
RandomRotation(degrees=20) : A random rotation of the image with the range of rotation in degrees being (-20, 20).RandomAffine(degrees = 0, translate = (0.2, 0.2)) : A random affine transformation, where the specification translate = (a, b) implies that the horizontal and vertical shifts are sampled randomly in the intervals [- image_width × a, image_width × a] and [-image_height × b, image_height × b] respectively. The degrees=0 statement deactivates rotation since we have already taken it into account via random rotation. One can also include a scale transformation here but we implement it using the zoom out operation.RandomZoomOut(0,(2.0, 3.0), p=0.2) : A random zoom out transformation, which randomly samples the interval (2.0, 3.0) for a float r and outputs an image with output_width = input_width × r and output_height = input_height × r. The float p is the probability that the zoom operation is performed. This transformation is followed by a Resize transformation so that the final image is again 28 × 28.
Batch Normalization in a Deep Neural Network
Let us give a quick review of how BatchNorm improves the performance of a deep neural network. Suppose zᵃᵢ denotes the input for a given layer of a deep neural network, where a is the batch index which runs from a=1,…, Nₛ and i is the feature index running from i=1,…, C. The BatchNorm operation then involves the following steps:
For a given feature index i, one first computes the mean and the variance over the batch of size Nₛ i.e.
2. One normalizes the input using the mean and variance computed above (with ϵ being a small positive number):
3. Finally, one shifts and rescales the normalized input for every feature i:
where there is no summation over the index i, and the parameters (γᵢ, βᵢ) are trainable.
Consider a deep neural network for classifying the MNIST dataset. I will choose a network consisting of 3 fully-connected hidden layers, with 100 activations each, where each hidden layer is endowed with a sigmoid activation function. The last hidden layer feeds into a classification layer with 10 activations corresponding to the 10 classes of the MNIST dataset. The input to this neural network is a 2d-tensor of shape b × 28² — where b is the batch size and each 28 × 28 MNIST image is reshaped into a 28²-dimensional vector. In this case, the feature index runs from i=1, …, 28².
This model is similar to the one discussed in the original BatchNorm paper — I will refer to this model as DNN_d3. One may consider a version of this model where one adds a BatchNorm layer before the sigmoid activation function in each hidden layer. Let us call the resultant model DNNBN_d3. The idea is to understand how the introduction of the BatchNorm layer affects the performance of the network.
To do this, let us now train and test the two models on the MNIST dataset described above, with CrossEntropyLoss() as the loss function and the Adam optimizer, for 15 epochs. For a learning rate lr=0.01 and a training batch size of 100 (we choose a test batch size of 5000), the test accuracy and the training loss for the models are given in Figure 1.
Figure 1. Test Accuracy (left) and Training Loss (right) for the two models over 15 epochs with lr=0.01.
Evidently, the introduction of BatchNorm makes the network converge faster — DNNBN achieves a higher test accuracy and lower training loss. BatchNorm can therefore speed up training.
What happens if one increases the learning rate? Generally speaking, a high learning rate might lead to gradients blowing up or vanishing, which would render the training unstable. In particular, larger learning rates will lead to larger layer parameters which in turn give larger gradients during backpropagation. BatchNorm, however, ensures that the backpropagation through a layer is not affected by a scaling transformation of the layer parameters (see Section 3.3 of this paper for more details). This makes the network significantly more resistant to instabilities arising out of a high learning rate.
To demonstrate this explicitly for the models at hand, let us train them at a much higher learning rate lr=0.1 — the test accuracy and the training losses for the models in this case are given in Figure 2.
Figure 2. Test Accuracy (left) and Training Loss (right) for the two models over 15 epochs with lr=0.1.
The high learning rate manifestly renders the DNN unstable. The model with BatchNorm, however, is perfectly well-behaved! A more instructive way to visualize this behavior is to plot the accuracy curves for the two learning rates in a single graph, as shown in Figure 3.
Figure 3. The accuracy curves at two different learning rates for DNN_d3 (left) and DNNBN_d3(right).
While the model DNN_d3 stops training at the high learning rate, the impact on the performance of DNNBN_d3 is significantly milder. BatchNorm therefore allows one to train a model at a higher learning rate, providing yet another way to speed up training.
The Model ViTBNFFN : BatchNorm in the FeedForward Network
Let us begin by briefly reviewing the architecture of the standard Vision Transformer for image classification tasks, as shown in the schematic diagram of Figure 4. For more details, I refer the reader to my previous article or one of the many excellent reviews of the topic in Towards Data Science.
Figure 4. Schematic representation of the ViT architecture.
Functionally, the architecture of the Vision Transformer may be divided into three main components:
Embedding layer : This layer maps an image to a “sentence” — a sequence of tokens, where each token is a vector of dimension dₑ (the embedding dimension). Given an image of size h × w and c color channels, one first splits it into patches of size p × p and flattens them — this gives (h × w)/p² flattened patches (or tokens) of dimension dₚ = p² × c, which are then mapped to vectors of dimension dₑ using a learnable linear transformation. To this sequence of tokens, one adds a learnable token — the CLS token — which is isolated at the end for the classification task. Schematically, one has:
Finally, to this sequence of tokens, one adds a learnable tensor of the same shape which encodes the positional embedding information. The resultant sequence of tokens is fed into the transformer encoder. The input to the encoder is therefore a 3d tensor of shape b × N × dₑ — where b is the batch size, N is the number of tokens including the CLS token, and dₑ is the embedding dimension.
2. Transformer encoder : The transformer encoder maps the sequence of tokens to another sequence of tokens with the same number and the same shape. In other words, it maps the input 3d tensor of shape b × N × dₑ to another 3d tensor of the same shape. The encoder can have L distinct layers (defined as the depth of the transformer) where each layer is made up of two sub-modules as shown in Figure 5— the multi-headed self-attention (MHSA) and the FeedForward Network (FFN).
Figure 5. Sub-modules of the transformer encoder.
The MHSA module implements a non-linear map on the 3d tensor of shape b × N × dₑ to a 3d tensor of the same shape which is then fed into the FFN as shown in Figure 2. This is where information from different tokens get mixed via the self-attention map. The configuration of the MHSA module is fixed by the number of heads nₕ and the head dimension dₕ.
The FFN is a deep neural network with two linear layers and a GELU activation in the middle as shown in Figure 6.
Figure 6. The FFN module inside a layer of the transformer encoder.
The input to this sub-module is a 3d tensor of of shape b × N × dₑ. The linear layer on the left transforms it to a 3d tensor of shape b × N × d_mlp, where d_mlp is the hidden dimension of the network. Following the non-linear GELU activation, the tensor is mapped to a tensor of the original shape by the second layer.
3. MLP Head : The MLP Head is a fully-connected network that maps the output of the transformer encoder — 3d tensor of shape b × N × dₑ — to a 2d tensor of shape b × d_num where d_num is the number of classes in the given image classification task. This is done by first isolating the CLS token from the input tensor and then putting it through the connected network.
The model ViTBNFFN has the same architecture as described above with two differences. Firstly, one introduces a BatchNorm Layer in the FFN of the encoder between the first linear layer and the GELU activation as shown in Figure 7. Secondly, one removes the LayerNorm preceding the FFN in the standard ViT encoder (see Figure 5 above).
Figure 7. The FFN submodule for the ViTBNFFN model.
Since the linear transformation acts on the third dimension of the input tensor of shape b × N × dₑ , we should identify dₑ as the feature dimension of the BatchNorm. The PyTorch implementation of the new feedforward network is given in Code Block 2.
https://medium.com/media/e684b09cc94115f87c366c98dce414a8/href
The built-in BatchNorm class in PyTorch always takes the first index of a tensor as the batch index and the second index as the feature index. Therefore, one needs to transform our 3d tensor with shape b × N × dₑ to a tensor of shape b × dₑ × N before applying BatchNorm, and transforming it back to b × N × dₑ afterwards. In addition, I have used the 2d BatchNorm class (since it is slightly faster than the 1d BatchNorm). This requires promoting the 3d tensor to a 4d tensor of shape b × dₑ × N × 1 (line 16) and transforming it back (line 18) to a 3d tensor of shape b × N × dₑ. One can use the 1d BatchNorm class without changing any of the results presented in the section.
The Experiment
With a fixed learning rate and batch size, I will train and test the two models — ViT and ViTBNFFN — on the augmented MNIST dataset for 10 epochs and compare the Top-1 accuracies on the validation dataset. Since we are interested in understanding the effects of BatchNorm, we will have to compare the two models with identical configurations. The experiment will be repeated at different depths of the transformer encoder keeping the rest of the model configuration unchanged. The specific configuration for the two models that I use in this experiment is given as follows :
Embedding layer: An MNIST image is a grey-scale image of size 28× 28. The patch size is p= 7, which implies that the number of tokens is 16 + 1 =17 including the CLS token. The embedding dimension is dₑ = 64.Transformer encoder: The MHSA submodule has nₕ = 8 heads with head dimension dₕ=64. The hidden dimension of the FFN is d_mlp = 128. The depth of the encoder will be the only variable parameter in this architecture.MLP head: The MLP head will simply consist of a linear layer.
The training and testing batch sizes will be fixed at 100 and 5000 respectively for all the epochs, with CrossEntropyLoss() as the loss function and Adam optimizer. The dropout parameters are set to zero in both the embedding layer as well as the encoder. I have used the NVIDIA L4 Tensor Core GPU available at Google Colab for all the runs, which have been recorded using the tracking feature of MLFlow.
Let us start by training and testing the models at the learning rate lr= 0.003. Figure 8 below summarizes the four graphs which plot the accuracy curves of the two models at depths d=4, 5, 6 and 7 respectively. In these graphs, the notation ViT_dn (ViTBNFFN_dn) denotes ViT (ViTBNFFN) with depth of the encoder d=n and the rest of the model configuration being the same as specified above.
Figure 8. Comparison of the accuracy curves of the two models at lr=0.003 for depths 4,5,6 and 7.
For d= 4 and d= 5 (the top row of graphs), the accuracies of the two models are comparable — for d=4 (top left) ViT does somewhat better, while for d=5 (top right) ViTBNFFN surpasses ViT marginally. For d < 4, the accuracies remain comparable. However, for d=6 and d=7 (the bottom row of graphs), ViTBNFFN does significantly better than ViT. One can check that this qualitative feature remains the same for any depth d ≥ 6.
Let us repeat the experiment at a slightly higher learning rate lr = 0.005. The accuracy curves of the two models at depths d=1, 2, 3 and 4 respectively are summarized in Figure 9.
Figure 9. Comparison of the accuracy curves of the two models at lr=0.005 for depths 1,2,3 and 4.
For d= 1 and d= 2 (the top row of graphs), the accuracies of the two models are comparable — for d=1 ViT does somewhat better, while for d=2 they are almost indistinguishable. For d=3 (bottom left), ViTBNFFN achieves a slightly higher accuracy than ViT. For d=4 (bottom right), however, ViTBNFFN does significantly better than ViT and this qualitative feature remains the same for any depth d ≥ 4.
Therefore, for a reasonable choice of learning rate and batch size, ViTBNFFN converges significantly faster than ViT beyond a critical depth of the transformer encoder. For the range of hyperparameters I consider, it seems that this critical depth gets smaller with increasing learning rate at a fixed batch size.
For the deep neural network example, we saw that the impact of a high learning rate is significantly milder on the network with BatchNorm. Is there something analogous that happens for a Vision Transformer? This is addressed in Figure 10. Here each graph plots the accuracy curves of a given model at a given depth for two different learning rates lr=0.003 and lr=0.005. The first column of graphs corresponds to ViT for d=2, 3 and 4 (top to bottom) while the second column corresponds to ViTBNFFN for the same depths.
Figure 10. Accuracy curves for ViT and ViTBNFFN for two learning rates at different depths.
Consider d=2 — given by the top row of graphs — ViT and ViTBNFFN are comparably impacted as one increases the learning rate. For d = 3 — given by the second row of graphs — the difference is significant. ViT achieves a much lower accuracy at the higher learning rate — the accuracy drops from about 91% to around 78% at the end of epoch 10. On the other hand, for ViTBNFFN, the accuracy at the end of epoch 10 drops from about 92% to about 90%. This qualitative feature remains the same at higher depths too — see the bottom row of graphs which corresponds to d=4. Therefore, the impact of the higher learning rate on ViTBNFFN looks significantly milder for sufficiently large depth of the transformer encoder.
Conclusion
In this article, I have studied the effects of introducing a BatchNorm layer inside the FeedForward Network of the transformer encoder in a Vision Transformer. Comparing the models on an augmented MNIST dataset, there are two main lessons that one may draw. Firstly, for a transformer of sufficient depth and for a reasonable choice of hyperparameters, the model with BatchNorm achieves significantly higher accuracy compared to the standard ViT. This faster convergence can greatly speed up training. Secondly, similar to our intuition for deep neural networks, the Vision Transformer with BatchNorm is more resilient to a higher learning rate, if the encoder is sufficiently deep.
Thanks for reading! If you have made it to the end of the article and enjoyed it, please leave claps and/or comments and follow me for more content! Unless otherwise stated, all images and graphs used in this article were generated by the author.
Vision Transformer with BatchNorm: Optimizing the depth was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.