Run Hooks
Overview
SessionRunHooks
are useful to track training, report progress, request early stopping and more. Users can attach an arbitrary number of hooks to an estimator. SessionRunHooks
use the observer pattern and notify at the following points:
- when a session starts being used
- before a call to the
session.run()
- after a call to the
session.run()
- when the session closed
A SessionRunHook
encapsulates a piece of reusable/composable computation that can piggyback a call to MonitoredSession.run()
. A hook can add any ops-or-tensor/feeds to the run call, and when the run call finishes with success gets the outputs it requested. Hooks are allowed to add ops to the graph in hook.begin()
. The graph is finalized after the begin()
method is called.
Built-in Run Hooks
There are a few pre-defined SessionRunHooks
available, for example:
Run hooks are useful for tracking training, reporting progress, requesting early stopping, and more. Users can attach an arbitrary number of hooks to an estimator. Some built-in run hooks include:
Method | Description |
---|---|
hook_checkpoint_saver() |
Saves checkpoints every N steps or seconds. |
hook_global_step_waiter() |
Delay execution until global step reaches to wait_until_step. |
hook_history_saver() |
Saves Metrics History. |
hook_logging_tensor() |
Prints the given tensors once every N local steps or once every N seconds. |
hook_nan_tensor() |
NaN Loss monitor. |
hook_progress_bar() |
Creates and updates progress bar. |
hook_step_counter() |
Steps per second monitor. |
hook_stop_at_step() |
Monitor to request stop at a specified step. |
hook_summary_saver() |
Saves summaries every N steps. |
For example, we can use hook_progress_bar()
to attach a hook to create and update a progress bar during the model training process.
fcs <- feature_columns(column_numeric("drat"))
input <- input_fn(mtcars, response = "mpg", features = c("drat", "cyl"), batch_size = 8L)
lr <- linear_regressor(
feature_columns = fcs
) %>% train(
input_fn = input,
steps = 2,
hooks = list(
hook_progress_bar()
))
Training 2/2 [======================================] - ETA: 0s - loss: 3136.10
Another example is to use hook_history_saver()
to save the training history every 2 training steps like the following:
lr <- linear_regressor(feature_columns = fcs)
training_history <- train(
lr,
input_fn = input,
steps = 4,
hooks = list(
hook_history_saver(every_n_step = 2)
))
train()
returns the saved training metrics history:
> training_history
mean_losses total_losses steps
1 343.9690 2751.752 2
2 419.7618 3358.094 4
Custom Run Hooks
Users can also create custom run hooks by defining the behaviors of the hook in different phases of a session.
We can implement a custom run hook by defining a list of call back functions as part of session_run_hook()
initialization. It has the following optional parameters that can be overriden by a custom defined function:
-
begin
: An function with signaturefunction()
, to be called once before using the session. -
after_create_session
: An function with signaturefunction(session, coord)
, to be called once the new TensorFlow session has been created. -
before_run
: An function with signaturefunction(run_context)
to be called before a run. -
after_run
: An function with signaturefunction(run_context, run_values)
to be called after a run. -
end
: An function with signaturefunction(session)
to be called at the end of the session.
For example, let’s try to implement the hook_history_saver()
that we showed in previous section. We first initialize a iter_count
variable to save the current count of the steps being run. We increment the count as part of after_run()
after each session.run
calls. Inside before_run()
, we use the context to access the current losses and save it to a tensor named “losses” so that later we can access it inside after_run()
via values$results$losses
that contains the evaluated value of tensor “losses”. Finally, we calculate the mean of the raw losses and append it to a global state varibale named “mean_losses_history” with the list of mean losses.
mean_losses_history <<- NULL
hook_history_saver_custom <- function(every_n_step) {
iter_count <<- 0
session_run_hook(
before_run = function(context) {
session_run_args(
losses = context$session$graph$get_collection("losses")
)
},
after_run = function(context, values) {
iter_count <<- iter_count + 1
print(paste0("Running step: ", iter_count))
if (iter_count %% every_n_step == 0) {
raw_losses <- values$results$losses[[1]]
mean_losses_history <<- c(mean_losses_history, mean(raw_losses))
}
}
)
}
Next, we can then attach this hook to our estimator:
lr <- linear_regressor(
feature_columns = fcs
) %>% train(
input_fn = input,
steps = 4,
hooks = list(
hook_history_saver_custom(every_n_step = 1)
))
[1] "Running step: 1"
[1] "Running step: 2"
[1] "Running step: 3"
[1] "Running step: 4"
We saved the losses history at every step. Let’s check the list of losses:
> mean_losses_history
[1] 415.8088 452.2128 376.7346 331.6045