All

Different methods for mitigating overfitting on Neural Networks

Pablo Sánchez

26/05/2021

No Comments

Using Machine Learning and Deep Learning models to solve scientific problems of greater or lesser complexity is a challenge.

Referring to neural networks, on the one hand, simple networks with too little capacity will not learn the problem well producing a model that underfits the data. On the other hand, complex networks with too much capacity will learn it too well leading to a model that overfits the data.

As a result, in both cases you will get a model that does not generalize well. In this post, you will find some techniques to reduce overfitting and get a more generalized and robust model.

Overfitting

With four parameters I can fit an elephant, and with five I can make him wiggle his trunk.

John von Neumann

In Machine Learning and Deep Learning models, it is possible to test several combinations of parameters and hyperparameters to get a more accurate model. This is a very powerful tool but a very dangerous one too. The more complex the network is, the more fit to the training data will be.

Explanation of underfitting and overfitting using the training and validation error depending on the model complexity

Method to mitigate overfitting

Model simplification

According to the previous plot, decreasing the complexity of the model is one idea to deal with overfitting. On neural networks, reducing the number of neurons or removing some hidden layers will work.

Regularization

One of the first methods we should try when we need to reduce overfitting in our neural network is regularization. The main idea of this method is to penalize the model for being too complex or using high values in the weights matrix.

For applying regularization it is necessary to add an extra element to the loss function. The impact this method has on the model is parameterized by a variable called lambda (\(\lambda\)). There are different ways to penalize the loss function but the most famous ones are based on the L1 and L2 norm of the weights (respectivaly called L1 and L2 regularization). As you can see in the following equations, the higher the parameter, the higher the impact on the loss function.

L1 regularization equation

\( J_{L1}(W,b) = \frac{1}{m} \sum_{i=1}^{m} L(\hat{y}^{(i)}, y^{(i)}) + \lambda \left \| W \right \|_{1} \quad \quad \left \| W \right \|_{1} = \sum_{j=1}^{n_x} \| W_{j} \| \)

L2 regularization equation

\( J_{L2}(W,b) = \frac{1}{m} \sum_{i=1}^{m} L(\hat{y}^{(i)}, y^{(i)}) + \lambda \left \| W \right \|_{2} \quad \quad \left \| W \right \|_{2} = \sum_{j=1}^{n_x} W_j^{2} \)

Where

  • \(m\) is the number of training examples.
  • \(L(\hat{y}^{(i)}, y^{(i)})\) is the loss function between the estimated value \(\hat{y}^{(i)}\) and the real value \(y^{(i)}\) for the i-th training example.
  • \(J(W,b)\) is the cost function.

Dropout

Dropout is another very popular technique of regularization of neural networks. It is mainly based on the idea of simplifying the model. The objective is to reduce the importance of neurons by randomly switching some neurons off on the training stage.

It is possible to choose the probability of some neurons being disabled for each layer. Due to the random effect, the set of disabled neurons on each iteration is different, resulting on a completely different network. The hyperparameter used for this technique is called the dropout rate.

In the following gif image, it is possible to see a graphical representation of dropout on a Neural Network. In this example, the dropout rate is different on each layer:

  • The dropout rate in the second layer is 0.5, for that reason 3 of 6 neurons are switched off.
  • The dropout rate in the third layer is 0.33, for that reason 2 of 6 neurons are switched off.

Early stopping

Early stopping is another method of regularization used on the models trained in iterative processes, such as gradient descent. The main idea of this technique is to stop the training process when the validation error starts to increase. In this specific point, the model will increase its generalization error meaning that it is starting to overfit.

Data Augmentation

Data Augmentation is a well-known and very used technique when training Neural Networks using images. In this blog, it is possible to see different applications of data augmentation for financial series, such as Generating Financial Series with Generative Adversarial Networks.

In some cases, the dataset is not big enough or does not have a wide variety of pictures. This method reduces overfitting by increasing the dataset by applying different transformations to enrich it. The most popular image transformations are:

  • Flipping
  • Rotation
  • Scaling
  • Cropping
  • Translations
  • Adding noise

For more information about data augmentation, you can read the Nanonets post.

0 Comments
Inline Feedbacks
View all comments