Overfitting

Overfitting means creating a model that matches (memorizes) the training set so closely that the model fails to make correct predictions on new data. An overfit model is analogous to an invention that performs well in the lab but is worthless in the real world.

In Figure 11, imagine that each geometric shape represents a tree's position in a square forest. The blue diamonds mark the locations of healthy trees, while the orange circles mark the locations of sick trees.

Figure 11. This figure contains about 60 dots, half of which are
            healthy trees and the other half sick trees.
            The healthy trees are mainly in the northeast quadrant, though a few
            healthy trees sneak into the northwest quadrants. The sick trees
            are mainly in the southeast quadrant, but a few of the sick trees
            spill into other quadrants.
Figure 11. Training set: locations of healthy and sick trees in a square forest.

 

Mentally draw any shapes—lines, curves, ovals...anything—to separate the healthy trees from the sick trees. Then, expand the next line to examine one possible separation.

The complex shapes shown in Figure 12 successfully categorized all but two of the trees. If we think of the shapes as a model, then this is a fantastic model.

Or is it? A truly excellent model successfully categorizes new examples. Figure 13 shows what happens when that same model makes predictions on new examples from the test set:

Figure 13. A new batch of healthy and sick trees overlaid on the
            model shown in Figure 12. The model miscategorizes many of the
            trees.
Figure 13.Test set: a complex model for distinguishing sick from healthy trees.

 

So, the complex model shown in Figure 12 did a great job on the training set but a pretty bad job on the test set. This is a classic case of a model overfitting to the training set data.

Fitting, overfitting, and underfitting

A model must make good predictions on new data. That is, you're aiming to create a model that "fits" new data.

As you've seen, an overfit model makes excellent predictions on the training set but poor predictions on new data. An underfit model doesn't even make good predictions on the training data. If an overfit model is like a product that performs well in the lab but poorly in the real world, then an underfit model is like a product that doesn't even do well in the lab.

Figure 14. Cartesian plot. X-axis is labeled 'quality of predictions
            on training set.' Y-axis is labeled 'quality of predictions on
            real-world data.' A curve starts at the origin and rises gradually,
            but then falls just as quickly. The lower-left portion of the curve
            (low quality of predictions on real-world data and low quality of
            predictions on training set) is labeled 'underfit models.' The
            lower-right portion of the curve (low quality of predictions on
            real-world data but high quality of predictions on training set)
            is labeled 'overfit models.' The peak of the curve (high quality
            of predictions on real-world data and medium quality of predictions
            on training set) is labeled 'fit models.'
Figure 14. Underfit, fit, and overfit models.

 

Generalization is the opposite of overfitting. That is, a model that generalizes well makes good predictions on new data. Your goal is to create a model that generalizes well to new data.

Detecting overfitting

The following curves help you detect overfitting:

  • loss curves
  • generalization curves

A loss curve plots a model's loss against the number of training iterations. A graph that shows two or more loss curves is called a generalization curve. The following generalization curve shows two loss curves:

Figure 15. The loss function for the training set gradually
            declines. The loss function for the validation set also declines,
            but then it starts to rise after a certain number of iterations.
Figure 15. A generalization curve that strongly implies overfitting.

 

Notice that the two loss curves behave similarly at first and then diverge. That is, after a certain number of iterations, loss declines or holds steady (converges) for the training set, but increases for the validation set. This suggests overfitting.

In contrast, a generalization curve for a well-fit model shows two loss curves that have similar shapes.

What causes overfitting?

Very broadly speaking, overfitting is caused by one or both of the following problems:

  • The training set doesn't adequately represent real life data (or the validation set or test set).
  • The model is too complex.

Generalization conditions

A model trains on a training set, but the real test of a model's worth is how well it makes predictions on new examples, particularly on real-world data. While developing a model, your test set serves as a proxy for real-world data. Training a model that generalizes well implies the following dataset conditions:

  • Examples must be independently and identically distributed, which is a fancy way of saying that your examples can't influence each other.
  • The dataset is stationary, meaning the dataset doesn't change significantly over time.
  • The dataset partitions have the same distribution. That is, the examples in the training set are statistically similar to the examples in the validation set, test set, and real-world data.

Explore the preceding conditions through the following exercises.

Exercises: Check your understanding

Consider the following dataset partitions.
A horizontal bar divided into three pieces: 70% of the bar
                     is the training set, 15% the validation set, and 15%
                     the test set
What should you do to ensure that the examples in the training set have a similar statistical distribution to the examples in the validation set and the test set?
Shuffle the examples in the dataset extensively before partitioning them.
Yes. Good shuffling of examples makes partitions much more likely to be statistically similar.
Sort the examples from earliest to most recent.
If the examples in the dataset are not stationary, then sorting makes the partitions less similar.
Do nothing. Given enough examples, the law of averages naturally ensures that the distributions will be statistically similar.
Unfortunately, this is not the case. The examples in certain sections of the dataset may differ from those in other sections.
A streaming service is developing a model to predict the popularity of potential new television shows for the next three years. The streaming service plans to train the model on a dataset containing hundreds of millions of examples, spanning the previous ten years. Will this model encounter a problem?
Probably. Viewers' tastes change in ways that past behavior can't predict.
Yes. Viewer tastes are not stationary. They constantly change.
Definitely not. The dataset is large enough to make good predictions.
Unfortunately, viewers' tastes are nonstationary.
Probably not. Viewers' tastes change in predictably cyclical ways. Ten years of data will enable the model to make good predictions on future trends.
Although certain aspects of entertainment are somewhat cyclical, a model trained from past entertainment history will almost certainly have trouble making predictions about the next few years.
A model aims to predict the time it takes for people to walk a mile based on weather data (temperature, dew point, and precipitation) collected over one year in a city whose weather varies significantly by season. Can you build and test a model from this dataset, even though the weather readings change dramatically by season?
Yes
Yes, it is possible to build and test a model from this dataset. You just have to ensure that the data is partitioned equally, so that data from all four seasons is distributed equally into the different partitions.
No
Assuming this dataset contains enough examples of temperature, dew point, and precipitation, then you can build and test a model from this dataset. You just have to ensure that the data is partitioned equally, so that data from all four seasons is distributed equally into the different partitions.

Challenge exercise

You are creating a model that predicts the ideal date for riders to buy a train ticket for a particular route. For example, the model might recommend that users buy their ticket on July 8 for a train that departs July 23. The train company updates prices hourly, basing their updates on a variety of factors but mainly on the current number of available seats. That is:

  • If a lot of seats are available, ticket prices are typically low.
  • If very few seats are available, ticket prices are typically high.
Your model exhibits low loss on the validation set and the test set but sometimes makes terrible predictions on real-world data. Why?
Click here to see the answer