In this post, we discuss how a technique known as warm-starting can be used to save computational resources and improve generalizability when training deep learning models.
Developing modern deep learning applications requires massive amounts of computational and human resources. The weights of the resulting deep learning models represent the fruit of these precious resources. Throwing these weights away amounts to lighting money on fire! The intrinsic value of learned model weights motivates the use of warm-starting, or initializing a model with weights from a previously trained model. While a simple concept, the applications of warm-starting are powerful and far-reaching, including transfer learning as well as other more nuanced use cases that should be part of all DL engineers’ playbooks.
Training a neural network normally begins with initializing model weights to random values. As an alternative strategy, we can initialize weights by copying them from a previously trained model. This warm-starting approach enables us to start training from a better initial point on the loss surface and often learn better models. In so doing, warm-starting leverages prior computation to dramatically reduce the time required to train a model.
If you’re well read on machine learning, you’ve probably heard about transfer learning, an approach where a model trained on a source domain is exploited to improve generalizability on a target domain. You may think of this approach as a specific example of the general concept of warm-starting, where the weights of one model are used as the initialization point for training another model. In some cases, data scientists may only use pre-trained weights for the earlier layers in the network and instantiate a new top layer corresponding to the (different) classes being learned. They may choose to only train the weights in this new layer, freezing the pre-loaded ones (leading to even more savings in computational cost!).
For example, when building a network to classify dog breeds, one might consider warm-starting from a publicly available model (e.g. InceptionV3, ResNet50) that has been trained on the public ImageNet dataset. Even though the target model will use different classification targets, there is enough overlap between the two image classification tasks that beginning with an existing model’s weights is a good starting point. Not only does this leverage the thousands of hours of training done on the original model, but it is also a form of augmentation for your dataset, which helps in situations where you otherwise might not have enough training examples.
While warm-starting is a great way to reuse work across different tasks, it can be equally powerful in the context of the development of a single deep learning task, given the iterative nature of model development. Typically a data scientist will run an experiment, adjust something (e.g., model architecture, data preprocessing algorithm), and then run another experiment, often throwing away all previous work. By choosing instead to warm-start subsequent experiments, the canny engineer can begin these experiments with the weights from a previously trained model.
Here are some specific scenarios where warm starting may be helpful:
Most deep learning frameworks provide some support for warm-starting. TensorFlow has a WarmStartSettings class. Most other frameworks, such as Keras and Pytorch, include functions for saving and loading models. However, it falls to the engineer to create and maintain a database of models, as well as explicitly call the necessary functions to store and load them.
Moreover, brute-force storing of all model weights quickly becomes infeasible.
By systematically cataloging the metadata and checkpoints associated with all previous experiments, Determined AI’s Productivity Engine for Deep Learning (PEDL) eliminates this work, effectively automating warm-starting. Not only does it periodically checkpoint models during training, but it also manages these checkpoints, marking stale ones for deletion and retaining relevant ones. PEDL’s web-based UI includes single-click “Continue Training” functionality, which allows an engineer to warm-start an experiment using the weights, hyperparameters, and optimizer state from any prior experiment (regardless of who on the team developed it).
Imagine you are part of a music streaming organization that is running a recommendation model A in production trained on a snapshot of user listening history A. You are given a new snapshot of user listening history B and tasked with using it to train a fresher model. You have a constrained amount of computational resources to train this model and a limited amount of time before this snapshot becomes stale. One option is to train a model from randomly initialized weights on dataset B, or a concatenation of dataset A and dataset B. Another option is to warm-start from the weights of model A and train on the updated dataset B.
To simulate this scenario, we have conducted an experiment in which we split the CIFAR-10 object detection dataset in half to represent datasets A and B. We first train a model A on dataset A to achieve a validation performance of 27.88% after 16 epochs of training. Next, we simultaneously train three models on dataset B. One model is trained on dataset B by initializing the weights to those of model A (Warm Start) utilizing PEDL’s “Continue Training” functionality. Another model is trained on dataset B by initializing the weights to random values (Random Initialization). A final model is trained on the concatenation of datasets A and B by initializing the weights to random values (Random Initialization (Full Data)). All experiments were trained using a standard 6-layer CNN model architecture that is commonly used in computer vision tasks with a batch size of 32 for 25,000 batches, or 16 epochs.
The validation performance over elapsed time (GPU minutes) is graphed above in Figure 3. As expected, the model that used warm-started weights consistently maintains a better validation error throughout training than the randomly initialized models. Given the same amount of GPU resources, the warm-started experiment achieves a validation error of 23.49% compared to an error of 27.04% for the randomly initialized experiment on dataset B and an error of 25.77% for the randomly initialized experiment on the concatenation of datasets A and B.
Want to start experimenting with PEDL’s warm-starting support? Drop us an email at [email protected] to get started.