Checkpoints
The phrase “Saving a TensorFlow model” typically means one of two things:
- Checkpoints, OR
- SavedModel.
Checkpoints capture the exact value of all parameters (tf$Variable
objects) used by a model. Checkpoints do not contain any description of the computation defined by the model and thus are typically only useful when source code that will use the saved parameter values is available.
The SavedModel
format on the other hand includes a serialized description of the computation defined by the model in addition to the parameter values (checkpoint). Models in this format are independent of the source code that created the model. They are thus suitable for deployment via TensorFlow Serving, TensorFlow Lite, TensorFlow.js, or programs in other programming languages (the C, C++, Java, Go, Rust, C# etc. TensorFlow APIs).
This guide covers APIs for writing and reading checkpoints.
Setup
Saving models through the Keras API
The respective Keras guide explains how to use the Keras API to save and restore complete models as well as model weights.
Calling save_weights
effectively results in saving a TensorFlow checkpoint:
More customization is available through lower-level TensorFlow methods.
Writing checkpoints
The persistent state of a TensorFlow model is stored in tf$Variable
objects. These can be constructed directly, but are often created through high-level APIs like Keras layers or models.
The easiest way to manage variables is by attaching them to Python objects, then referencing those objects.
Subclasses of tf$train$Checkpoint
, tf$keras$layers$Layer
, and tf$keras$Model
automatically track variables assigned to their attributes. The following example constructs a simple linear model, then writes checkpoints which contain values for all of the model’s variables.
Manual checkpointing
Setup
To help demonstrate all the features of tf$train$Checkpoint
, define a toy dataset and an optimization step:
library(tfdatasets)
toy_dataset <- function() {
inputs <- matrix(0:9, ncol = 1) %>% tf$cast(tf$float32)
labels <- matrix(0:49, ncol = 5, byrow = TRUE) %>% tf$cast(tf$float32)
tensor_slices_dataset(list(x = inputs, y = labels)) %>%
dataset_repeat(10) %>%
dataset_batch(2)
}
train_step <- function(model, example, optimizer) {
with(tf$GradientTape() %as% tape, {
output <- model(example$x)
loss <- tf$reduce_mean(tf$abs(output - example$y))
})
variables <- model$trainable_variables
gradients <- tape$gradient(loss, variables)
optimizer$apply_gradients(purrr::transpose(list(gradients, variables)))
loss
}
Create the checkpoint objects
To manually make a checkpoint you will need a tf$train$Checkpoint
object. Where the objects you want to checkpoint are set as attributes on the object.
A tf$train$CheckpointManager
can also be helpful for managing multiple checkpoints.
Train and checkpoint the model
The following training loop creates an instance of the model and of an optimizer, then gathers them into a tf$train$Checkpoint
object. It calls the training step in a loop on each batch of data, and periodically writes checkpoints to disk.
library(glue)
library(tfautograph)
train_and_checkpoint <- autograph(function(model, manager, dataset) {
ckpt$restore(manager$latest_checkpoint)
if (!is.null(manager$latest_checkpoint)) {
tf$print(glue("Restored from {manager$latest_checkpoint}\n"))
} else {
tf$print("Initializing from scratch.\n")
}
for (example in dataset()) {
loss <- train_step(model, example, opt)
ckpt$step$assign_add(1)
if (ckpt$step %% 10 == 0) {
save_path <- manager$save()
tf$print(glue("Saved checkpoint for step {as.numeric(ckpt$step)}: {as.character(save_path)}"))
tf$print(glue("loss: {round(as.numeric(loss), 2)}"))
}
}
})
train_and_checkpoint(model, manager, toy_dataset)
Restore and continue training
After the first you can pass a new model and manager, but pickup training exactly where you left off:
opt <- optimizer_adam(0.1)
net <- create_model(5)
ckpt <- tf$train$Checkpoint(step = tf$Variable(1), optimizer = opt, net = net)
manager <- tf$train$CheckpointManager(ckpt, './tf_ckpts', max_to_keep = 3)
train_and_checkpoint(net, manager, toy_dataset)
The tf$train$CheckpointManager
object deletes old checkpoints. Above it’s configured to keep only the three most recent checkpoints.
manager$checkpoints # List the three remaining checkpoints
#> [1] "./tf_ckpts/ckpt-18" "./tf_ckpts/ckpt-19" "./tf_ckpts/ckpt-20"
These paths, e.g. ‘./tf_ckpts/ckpt-10’, are not files on disk. Instead they are prefixes for an index file and one or more data files which contain the variable values. These prefixes are grouped together in a single checkpoint file (‘./tf_ckpts/checkpoint’) where the CheckpointManager
saves its state.
list.files("./tf_ckpts/")
#> [1] "checkpoint" "ckpt-18.data-00000-of-00001"
#> [3] "ckpt-18.index" "ckpt-19.data-00000-of-00001"
#> [5] "ckpt-19.index" "ckpt-20.data-00000-of-00001"
#> [7] "ckpt-20.index"
TensorFlow matches variables to checkpointed values by traversing a directed graph with named edges, starting from the object being loaded. Edge names typically come from attribute names in objects, for example the “l1” in self$l1 <- layer_dense(units = size)
. tf$train$Checkpoint
uses its keyword argument names, as in the “step” in tf$train$CheckpointManager(step = ...)
.
The dependency graph from the example above looks like this:
With the optimizer in red, regular variables in blue, and optimizer slot variables in orange. The other nodes, for example representing the tf$train$Checkpoint
, are black.
Slot variables are part of the optimizer’s state, but are created for a specific variable. For example the m
edges above correspond to momentum, which the Adam optimizer tracks for each variable. Slot variables are only saved in a checkpoint if the variable and the optimizer would both be saved, thus the dashed edges.
Calling restore()
on a tf$train$Checkpoint
object queues the requested restorations, restoring variable values as soon as there’s a matching path from the Checkpoint object. For example we can load just the kernel from the model we defined above by reconstructing one path to it through the network and the layer.
to_restore <- tf$Variable(tf$zeros(list(5L)))
as.numeric(to_restore) # All zeros
#> [1] 0 0 0 0 0
fake_layer <- tf$train$Checkpoint(bias = to_restore)
fake_net <- tf$train$Checkpoint(l1 = fake_layer)
new_root <- tf$train$Checkpoint(net = fake_net)
status <- new_root$restore(tf$train$latest_checkpoint('./tf_ckpts/'))
as.numeric(to_restore) # We get the restored value now
#> [1] 0.0308778 1.0408185 1.9840808 2.9422886 3.9628119
The dependency graph for these new objects is a much smaller subgraph of the larger checkpoint we wrote above. It includes only the bias and a save counter that tf$train$Checkpoint
uses to number checkpoints.
restore()
returns a status object, which has optional assertions. All of the objects we’ve created in our new Checkpoint
have been restored, so status.assert_existing_objects_matched()
passes.
status$assert_existing_objects_matched()
#> <tensorflow.python.training.tracking.util.CheckpointLoadStatus>
There are many objects in the checkpoint which haven’t matched, including the layer’s kernel and the optimizer’s variables. status$assert_consumed()
only passes if the checkpoint and the program match exactly, and would throw an exception here.
Delayed restorations
Layer objects in TensorFlow may delay the creation of variables to their first call, when input shapes are available. For example the shape of a Dense layer’s kernel depends on both the layer’s input and output shapes, and so the output shape required as a constructor argument is not enough information to create the variable on its own. Since calling a Layer also reads the variable’s value, a restore must happen between the variable’s creation and its first use.
To support this idiom, tf$train$Checkpoint
queues restores which don’t yet have a matching variable.
Manually inspecting checkpoints
tf$train$list_variables
lists the checkpoint keys and shapes of variables in a checkpoint. Checkpoint keys are paths in the graph displayed above.
tf$train$list_variables(tf$train$latest_checkpoint('./tf_ckpts/'))
#> [[1]]
#> [[1]][[1]]
#> [1] "_CHECKPOINTABLE_OBJECT_GRAPH"
#>
#> [[1]][[2]]
#> list()
#>
#>
#> [[2]]
#> [[2]][[1]]
#> [1] "net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE"
#>
#> [[2]][[2]]
#> [1] 5
#>
#>
#> [[3]]
#> [[3]][[1]]
#> [1] "net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE"
#>
#> [[3]][[2]]
#> [1] 5
#>
#>
#> [[4]]
#> [[4]][[1]]
#> [1] "net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE"
#>
#> [[4]][[2]]
#> [1] 5
#>
#>
#> [[5]]
#> [[5]][[1]]
#> [1] "net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE"
#>
#> [[5]][[2]]
#> [1] 1 5
#>
#>
#> [[6]]
#> [[6]][[1]]
#> [1] "net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE"
#>
#> [[6]][[2]]
#> [1] 1 5
#>
#>
#> [[7]]
#> [[7]][[1]]
#> [1] "net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE"
#>
#> [[7]][[2]]
#> [1] 1 5
#>
#>
#> [[8]]
#> [[8]][[1]]
#> [1] "optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE"
#>
#> [[8]][[2]]
#> list()
#>
#>
#> [[9]]
#> [[9]][[1]]
#> [1] "optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE"
#>
#> [[9]][[2]]
#> list()
#>
#>
#> [[10]]
#> [[10]][[1]]
#> [1] "optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE"
#>
#> [[10]][[2]]
#> list()
#>
#>
#> [[11]]
#> [[11]][[1]]
#> [1] "optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE"
#>
#> [[11]][[2]]
#> list()
#>
#>
#> [[12]]
#> [[12]][[1]]
#> [1] "optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE"
#>
#> [[12]][[2]]
#> list()
#>
#>
#> [[13]]
#> [[13]][[1]]
#> [1] "save_counter/.ATTRIBUTES/VARIABLE_VALUE"
#>
#> [[13]][[2]]
#> list()
#>
#>
#> [[14]]
#> [[14]][[1]]
#> [1] "step/.ATTRIBUTES/VARIABLE_VALUE"
#>
#> [[14]][[2]]
#> list()
List and dictionary tracking
As with direct attribute assignments like self$l1 <- layer_dense(units = size)
, assigning lists and dictionaries to attributes will track their contents.
save <- tf$train$Checkpoint()
save$listed <- list(tf$Variable(1))
save$listed$append(tf$Variable(2))
save$mapped <- dict(one = save$listed[0])
reticulate::py_set_item(save$mapped, "two", save$listed[1])
#save$mapped['two'] <- save$listed[1]
save_path <- save$save('./tf_list_example')
restore <- tf$train$Checkpoint()
v2 <- tf$Variable(0)
0 == as.numeric(v2) # Not restored yet
#> [1] TRUE
restore$mapped <- dict(two = v2)
restore$restore(save_path)
#> <tensorflow.python.training.tracking.util.CheckpointLoadStatus>
2 == as.numeric(v2)
#> [1] TRUE
You may notice wrapper objects for lists and dictionaries. These wrappers are checkpointable versions of the underlying data-structures. Just like the attribute based loading, these wrappers restore a variable’s value as soon as it’s added to the container.
restore$listed <- list()
print(restore$listed) # ListWrapper([])
#> ListWrapper([])
v1 <- tf$Variable(0.)
restore$listed$append(v1) # Restores v1, from restore() in the previous cell
1. == as.numeric(v1)
#> [1] TRUE
The same tracking is automatically applied to subclasses of tf.keras.Model, and may be used for example to track lists of layers.