18 Systematically Improving a Model
It would be nice to train a model once, evaluate it & deploy it, and then move on to the next task. However, this is not how an excellent model is built. To build an excellent model we have to acquire an in-depth understanding of what the model does well, what it does poorly, and we need to monitor this over time. And we need to repeat this process many times. We do this by using the model and looking at our data.
There is a lot of information on the internet for improving ML models, here we collect tips & tricks that we can personally vouch for, for text classification.
As you work through these steps, use your training and testing data only. Do not open your validation set until the very end. The validation set cannot be used for model selection or in hyper-parameter tuning - that is what the training & testing sets are for. Once you have selected your final model, use the validation set to confirm that your model performs similarly on unseen data as it did in training/testing.
If you are forced to open your validation set, e.g. to root out inconsistent labels or data problems. That validation set can now go into training/testing data, and you need to label a new validation set.
With that in mind, let’s move on practical steps for systematically improving our models.
Data & Discrete Variables
We are going to keep saying this because it is so important: start with data. Everything else you do will have a severely-reduced and limited effect without enough data, without enough variance within your data, or without the patterns you want the model to learn.
And yet it is common when starting out in Data Science to focus on model/algorithm selection, hyper-parameter tuning, optimisers, dropout & regularisation, or anything else except the thing we should be focusing on: the data. Whilst those things are important, keep in mind that the data determines what the model can learn, algorithms & hyper-parameters etc. determine how efficiently the model learns what’s in the data.
With that in mind, we should prioritise data strategy above everything else. This means spending the majority of our time collecting, cleaning, and labelling data. Barring elementary mistakes and training run errors, normally your model is not as good as it could be because:
- The task is too difficult for the model you are using
- You do not have enough data
- There are problems within your data (e.g. inconsistent labels, lack of balance between groups or labels)
When thinking about 1., keep in mind that this is likely to be true both at the overall dataset level, and within the groups present in your data. With some groups needing more data than others. You can assess whether a group needs more data as a function of 1) how well the model performs on that group, and 2) how frequent the group is expected to be when you deploy the model.
If the model performs extremely well on a group, you should prioritise finding data for other groups. If a model is performing poorly, but the expected frequency of that group is low, prioritise finding data for other groups. Where possible we want to target the highest impact areas.
You need to get some.
Without a meaningful discrete variable to slice your data with, you will really struggle to find systematic routes to improving your model. Given that we tend to be working with text, there are a number of things you could try.
Ultimately you should determine groups from the data. What did you notice about your data as you were labelling it? If you noticed that long posts tended to be a certain way, consider breaking your data into buckets of string length, or word count, and calculating metrics (accuracy, precision, recall, F1) for each bucket.
String length is unlikely to be the killer slice of your data. It’s more likely that breaking your data into groups which have semantic differences will give you the information you need to improve your model. This suggests topic modelling or clustering. Once you have your topics/clusters, calculate metrics and identify which groups need more data.
Modelling Outputs
If your model still isn’t performing as you expect or need it to, you’re going to need to look at more data - just of a slightly different kind. Instead of looking at the inputs, we’re going to focus on the model’s outputs.
Modelling Outputs - Loss
Apart from the classification metrics, we can extract logits (or softmax’d logits), and the loss 1 for each training sample. If a sample has high loss in the training process, it’s a strong signal that something is wrong with that sample, or elsewhere in your data. So pay close attention. Double check the label and make sure it’s correct. If it is correct, you need to look at other similar samples in your data and check if they are correct. If they are all correct then you need to collect more of them.
- Why do we need a loss function?
- What is the relationship between the loss function and the evaluation metrics?
- Why can’t we just use evaluation metrics as as loss functions?
Modelling Outputs - Logits & Uncertainty
If we have two classification classes, the distance between the logits for each class will tell us approximately how confident the model is in its prediction. A low distance indicates uncertainty, which we can understand as the model struggling, for one reason or another, to classify the sample. If we have our groups in place, we can look at the balance of logits by groups to identify areas our model is struggling to classify.
Like the loss, we should pay close attention to the balance between the logits (or the softmax’d logits). We aim to find patterns in our data to understand which slices the model is having the most difficulty with. We do this by:
- arranging our data for uncertainty
- checking labels of the most uncertain samples
- finding other similar samples and checking those labels too.
Similarly to inspecting loss, if we do not find contradictory labels then we most likely need to find more samples to help the model learn.
Helping the model learn to classify samples with high uncertainty will tend to increase recall, because it helps the model learn a new pattern, increasing the range of relevant cases it’s able to deal with.
Modelling Outputs - Logits & Certainty
In contrast to the samples where uncertainty is high, we can use the samples for which our model is confident to acquire more training data and have a positive impact on performance. At first this may seem counter-intuitive, but provided they are correct, the model can still learn new things from its own predictions.
Harvesting confident predictions will tend to increase precision. We are helping the model to be more certain of its predictions, high certainty will tend to correlate with correctness, which will tend to reduce the false positive rate.
However!
When selecting samples from our model’s confident predictions, we need to be careful that we are not just filling our corpus with the same patterns & duplicates or near duplicates of our training data.
Hyper-parameters
Once we have acquired an in-depth understanding of what our model does well vs what it does poorly, and exhausted all available data sources, we can start to investigate some of the more arcane methods for improving model performance, like hyper-parameters. Whereas more data 2 may take a model from 60% to 95% accuracy, there is no guarantee that tuning hyper-parameters will have any positive impact on performance. That said, it is not uncommon to squeeze out a couple of % points.
The main hyper-parameters you should look to tune are:
- # of epochs
- learning rate
- regularisation with weight decay & dropout
- batch size
For weight decay & dropout check each library’s implementation, you may need to edit multiple parameters.
For text classification, regularisation through weight decay and dropout is often vital for reducing overfitting. We can usually trade in some of our performance on the training set for for generalisation - measured by increase in performance on the testing set.
18.1 Stepping into the Great Unknown
Beyond these systematic steps, there is more advanced (but shaky) ground that we can explore. However, it becomes increasingly harder to predict which of these steps are going to be fruitful - we need to isolate them and experiment with each of them, and react accordingly to the experimental data.
- Was training stable at the start or do I need to add some warm up?
- Do the training & loss curves suggest under or overfitting? If you’re not slightly overfitting, train for longer!
- Will using a bigger model significantly impact performance?
- Am I using the right activation function? Would switching to ReLU, or Leaky ReLU have an impact? What about Sigmoid?
- Is my loss function optimal?
- Should I be prioritising Precision or Recall for one class over another? Could I optimise the probability cut-off for classifications to achieve this?
- What optimiser am I using, can I do better than the default (AdamW)?
- Is my learning rate too high? How does the learning rate scheduler work in my optimiser?
- Is batch size having any impact on learning?
- Should I be early stopping? Am I saving the best model?
- Would ensembling help? 3
- Is there any post-processing I can do to move the needle?
Training neural networks is an art form not a science, aim to learn as you experiment.
18.2 Monitoring & Continuous Improvement
Modern machine learning models are excellent at identifying and compressing patterns in their training data. A good model will compress enough patterns that it can start to interpolate between them, this feature allows them to deal with ‘similar but slightly different’ novel data. However, input data will change over time, and the further the distribution of the new input data gets from the training data’s distribution, the worse the model will perform. If we want our model to maintain performance over time, we have to label new data and re-train the model with all previous data and the new data. To save time, we should aim to select the new data directly from the projects we used the model on, and the samples we hand labelled.
One way to do this is to combine labelling new data with our sampling of model outputs. For any project or dataset that we use the model on, we simply have to verify that it is performing as expected. This means looking at our data - taking samples of our model’s predictions and confirming or rejecting the predictions. We can then score our model on our samples, and approximate how well the model is calibrated for the dataset.
In the process you will find samples that are definitely correct, definitely incorrect, and borderline. Correct/confirm the labels and then save them somewhere that they can be added to the corpus for future training and evaluation runs.
Visit the model_evaluation document for a primer on loss functions↩︎
a lot more data↩︎
‘Bad practices’ like training the model with 5 different seeds have a ‘fairly reliable’ positive impact on performance↩︎