This section details the training pipeline.
Optimizing the input pipeline
Summary: The causes and interventions of input-bound pipelines are highly task-dependent. Use a profiler and look out for common issues.
Use an appropriate profiler, such as one of the following, to diagnose input-bound pipelines:
- Perfetto for JAX
- TensorFlow profiler for TensorFlow.
Ultimately, the specific causes and interventions are highly task-dependent. Broader engineering considerations (for example, minimizing disk footprint) may hurt input pipeline performance.
The following are common causes of input-bound pipelines:
- Data are not colocated with the training process, causing I/O latency. For example, reading training data over a network might cause I/O latency.
- Expensive online data preprocessing. Consider preprocessing once offline and saving the results.
- Unintentional synchronization barriers that interfere with data pipeline prefetching. For example, when synchronizing metrics between the device and host in CommonLoopUtils.
We suggest the following interventions for input-bound pipelines:
- Instrument input pipeline to prefetch examples (for example, tf.data.Dataset.prefetch).
- Remove unused features and metadata from each as early in the pipeline as possible.
- Increase the replication of the number of jobs generating examples for the input pipeline, for example, by using the tf.data service.
Evaluating model performance
Summary: Run evaluation at larger batch sizes than training. Run evaluations at regular step intervals, not regular time intervals.
Evaluation settings
You can use the following settings to evaluate the performance of your models:
- Online evaluation: Collect metrics when the model is serving predictions in a production environment. Online evaluation generally provides the most realistic assessment of model quality because it matches the way the model will be used.
- Offline evaluation: Collect metrics when the model is run on offline training, validation, or test sets representative of the production environment. Depending on the problem, offline evaluation can be fairly involved and computationally expensive.
- Periodic evaluations: Collect metrics during model training that might be a proxy for the offline evaluation, and/or on a subset of the data used in offline evaluation. Periodic evaluations are the most practical and economical choice but may not fully represent the production environment. Aim to use an expedient proxy of the offline evaluation, without sacrificing the reliability of the signal received during training.
Setting up periodic evaluations
We recommend running periodic evaluations during training for the following reasons:
- To monitor training progress in real time.
- To facilitate retrospective model checkpoint selection.
- To examine the training curves at the end of training.
The simplest configuration is to perform both training and periodic evaluations within the same compute instance, periodically alternating between training and evaluation. In this case, the batch size used to perform evaluations should be at least as large as the batch size used for training. That's because you don't need to maintain model activations during evaluation, which lowers the computational requirements per example.
Perform periodic evaluations at regular step intervals, not time intervals. Evaluating based on time intervals can make it harder to interpret the training curves, especially when training may suffer from preemptions of the training jobs, network latency issues, and so on.
Periodicity in validation and test metrics (when using a shuffled training set, validation set, test set split) can indicate implementation bugs such as:
- Test data overlapping with training data.
- Training data not being properly shuffled.
Evaluating at regular step intervals can make these issues easier to catch.
Partial batches can occur when the evaluation sets are not divisible by the batch size. Ensure that the padded examples are correctly weighted (as in the weighted average over examples computing the average loss over the batch) to prevent the loss function from being biased by them. Often, you can give these padded examples a weight of zero.
Save sufficient information per evaluation to support offline analysis. Ideally, save predictions on a selection of individual examples since they can be invaluable for debugging. Generating artifacts like SavedModels simplify ad hoc model inspection after evaluation jobs finish.
Choosing a sample for periodic evaluation
The periodic evaluation job might not run fast enough to compute metrics on the full offline evaluation set in a reasonable amount of time. This problem often necessitates sampling data for periodic evaluation. When constructing a sampled dataset, consider issues with sample size and special concerns in imbalanced datasets.
Sample size
Check that the performance computed on the sampled dataset used by the periodic job matches the performance on the whole offline evaluation set; that is, ensure that there is no skew between the sampled dataset and the full dataset.
The dataset you use for periodic evaluation should be both of the following:
- Small enough to easily generate model predictions over its entirety.
- Large enough to do both of the following:
- Accurately measure improvements to the model; that is, measurements shouldn't be overwhelmed by label noise.
- Accommodate multiple such evaluations across trials in sequence, and still produce accurate estimates. That is, large enough to avoid adaptively "fitting" to the validation set over time in a way that doesn't generalize to a held-out test set. However, this consideration is rarely a practical concern.
Imbalanced datasets
For imbalanced datasets, performance on rare minority classes is often noisy. For datasets with only a small number of minority examples, log the number of examples predicted correctly to get more insight into accuracy improvements. For example, .05 sensitivity improvement sounds exciting, but was the improvement just due to one more example being correct?
Saving checkpoints and retrospectively selecting the best checkpoint
Summary: Run training for a fixed number of steps and retrospectively choose the best checkpoint from the run.
Most deep learning frameworks support model checkpointing. That is, the current state of the model is periodically saved to disk. Checkpointing allows the training job to be resilient to compute instance interruptions. The best checkpoint is often not the last checkpoint, particularly when the validation set performance does not continue to increase over time but rather fluctuates about a particular value.
Set up the pipeline to keep track of the N best checkpoints seen so far during training. At the end of training, model selection simply means choosing the best checkpoint. We call this approach retrospective optimal checkpoint selection. Supporting prospective early stopping is usually not necessary, because you're pre-specifying a trial budget and are preserving the N best checkpoints seen so far.
Setting up experiment tracking
Summary: When tracking different experiments, track a number of essentials, like the best performance of a checkpoint in the study, and a short description of the study.
We recommend tracking experiment results in a spreadsheet. Our spreadsheets often contain the following columns:
- Study name
- A link to wherever the config for the study is stored.
- Notes or a short description of the study.
- Number of trials run
- Performance on the validation set of the best checkpoint in the study.
- Specific reproduction commands or notes on what unsubmitted changes were necessary to launch training.
Find a convenient tracking system that captures at least the information listed above. Untracked experiments might as well not exist.
Batch normalization implementation details
Summary: Nowadays, you can often replace batch normalization with LayerNorm, but in cases where you cannot do that replacement, there are tricky details when changing the batch size or number of hosts.
Batch normalization normalizes activations using their mean and variance over the current batch. However, in the multi-device setting, these statistics differ on each device unless explicitly synchronized. Anecdotal reports (mostly on ImageNet) indicate that calculating these normalizing statistics using only ~64 examples actually works better in practice. (See the description of Ghost Batch Normalization in Train longer, generalize better: closing the generalization gap in large batch training of neural networks.) Decoupling the total batch size and the number of examples used to calculate batch norm statistics is particularly useful for batch size comparisons.
Ghost batch normalization implementations don't always correctly handle the case where the per-device batch size is greater than virtual batch size. In this case, you'd need to subsample the batch on each device to get the proper number of batch norm statistic examples.
Exponential moving averages (EMAs) used in test mode batch normalization are just a linear combination of training statistics. Therefore, you only need to synchronize these EMAs before saving them in checkpoints. However, some common implementations of batch normalization don't synchronize these EMAs and only save the EMA from the first device.
Considerations for multi-host pipelines
Summary: for logging, evals, RNGs, checkpointing, and data sharding, multi-host training can make it very easy to introduce bugs!
Do the following for multi-host pipelines:
- Ensure that the pipeline is only logging and checkpointing on one host.
- Synchronize batch normalization statistics across hosts before evaluating or checkpointing.
- Shard data files across hosts since that usually improves performance.
Critical: Ensure that you have RNG seeds that are the same across hosts (for model initialization), and seeds that are different across hosts (for data shuffling/preprocessing). Therefore, make sure to mark them appropriately.