Distributed training with Keras
Overview
The tf$distribute$Strategy
API provides an abstraction for distributing your training across multiple processing units. The goal is to allow users to enable distributed training using existing models and training code, with minimal changes.
This tutorial uses the tf$distribute$MirroredStrategy
, which does in-graph replication with synchronous training on many GPUs on one machine. Essentially, it copies all of the model’s variables to each processor. Then, it uses all-reduce to combine the gradients from all processors and applies the combined value to all copies of the model.
MirroredStategy
is one of several distribution strategy available in TensorFlow core.
Keras API
This example uses the keras
API to build the model and training loop.
Download the dataset
Download the MNIST dataset and load it using tfds. This returns a dataset in tfdatasets
format.
# if you haven't done it yet:
# tfds::install_tfds()
mnist <- tfds_load("mnist")
info <- summary(mnist)
Define distribution strategy
Create a MirroredStrategy
object. This will handle distribution, and provides a context manager (tf$distribute$MirroredStrategy$scope
) to build your model inside.
strategy <- tf$distribute$MirroredStrategy()
strategy$num_replicas_in_sync
Setup input pipeline
When training a model with multiple GPUs, you can use the extra computing power effectively by increasing the batch size. In general, use the largest batch size that fits the GPU memory, and tune the learning rate accordingly.
num_train_examples <- as.integer(info$splits[[2]]$statistics$numExamples)
num_test_examples <- as.integer(info$splits[[1]]$statistics$numExamples)
BUFFER_SIZE <- 10000
BATCH_SIZE_PER_REPLICA <- 64
BATCH_SIZE <- BATCH_SIZE_PER_REPLICA * strategy$num_replicas_in_sync
Pixel values, which are 0-255, have to be normalized to the 0-1 range. Furthermore, we shuffle and batch the train and test datasets. Notice we are also keeping an in-memory cache of the training data to improve performance.
train_dataset <- mnist$train %>%
dataset_map(function(record) {
record$image <- tf$cast(record$image, tf$float32) / 255
record}) %>%
dataset_cache() %>%
dataset_shuffle(BUFFER_SIZE) %>%
dataset_batch(BATCH_SIZE) %>%
dataset_map(unname)
test_dataset <- mnist$test %>%
dataset_map(function(record) {
record$image <- tf$cast(record$image, tf$float32) / 255
record}) %>%
dataset_batch(BATCH_SIZE) %>%
dataset_map(unname)
Create the model
Create and compile the Keras model in the context of strategy$scope
.
with (strategy$scope(), {
model <- keras_model_sequential() %>%
layer_conv_2d(
filters = 32,
kernel_size = 3,
activation = 'relu',
input_shape = c(28, 28, 1)
) %>%
layer_max_pooling_2d() %>%
layer_flatten() %>%
layer_dense(units = 64, activation = 'relu') %>%
layer_dense(units = 10, activation = 'softmax')
model %>% compile(
loss = 'sparse_categorical_crossentropy',
optimizer = 'adam',
metrics = 'accuracy')
})
Define the callbacks
The callbacks used here are:
- TensorBoard: This callback writes a log for TensorBoard which allows you to visualize the graphs.
- Model Checkpoint: This callback saves the model after every epoch.
- Learning Rate Scheduler: Using this callback, you can schedule the learning rate to change after every epoch/batch.
For illustrative purposes, add a print callback to display the learning rate.
# Define the checkpoint directory to store the checkpoints
checkpoint_dir <- './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix <- file.path(checkpoint_dir, "ckpt_{epoch}")
# Function for decaying the learning rate.
# You can define any decay function you need.
decay <- function(epoch, lr) {
if (epoch < 3) 1e-3
else if (epoch >= 3 && epoch < 7) 1e-4
else 1e-5
}
# Callback for printing the LR at the end of each epoch.
PrintLR <- R6::R6Class("PrintLR",
inherit = KerasCallback,
public = list(
losses = NULL,
on_epoch_end = function(epoch, logs = list()) {
tf$print(glue('\nLearning rate for epoch {epoch} is {as.numeric(model$optimizer$lr)}\n'))
}
))
print_lr <- PrintLR$new()
callbacks <- list(
callback_tensorboard(log_dir = '/tmp/logs'),
callback_model_checkpoint(filepath = checkpoint_prefix, save_weights_only = TRUE),
callback_learning_rate_scheduler(decay),
print_lr
)
Train and evaluate
Now, train the model in the usual way, calling fit on the model and passing in the dataset created at the beginning of the tutorial. This step is the same whether you are distributing the training or not.
model %>% fit(train_dataset, epochs = 12, callbacks = callbacks)
As you can see below, the checkpoints are getting saved.
list.files(checkpoint_dir)
To see how the model performs, load the latest checkpoint and call evaluate on the test data.
model %>% load_model_weights_tf(tf$train$latest_checkpoint(checkpoint_dir))
model %>% evaluate(test_dataset)
tensorboard(log_dir = "/tmp/logs")
Export to SavedModel
Export the graph and the variables to the platform-agnostic SavedModel format. After your model is saved, you can load it with or without the scope.
path <- 'saved_model/'
model %>% save_model_tf(path)
Load the model without strategy$scope
.
unreplicated_model <- load_model_tf(path)
unreplicated_model %>% compile(
loss = 'sparse_categorical_crossentropy',
optimizer = 'adam',
metrics = 'accuracy')
unreplicated_model %>% evaluate(test_dataset)
Load the model with strategy$scope
.
with (strategy$scope(), {
replicated_model <- load_model_tf(path)
replicated_model %>% compile(
loss = 'sparse_categorical_crossentropy',
optimizer = 'adam',
metrics = 'accuracy')
replicated_model %>% evaluate(test_dataset)
})