Tutorial: Save and Restore Models
Model progress can be saved after as well as during training. This means a model can resume where it left off and avoid long training times. Saving also means you can share your model and others can recreate your work. When publishing research models and techniques, most machine learning practitioners share:
- code to create the model, and
- the trained weights, or parameters, for the model
Sharing this data helps others understand how the model works and try it themselves with new data.
Options
There are many different ways to save TensorFlow models—depending on the API you’re using. This guide uses Keras, a high-level API to build and train models in TensorFlow. For other approaches, see the TensorFlow Save and Restore guide or Saving in eager.
Setup
We’ll use the MNIST dataset to train our model to demonstrate saving weights. To speed up these demonstration runs, only use the first 1000 examples:
library(keras)
mnist <- dataset_mnist()
c(train_images, train_labels) %<-% mnist$train
c(test_images, test_labels) %<-% mnist$test
train_labels <- train_labels[1:1000]
test_labels <- test_labels[1:1000]
train_images <- train_images[1:1000, , ] %>%
array_reshape(c(1000, 28 * 28))
train_images <- train_images / 255
test_images <- test_images[1:1000, , ] %>%
array_reshape(c(1000, 28 * 28))
test_images <- test_images / 255
Define a model
Let’s build a simple model we’ll use to demonstrate saving and loading weights.
# Returns a short sequential model
create_model <- function() {
model <- keras_model_sequential() %>%
layer_dense(units = 512, activation = "relu", input_shape = 784) %>%
layer_dropout(0.2) %>%
layer_dense(units = 10, activation = "softmax")
model %>% compile(
optimizer = "adam",
loss = "sparse_categorical_crossentropy",
metrics = list("accuracy")
)
model
}
model <- create_model()
summary(model)
## Model: "sequential"
## ___________________________________________________________________________
## Layer (type) Output Shape Param #
## ===========================================================================
## dense (Dense) (None, 512) 401920
## ___________________________________________________________________________
## dropout (Dropout) (None, 512) 0
## ___________________________________________________________________________
## dense_1 (Dense) (None, 10) 5130
## ===========================================================================
## Total params: 407,050
## Trainable params: 407,050
## Non-trainable params: 0
## ___________________________________________________________________________
Save the entire model
Call save_model_*
to save the a model’s architecture, weights, and training configuration in a single file/folder. This allows you to export a model so it can be used without access to the original code*. Since the optimizer-state is recovered, you can resume training from exactly where you left off.
Saving a fully-functional model is very useful—you can load them in TensorFlow.js (HDF5, Saved Model) and then train and run them in web browsers, or convert them to run on mobile devices using TensorFlow Lite (HDF5, Saved Model)
*Custom objects (e.g. subclassed models or layers) require special attention when saving and loading. See the “Saving custom objects” section below.
SavedModel format
The SavedModel format is a way to serialize models. Models saved in this format can be restored using load_model_tf
and are compatible with TensorFlow Serving. The SavedModel guide goes into detail about how to serve/inspect the SavedModel. The section below illustrates the steps to saving and restoring the model.
## Train on 1000 samples
## Epoch 1/5
## 1000/1000 - 0s - loss: 1.1809 - accuracy: 0.6680
## Epoch 2/5
## 1000/1000 - 0s - loss: 0.4156 - accuracy: 0.8860
## Epoch 3/5
## 1000/1000 - 0s - loss: 0.2836 - accuracy: 0.9250
## Epoch 4/5
## 1000/1000 - 0s - loss: 0.2241 - accuracy: 0.9370
## Epoch 5/5
## 1000/1000 - 0s - loss: 0.1473 - accuracy: 0.9680
The SavedModel format is a directory containing a protobuf binary and a Tensorflow checkpoint. Inspect the saved model directory:
## [1] "assets" "saved_model.pb" "variables"
Reload a fresh Keras model from the saved model:
## Model: "sequential_1"
## ___________________________________________________________________________
## Layer (type) Output Shape Param #
## ===========================================================================
## dense_2 (Dense) (None, 512) 401920
## ___________________________________________________________________________
## dropout_1 (Dropout) (None, 512) 0
## ___________________________________________________________________________
## dense_3 (Dense) (None, 10) 5130
## ===========================================================================
## Total params: 407,050
## Trainable params: 407,050
## Non-trainable params: 0
## ___________________________________________________________________________
HDF5 format
Keras provides a basic saving format using the HDF5 standard.
## Train on 1000 samples
## Epoch 1/5
## 1000/1000 - 0s - loss: 1.1386 - accuracy: 0.6780
## Epoch 2/5
## 1000/1000 - 0s - loss: 0.4326 - accuracy: 0.8770
## Epoch 3/5
## 1000/1000 - 0s - loss: 0.2874 - accuracy: 0.9310
## Epoch 4/5
## 1000/1000 - 0s - loss: 0.2164 - accuracy: 0.9460
## Epoch 5/5
## 1000/1000 - 0s - loss: 0.1536 - accuracy: 0.9690
Now recreate the model from that file:
## Model: "sequential_2"
## ___________________________________________________________________________
## Layer (type) Output Shape Param #
## ===========================================================================
## dense_4 (Dense) (None, 512) 401920
## ___________________________________________________________________________
## dropout_2 (Dropout) (None, 512) 0
## ___________________________________________________________________________
## dense_5 (Dense) (None, 10) 5130
## ===========================================================================
## Total params: 407,050
## Trainable params: 407,050
## Non-trainable params: 0
## ___________________________________________________________________________
This technique saves everything:
- The weight values
- The model’s configuration(architecture)
- The optimizer configuration
Keras saves models by inspecting the architecture. Currently, it is not able to save TensorFlow optimizers (from tf$train
). When using those you will need to re-compile the model after loading, and you will lose the state of the optimizer.
Saving custom objects
If you are using the SavedModel format, you can skip this section. The key difference between HDF5 and SavedModel is that HDF5 uses object configs to save the model architecture, while SavedModel saves the execution graph.
Thus, SavedModels are able to save custom objects like subclassed models and custom layers without requiring the orginal code.
To save custom objects to HDF5, you must do the following:
- Define a get_config method in your object, and optionally a from_config classmethod.
-
get_config()
returns a JSON-serializable dictionary of parameters needed to recreate the object. -
from_config(config)
uses the returned config from get_config to create a new object. By default, this function will use the config as initialization arguments.
-
- Pass the object to the
custom_objects
argument when loading the model. The argument must be a named list mapping the string class name to the class definition. E.g.load_keras_model_hdf5(path, custom_objects=list("CustomLayer" = CustomLayer))
See the Writing layers and models from scratch tutorial for examples of custom_objects
and get_config
.
Save checkpoints during training
It is useful to automatically save checkpoints during and at the end of training. This way you can use a trained model without having to retrain it, or pick-up training where you left of, in case the training process was interrupted.
callback_model_checkpoint
is a callback that performs this task.
The callback takes a couple of arguments to configure checkpointing. By default, save_weights_only
is set to false, which means the complete model is being saved - including architecture and configuration. You can then restore the model as outlined in the previous paragraph.
Now here, let’s focus on just saving and restoring weights. In the following code snippet, we are setting save_weights_only
to true, so we will need the model definition on restore.
Checkpoint callback usage
Train the model and pass it the callback_model_checkpoint
:
checkpoint_path <- "checkpoints/cp.ckpt"
# Create checkpoint callback
cp_callback <- callback_model_checkpoint(
filepath = checkpoint_path,
save_weights_only = TRUE,
verbose = 0
)
model <- create_model()
model %>% fit(
train_images,
train_labels,
epochs = 10,
validation_data = list(test_images, test_labels),
callbacks = list(cp_callback), # pass callback to training
verbose = 2
)
## Train on 1000 samples, validate on 1000 samples
## Epoch 1/10
## 1000/1000 - 0s - loss: 1.1775 - accuracy: 0.6750 - val_loss: 0.6874 - val_accuracy: 0.7980
## Epoch 2/10
## 1000/1000 - 0s - loss: 0.4144 - accuracy: 0.8810 - val_loss: 0.5366 - val_accuracy: 0.8320
## Epoch 3/10
## 1000/1000 - 0s - loss: 0.2811 - accuracy: 0.9280 - val_loss: 0.4517 - val_accuracy: 0.8610
## Epoch 4/10
## 1000/1000 - 0s - loss: 0.2205 - accuracy: 0.9430 - val_loss: 0.4692 - val_accuracy: 0.8500
## Epoch 5/10
## 1000/1000 - 0s - loss: 0.1520 - accuracy: 0.9690 - val_loss: 0.4084 - val_accuracy: 0.8660
## Epoch 6/10
## 1000/1000 - 0s - loss: 0.1147 - accuracy: 0.9780 - val_loss: 0.3946 - val_accuracy: 0.8680
## Epoch 7/10
## 1000/1000 - 0s - loss: 0.0831 - accuracy: 0.9870 - val_loss: 0.4008 - val_accuracy: 0.8710
## Epoch 8/10
## 1000/1000 - 0s - loss: 0.0607 - accuracy: 0.9970 - val_loss: 0.4056 - val_accuracy: 0.8640
## Epoch 9/10
## 1000/1000 - 0s - loss: 0.0510 - accuracy: 0.9970 - val_loss: 0.4031 - val_accuracy: 0.8720
## Epoch 10/10
## 1000/1000 - 0s - loss: 0.0465 - accuracy: 0.9960 - val_loss: 0.3923 - val_accuracy: 0.8710
Inspect the files that were created:
## [1] "checkpoint" "cp.ckpt.data-00000-of-00001"
## [3] "cp.ckpt.index"
Create a new, untrained model. When restoring a model from only weights, you must have a model with the same architecture as the original model. Since it’s the same model architecture, we can share weights despite that it’s a different instance of the model.
Now rebuild a fresh, untrained model, and evaluate it on the test set. An untrained model will perform at chance levels (~10% accuracy):
## $loss
## [1] 2.321936
##
## $accuracy
## [1] 0.126
Then load the weights from the latest checkpoint (epoch 10), and re-evaluate:
fresh_model %>% load_model_weights_tf(filepath = checkpoint_path)
fresh_model %>% evaluate(test_images, test_labels, verbose = 0)
## $loss
## [1] 0.3923183
##
## $accuracy
## [1] 0.871
Checkpoint callback options
Alternatively, you can decide to save only the best model, where best by default is defined as validation loss. See the documentation for callback_model_checkpoint for further information.
checkpoint_path <- "checkpoints/cp.ckpt"
# Create checkpoint callback
cp_callback <- callback_model_checkpoint(
filepath = checkpoint_path,
save_weights_only = TRUE,
save_best_only = TRUE,
verbose = 1
)
model <- create_model()
model %>% fit(
train_images,
train_labels,
epochs = 10,
validation_data = list(test_images, test_labels),
callbacks = list(cp_callback), # pass callback to training,
verbose = 2
)
## Train on 1000 samples, validate on 1000 samples
## Epoch 1/10
##
## Epoch 00001: val_loss improved from inf to 0.72178, saving model to checkpoints/cp.ckpt
## 1000/1000 - 0s - loss: 1.1691 - accuracy: 0.6620 - val_loss: 0.7218 - val_accuracy: 0.7760
## Epoch 2/10
##
## Epoch 00002: val_loss improved from 0.72178 to 0.56689, saving model to checkpoints/cp.ckpt
## 1000/1000 - 0s - loss: 0.4227 - accuracy: 0.8850 - val_loss: 0.5669 - val_accuracy: 0.8110
## Epoch 3/10
##
## Epoch 00003: val_loss improved from 0.56689 to 0.51581, saving model to checkpoints/cp.ckpt
## 1000/1000 - 0s - loss: 0.3018 - accuracy: 0.9160 - val_loss: 0.5158 - val_accuracy: 0.8380
## Epoch 4/10
##
## Epoch 00004: val_loss improved from 0.51581 to 0.44739, saving model to checkpoints/cp.ckpt
## 1000/1000 - 0s - loss: 0.2120 - accuracy: 0.9480 - val_loss: 0.4474 - val_accuracy: 0.8540
## Epoch 5/10
##
## Epoch 00005: val_loss did not improve from 0.44739
## 1000/1000 - 0s - loss: 0.1519 - accuracy: 0.9700 - val_loss: 0.4602 - val_accuracy: 0.8510
## Epoch 6/10
##
## Epoch 00006: val_loss improved from 0.44739 to 0.42596, saving model to checkpoints/cp.ckpt
## 1000/1000 - 0s - loss: 0.1257 - accuracy: 0.9750 - val_loss: 0.4260 - val_accuracy: 0.8630
## Epoch 7/10
##
## Epoch 00007: val_loss improved from 0.42596 to 0.40990, saving model to checkpoints/cp.ckpt
## 1000/1000 - 0s - loss: 0.0866 - accuracy: 0.9850 - val_loss: 0.4099 - val_accuracy: 0.8610
## Epoch 8/10
##
## Epoch 00008: val_loss did not improve from 0.40990
## 1000/1000 - 0s - loss: 0.0688 - accuracy: 0.9930 - val_loss: 0.4210 - val_accuracy: 0.8560
## Epoch 9/10
##
## Epoch 00009: val_loss did not improve from 0.40990
## 1000/1000 - 0s - loss: 0.0517 - accuracy: 0.9970 - val_loss: 0.4326 - val_accuracy: 0.8640
## Epoch 10/10
##
## Epoch 00010: val_loss did not improve from 0.40990
## 1000/1000 - 0s - loss: 0.0386 - accuracy: 1.0000 - val_loss: 0.4521 - val_accuracy: 0.8510
## [1] "checkpoint" "cp.ckpt.data-00000-of-00001"
## [3] "cp.ckpt.index"
What are these files?
The above code stores the weights to a collection of checkpoint-formatted files that contain only the trained weights in a binary format. Checkpoints contain:
- One or more shards that contain your model’s weights.
- An index file that indicates which weights are stored in a which shard.
If you are only training a model on a single machine, you’ll have one shard with the suffix: .data-00000-of-00001
Manually save the weights
You saw how to load the weights into a model. Manually saving them is just as simple with the save_model_weights_tf
function.
# Save the weights
model %>% save_model_weights_tf("checkpoints/cp.ckpt")
# Create a new model instance
new_model <- create_model()
# Restore the weights
new_model %>% load_model_weights_tf('checkpoints/cp.ckpt')
# Evaluate the model
new_model %>% evaluate(test_images, test_labels, verbose = 0)
## $loss
## [1] 0.4520541
##
## $accuracy
## [1] 0.851