Distributed training with Keras

Overview

The tf$distribute$Strategy API provides an abstraction for distributing your training across multiple processing units. The goal is to allow users to enable distributed training using existing models and training code, with minimal changes.

This tutorial uses the tf$distribute$MirroredStrategy, which does in-graph replication with synchronous training on many GPUs on one machine. Essentially, it copies all of the model’s variables to each processor. Then, it uses all-reduce to combine the gradients from all processors and applies the combined value to all copies of the model.

MirroredStategy is one of several distribution strategy available in TensorFlow core.

Keras API

This example uses the keras API to build the model and training loop.

library(tensorflow)
library(keras)
library(tfdatasets)
# used to load the MNIST dataset
library(tfds)

library(purrr)
library(glue)

Download the dataset

Download the MNIST dataset and load it using tfds. This returns a dataset in tfdatasets format.

# if you haven't done it yet:
# tfds::install_tfds()
mnist <- tfds_load("mnist")
info <- summary(mnist)

Define distribution strategy

Create a MirroredStrategy object. This will handle distribution, and provides a context manager (tf$distribute$MirroredStrategy$scope) to build your model inside.

strategy <- tf$distribute$MirroredStrategy()
strategy$num_replicas_in_sync

Setup input pipeline

When training a model with multiple GPUs, you can use the extra computing power effectively by increasing the batch size. In general, use the largest batch size that fits the GPU memory, and tune the learning rate accordingly.

num_train_examples <- as.integer(info$splits[[2]]$statistics$numExamples)
num_test_examples <- as.integer(info$splits[[1]]$statistics$numExamples)

BUFFER_SIZE <- 10000

BATCH_SIZE_PER_REPLICA <- 64
BATCH_SIZE <- BATCH_SIZE_PER_REPLICA * strategy$num_replicas_in_sync

Pixel values, which are 0-255, have to be normalized to the 0-1 range. Furthermore, we shuffle and batch the train and test datasets. Notice we are also keeping an in-memory cache of the training data to improve performance.

train_dataset <- mnist$train %>% 
  dataset_map(function(record) {
    record$image <- tf$cast(record$image, tf$float32) / 255
    record}) %>%
  dataset_cache() %>%
  dataset_shuffle(BUFFER_SIZE) %>% 
  dataset_batch(BATCH_SIZE) %>% 
  dataset_map(unname)

test_dataset <- mnist$test %>% 
 dataset_map(function(record) {
    record$image <- tf$cast(record$image, tf$float32) / 255
    record}) %>%
  dataset_batch(BATCH_SIZE) %>% 
  dataset_map(unname)

Create the model

Create and compile the Keras model in the context of strategy$scope.

with (strategy$scope(), {
   model <- keras_model_sequential() %>%
     layer_conv_2d(
       filters = 32,
       kernel_size = 3,
       activation = 'relu',
       input_shape = c(28, 28, 1)
       ) %>%
     layer_max_pooling_2d() %>%
     layer_flatten() %>%
     layer_dense(units = 64, activation = 'relu') %>%
     layer_dense(units = 10, activation = 'softmax')
   
  model %>% compile(
    loss = 'sparse_categorical_crossentropy',
    optimizer = 'adam',
    metrics = 'accuracy')
})

Define the callbacks

The callbacks used here are:

  • TensorBoard: This callback writes a log for TensorBoard which allows you to visualize the graphs.
  • Model Checkpoint: This callback saves the model after every epoch.
  • Learning Rate Scheduler: Using this callback, you can schedule the learning rate to change after every epoch/batch.

For illustrative purposes, add a print callback to display the learning rate.

# Define the checkpoint directory to store the checkpoints
checkpoint_dir <- './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix <- file.path(checkpoint_dir, "ckpt_{epoch}")
# Function for decaying the learning rate.
# You can define any decay function you need.
decay <- function(epoch, lr) {
  if (epoch < 3) 1e-3
    else if (epoch >= 3 && epoch < 7) 1e-4
      else 1e-5
}
# Callback for printing the LR at the end of each epoch.
PrintLR <- R6::R6Class("PrintLR",
  inherit = KerasCallback,
  
  public = list(
    
    losses = NULL,
     
    on_epoch_end = function(epoch, logs = list()) {
      tf$print(glue('\nLearning rate for epoch {epoch} is {as.numeric(model$optimizer$lr)}\n'))
    }
))

print_lr <- PrintLR$new()
callbacks <- list(
    callback_tensorboard(log_dir = '/tmp/logs'),
    callback_model_checkpoint(filepath = checkpoint_prefix, save_weights_only = TRUE),
    callback_learning_rate_scheduler(decay),
    print_lr
)

Train and evaluate

Now, train the model in the usual way, calling fit on the model and passing in the dataset created at the beginning of the tutorial. This step is the same whether you are distributing the training or not.

model %>% fit(train_dataset, epochs = 12, callbacks = callbacks)

As you can see below, the checkpoints are getting saved.

list.files(checkpoint_dir)

To see how the model performs, load the latest checkpoint and call evaluate on the test data.

model %>% load_model_weights_tf(tf$train$latest_checkpoint(checkpoint_dir))

model %>% evaluate(test_dataset)
tensorboard(log_dir = "/tmp/logs")

Export to SavedModel

Export the graph and the variables to the platform-agnostic SavedModel format. After your model is saved, you can load it with or without the scope.

path <- 'saved_model/'
model %>% save_model_tf(path)

Load the model without strategy$scope.

unreplicated_model <- load_model_tf(path)

unreplicated_model %>% compile(
    loss = 'sparse_categorical_crossentropy',
    optimizer = 'adam',
    metrics = 'accuracy')

unreplicated_model %>% evaluate(test_dataset)

Load the model with strategy$scope.

with (strategy$scope(), {
  replicated_model <- load_model_tf(path)

  replicated_model %>% compile(
    loss = 'sparse_categorical_crossentropy',
    optimizer = 'adam',
    metrics = 'accuracy')

  replicated_model %>% evaluate(test_dataset)
})