Skip to main content

Efficient Transformers II: Knowledge Distillation & Fine-Tuning

· 11 min read
Harshil Shah

This two-part post looks at how to make state of the art NLP more efficient by exploring modifications to the popular but computationally demanding Transformer-based language modelling techniques.

The previous post:

  • Explained why the Transformer’s self-attention mechanism has a high computational workload.
  • Presented alternative attention mechanisms which are more efficient to run without significantly compromising performance.

This post will:

  • Explore methods which train small models to reproduce the outputs of large models.
  • Explain how to fine-tune language models efficiently.
  • Provide our recommendations for scenarios in which to use the different efficient Transformer approaches.

The previous post included a brief history of semantic representation learning in NLP, and an overview of how the Transformer’s self-attention mechanism works. We suggest first reading those sections for readers who may be unfamiliar. This post also shares some of the notation introduced in the previous post.

Knowledge distillation

Knowledge distillation is an area of research into more efficient Transformers which trains small models (students) by encouraging them to reproduce the outputs of large models (teachers). This is a technique which initially gained popularity on classification tasks in computer vision, but has been successfully applied in several domains, including NLP. The typical workflow is:

  1. Train a large model using generic labelled data.
  2. Train a small model to mimic the large model using task-specific unlabelled data (and task-specific labelled data, if available).

Although this process still involves training a large model, this is a one-off cost. The more frequent task of making predictions will be done by the small model, which is significantly more efficient to run. As a result, knowledge distillation is a particularly popular technique for running machine learning in hardware constrained environments, e.g. on mobile devices.

tip

It is worth considering that a small model could simply be trained (from scratch) on the same data used to train the large one. However, the small model may not have the capacity to learn representations of the same quality as the large one. The small model learns from the large model’s predicted probabilities, which typically encode more information than the class label alone. This allows the small model to learn richer representations and in a lot of scenarios, means that the small model has better predictive power than if trained from scratch.

Consider a document xx and class label yy, with the class label belonging to one of CC categories (i.e. y{1,2,,C}y \in \{1,2,\dots,C\}). Denote the probability with which a large model ff predicts that document xx has class label y=cy=c as p(c;f(x))p(c;f(x)). This probability is usually computed using a function of the form:

p(c;f(x))=softmax(f(x))c\begin{align} p(c;f(x)) = \operatorname{softmax}(f(x))_{c} \end{align}

where f(x)f(x) is the output of a neural network (e.g. a Transformer) which takes xx as input. The large model is trained using the following maximum likelihood objective:

Jf=c=1CI(c;y)logp(c;f(x))\begin{align} J_{f} = \sum_{c=1}^{C} \mathbb{I}(c;y) \log p(c;f(x)) \end{align}

where I(c;y)=1 if y=c else 0\mathbb{I}(c;y) = 1 \text{ if } y = c \text{ else } 0.

The small model gg can be trained to reproduce the probabilities predicted by the large model ff using an objective of the form:

Jg=c=1Cobj[f(x)c,g(x)c]\begin{align} J_{g} = \sum_{c=1}^{C} \operatorname{obj}[f(x)_{c},g(x)_{c}] \end{align}

Examples of objective functions include:

  • A maximum likelihood style objective: Jg=c=1Cp(c;f(x))logp(c;g(x))J_{g} = \sum_{c=1}^{C} p(c;f(x)) \log p(c;g(x))
    • This is equivalent to minimising the KL divergence between p(c;f)p(c;f) and p(c;g)p(c;g).
  • The negative mean squared error (MSE) between logits: Jg=1Cc=1C[f(x)cg(x)c]2J_{g} = -\frac{1}{C} \sum_{c=1}^{C} [f(x)_{c} - g(x)_{c}]^{2}

If task-specific labelled data are available when training the small model, the supervised objective and the distillation objective are combined using a weighted average:

Jg=αc=1Cobj[f(x)c,g(x)c]+(1α)c=1CI(c;y)logp(c;g(x))\begin{align} J_{g} = \alpha \sum_{c=1}^{C} \operatorname{obj}[f(x)_{c},g(x)_{c}] + (1 - \alpha) \sum_{c=1}^{C} \mathbb{I}(c;y) \log p(c;g(x)) \end{align}

Task-specific distillation

The biLSTM-SOFT model performs task-specific distillation by attempting to reproduce the predictions of an already fine-tuned BERT model on classification tasks. For the student model, it uses a single layer bidirectional LSTM. Although this is a recurrent model, because it only has one layer it is still quick to run.

The distillation objective is the negative MSE between the student’s and teacher’s logits. It is over 400x faster to run than the BERT model it is distilled from, but performs 4–7 accuracy/F1 points worse (depending on the task).

Distilling during pre-training

So far, this post has presented knowledge distillation in the context of supervised learning, as this is the setting in which it is most commonly used. However, DistilBERT performs knowledge distillation at both the language model pre-training and fine-tuning stages.

As explained in the previous post’s Background section, BERT is pre-trained using masked language modelling; DistilBERT treats the missing words as the class labels, and uses the maximum likelihood style distillation objective function. It uses BERT’s predicted probability distributions for the missing words as the soft targets for the distillation objective. The authors also add a cosine embedding objective, which encourages the small model to align the directions of its embeddings with those produced by BERT.

DistilBERT adopts the same basic architecture as BERT, but has half as many layers and is approximately 38% faster to run. When distilled during pre-training only, it retains 97% of BERT’s performance. The authors also found that performing task-specific distillation during fine-tuning (using a BERT model which had also been fine-tuned on the same task) gave an additional boost to performance.

Exploiting the Transformer architecture

TinyBERT is an approach that is similar to DistilBERT in that it performs knowledge distillation at both the language model pre-training and fine-tuning stages. However, TinyBERT directly takes knowledge from intermediate representations of BERT (not just the final outputs) by specifically exploiting features of the Transformer architecture.

As with DistilBERT, TinyBERT adopts the same architecture as BERT but with fewer layers. First, a mapping is defined from each layer of the student model to a layer of the teacher model, i.e. each student layer is associated with one teacher layer. Then, depending on the student layer, it uses one of three distillation objective functions:

  • Embedding layer
    • Minimises the MSE between the student’s and teacher’s embedding matrices.
  • Attention layers
    • Minimises the MSE between the student and teacher attention matrices (A\mathbf{A} in Equation (4) in the previous post) plus the MSE between the student’s and teacher’s outputs of the feedforward layers which follow the self-attention operation.
  • Final (prediction) layer
    • Uses the maximum likelihood style distillation objective to try to match the student’s and teacher’s predicted probabilities. This is the same as DistilBERT.

Compared to the BERT model it is distilled from, TinyBERT is 2x–9x faster to run (depending on the number of layers and hidden state sizes). It consistently outperforms DistilBERT on a wide range of tasks, indicating that the distillation objectives at each layer are helpful compared to the final layer alone.

Fine-tuning

As discussed in the Background section of the previous post, the final step of training large language models is usually to fine tune them on the task of interest. Although this stage can be relatively quick to run, it still generally involves updating all of the parameters of the model. This means that the hardware requirements are the same as for the pre-training stages. Given that the fine-tuning step is typically run separately for each task, this is still an expensive stage of the training process. Therefore another line of research looks to reduce the number of parameters to be updated during fine-tuning.

Fine-tuning a subset of the weights

One way to avoid having to update all of the parameters of the model is simply to freeze some of the layers. Lee et al. perform an empirical study of the effectiveness of this approach. They find, with a 12-layer BERT model, that freezing the first 9 layers and only fine-tuning the final 3 reaches at least 90% of the performance of full fine-tuning on most tasks. However, freezing the entire language model and simply training the final prediction layer performs significantly worse across all tasks.

Only updating the bias terms

Most operations in most neural network architectures involve multiplying an input by a matrix and then adding a bias term. The ways in which these operations are composed is what defines the architecture.

In Transformers, the bias terms (e.g. the b\mathbf{b} terms in Equations (1)–(3) in the previous post) represent less than 0.1% of the total parameters. Therefore BitFit proposes to only update these during fine-tuning, and to freeze the rest of the parameters. With limited labelled training data, BitFit performs competitively against (and sometimes better than) fine-tuning the entire model. With large training data sets, it performs only slightly worse than full fine-tuning.

Inserting small trainable networks into the Transformer

Another set of parameter-efficient fine-tuning methods freeze the entire pre-trained language model, but introduce a small set of additional parameters which are trained for the task of interest. Adapters do this by inserting two 2-layer feedforward networks within each of the Transformer layers. They are inserted directly before and after the existing feedforward network which follows the self-attention mechanism. The 2 layers perform the following operations:

  • The first layer down-projects the Transformer hidden state to a low-dimensional vector, and applies a nonlinearity.
  • The second layer up-projects the low-dimensional vector back to the Transformer hidden state size.

The idea behind this is that inserting learnable parameters throughout the Transformer architecture (rather than just training the final prediction layer) allows the model to adjust its internal representations in the same way that fine-tuning does, but in a much more efficient way. Adapter tuning is only ~0.4% worse than full fine tuning, with only 3.6% as many trained parameters.

Recommendations

These two posts have covered several different efficient Transformer methods. When might it be appropriate to use them?

  • If you are working with particularly long documents, then we recommend using an efficient attention mechanism, such as those covered in the previous post. This will reduce both the pre-training and inference times. Performance may even be better than using full-fledged attention.
  • If your bottleneck is memory (i.e. even a batch of short documents causes out-of-memory errors), try using a distilled model (e.g. TinyBERT). With fewer layers and lower-dimensional representations, this will reduce the model’s memory footprint while also being faster to train.
    • However, bear in mind that distilled models still require a pre-trained full-sized model to learn from, and they generally work better when distilled both during pre-training and fine-tuning.
  • If you have limited labelled data at the fine-tuning stage, using a fine-tuning method which freezes most/all of the language model parameters (e.g. Adapters) mitigates the risk of overfitting and should typically perform better than fine-tuning the entire model.
  • If you have lots of supervised downstream tasks, we suggest using an efficient fine-tuning method. This will reduce the number of parameters which need to be trained and stored for each task.
    • Although the efficient fine-tuning methods may be faster to train than full fine-tuning, they generally won’t improve inference speed (i.e. getting predictions from a trained model). In fact, some of them may be slightly slower for inference, due to the insertion of additional layers.

Summary

This two part series looks at how to make state of the art natural language processing (NLP) widely accessible by exploring efficient alternatives to the popular but computationally demanding Transformer-based language modelling techniques.

The previous post:

  • Explained why the Transformer’s self-attention mechanism has a high computational workload.
  • Presented alternative attention mechanisms which are more efficient to run without significantly compromising performance.

This post:

  • Discussed knowledge distillation techniques, which train small, efficient models by encouraging them to reproduce the outputs of their larger counterparts.
  • Explored efficient methods to fine-tune language models by reducing the number of parameters that are updated.
  • Provided our recommendations for scenarios in which to use the different efficient Transformer approaches.

If you want to try Re:infer at your company, sign up for a free trial or book a demo.