Writing Custom Keras Models
Overview
In addition to sequential models and models created with the functional API, you may also define models by defining a custom call()
(forward pass) operation.
To create a custom Keras model, you call the keras_model_custom()
function, passing it an R function which in turn returns another R function that implements the custom call()
(forward pass) operation. The R function you pass takes a model
argument, which provides access to the underlying Keras model object should you need it.
Typically, you’ll wrap your call to keras_model_custom()
in yet another function that enables callers to easily instantiate your custom model.
Creating a Custom Model
This example demonstrates the implementation of a simple custom model that implements a multi-layer-perceptron with optional dropout and batch normalization:
library(keras)
keras_model_simple_mlp <- function(num_classes,
use_bn = FALSE, use_dp = FALSE,
name = NULL) {
# define and return a custom model
keras_model_custom(name = name, function(self) {
# create layers we'll need for the call (this code executes once)
self$dense1 <- layer_dense(units = 32, activation = "relu")
self$dense2 <- layer_dense(units = num_classes, activation = "softmax")
if (use_dp)
self$dp <- layer_dropout(rate = 0.5)
if (use_bn)
self$bn <- layer_batch_normalization(axis = -1)
# implement call (this code executes during training & inference)
function(inputs, mask = NULL, training = FALSE) {
x <- self$dense1(inputs)
if (use_dp)
x <- self$dp(x)
if (use_bn)
x <- self$bn(x)
self$dense2(x)
}
})
}
Note that we include a name
parameter so that users can optionally provide a human readable name for the model.
Note also that when we create layers to be used in our forward pass we set them onto the self
object so they are tracked appropriately by Keras.
In call()
, you may specify custom losses by calling self$add_loss()
. You can also access any other members of the Keras model you need (or even add fields to the model) by using self$
.
Using a Custom Model
To use a custom model, just call your model’s high-level wrapper function. For example:
library(keras)
# create the model
model <- keras_model_simple_mlp(num_classes = 10, use_dp = TRUE)
# compile graph
model %>% compile(
loss = 'categorical_crossentropy',
optimizer = optimizer_rmsprop(),
metrics = c('accuracy')
)
# Generate dummy data
data <- matrix(runif(1000*100), nrow = 1000, ncol = 100)
labels <- matrix(round(runif(1000, min = 0, max = 9)), nrow = 1000, ncol = 1)
# Convert labels to categorical one-hot encoding
one_hot_labels <- to_categorical(labels, num_classes = 10)
# Train the model
model %>% fit(data, one_hot_labels, epochs=10, batch_size=32)