eager_image_captioning
This is the companion code to the post “Attention-based Image Captioning with Keras” on the TensorFlow for R blog.
https://blogs.rstudio.com/tensorflow/posts/2018-09-17-eager-captioning
library(keras)
use_implementation("tensorflow")
library(tensorflow)
tfe_enable_eager_execution(device_policy = "silent")
np <- import("numpy")
library(tfdatasets)
library(purrr)
library(stringr)
library(glue)
library(rjson)
library(rlang)
library(dplyr)
library(magick)
maybecat <- function(context, x) {
if (debugshapes) {
name <- enexpr(x)
dims <- paste0(dim(x), collapse = " ")
cat(context, ": shape of ", name, ": ", dims, "\n", sep = "")
}
}
debugshapes <- FALSE
restore_checkpoint <- FALSE
saved_features_exist <- FALSE
use_session_with_seed(7777,
disable_gpu = FALSE,
disable_parallel_cpu = FALSE)
annotation_file <- "train2014/annotations/captions_train2014.json"
image_path <- "train2014/train2014"
annotations <- fromJSON(file = annotation_file)
annot_captions <- annotations[[4]]
# 414113
num_captions <- length(annot_captions)
all_captions <- vector(mode = "list", length = num_captions)
all_img_names <- vector(mode = "list", length = num_captions)
for (i in seq_len(num_captions)) {
caption <-
paste0("<start> ", annot_captions[[i]][["caption"]], " <end>")
image_id <- annot_captions[[i]][["image_id"]]
full_coco_image_path <-
sprintf("train2014/train2014/COCO_train2014_%012d.jpg", image_id)
all_img_names[[i]] <- full_coco_image_path
all_captions[[i]] <- caption
}
num_examples <- 30000
if (!saved_features_exist) {
random_sample <- sample(1:num_captions, size = num_examples)
train_indices <-
sample(random_sample, size = length(random_sample) * 0.8)
validation_indices <-
setdiff(random_sample, train_indices)
saveRDS(random_sample,
paste0("random_sample_", num_examples, ".rds"))
saveRDS(train_indices,
paste0("train_indices_", num_examples, ".rds"))
saveRDS(validation_indices,
paste0("validation_indices_", num_examples, ".rds"))
} else {
random_sample <-
readRDS(paste0("random_sample_", num_examples, ".rds"))
train_indices <-
readRDS(paste0("train_indices_", num_examples, ".rds"))
validation_indices <-
readRDS(paste0("validation_indices_", num_examples, ".rds"))
}
sample_captions <- all_captions[random_sample]
sample_images <- all_img_names[random_sample]
train_captions <- all_captions[train_indices]
train_images <- all_img_names[train_indices]
validation_captions <- all_captions[validation_indices]
validation_images <- all_img_names[validation_indices]
load_image <- function(image_path) {
img <- tf$read_file(image_path) %>%
tf$image$decode_jpeg(channels = 3) %>%
tf$image$resize_images(c(299L, 299L)) %>%
tf$keras$applications$inception_v3$preprocess_input()
list(img, image_path)
}
image_model <- application_inception_v3(include_top = FALSE,
weights = "imagenet")
if (!saved_features_exist) {
preencode <- unique(sample_images) %>% unlist() %>% sort()
num_unique <- length(preencode)
batch_size_4save <- 1
image_dataset <- tensor_slices_dataset(preencode) %>%
dataset_map(load_image) %>%
dataset_batch(batch_size_4save)
save_iter <- make_iterator_one_shot(image_dataset)
save_count <- 0
until_out_of_range({
if (save_count %% 100 == 0) {
cat("Saving feature:", save_count, "of", num_unique, "\n")
}
save_count <- save_count + batch_size_4save
batch_4save <- save_iter$get_next()
img <- batch_4save[[1]]
path <- batch_4save[[2]]
batch_features <- image_model(img)
batch_features <- tf$reshape(batch_features,
list(dim(batch_features)[1],-1L, dim(batch_features)[4]))
for (i in 1:dim(batch_features)[1]) {
p <- path[i]$numpy()$decode("utf-8")
np$save(p,
batch_features[i, ,]$numpy())
}
})
}
top_k <- 5000
tokenizer <- text_tokenizer(num_words = top_k,
oov_token = "<unk>",
filters = '!"#$%&()*+.,-/:;=?@[\\]^_`{|}~ ')
tokenizer$fit_on_texts(sample_captions)
train_captions_tokenized <-
tokenizer %>% texts_to_sequences(train_captions)
validation_captions_tokenized <-
tokenizer %>% texts_to_sequences(validation_captions)
tokenizer$word_index
tokenizer$word_index["<unk>"]
tokenizer$word_index["<pad>"] <- 0
tokenizer$word_index["<pad>"]
word_index_df <- data.frame(
word = tokenizer$word_index %>% names(),
index = tokenizer$word_index %>% unlist(use.names = FALSE),
stringsAsFactors = FALSE
)
word_index_df <- word_index_df %>% arrange(index)
decode_caption <- function(text) {
paste(map(text, function(number)
word_index_df %>%
filter(index == number) %>%
select(word) %>%
pull()),
collapse = " ")
}
caption_lengths <-
map(all_captions[1:num_examples], function(c)
str_split(c, " ")[[1]] %>% length()) %>% unlist()
fivenum(caption_lengths)
max_length <- fivenum(caption_lengths)[5]
train_captions_padded <-
pad_sequences(
train_captions_tokenized,
maxlen = max_length,
padding = "post",
truncating = "post"
)
validation_captions_padded <-
pad_sequences(
validation_captions_tokenized,
maxlen = max_length,
padding = "post",
truncating = "post"
)
length(train_images)
dim(train_captions_padded)
batch_size <- 10
buffer_size <- num_examples
embedding_dim <- 256
gru_units <- 512
vocab_size <- top_k
features_shape <- 2048
attention_features_shape <- 64
train_images_4checking <- train_images[c(4, 10, 30)]
train_captions_4checking <- train_captions_padded[c(4, 10, 30),]
validation_images_4checking <- validation_images[c(7, 10, 12)]
validation_captions_4checking <-
validation_captions_padded[c(7, 10, 12),]
map_func <- function(img_name, cap) {
p <- paste0(img_name$decode("utf-8"), ".npy")
img_tensor <- np$load(p)
img_tensor <- tf$cast(img_tensor, tf$float32)
list(img_tensor, cap)
}
train_dataset <-
tensor_slices_dataset(list(train_images, train_captions_padded)) %>%
dataset_map(function(item1, item2)
tf$py_func(map_func, list(item1, item2), list(tf$float32, tf$int32))) %>%
# dataset_shuffle(buffer_size) %>%
dataset_batch(batch_size)
cnn_encoder <-
function(embedding_dim,
name = NULL) {
keras_model_custom(name = name, function(self) {
self$fc <-
layer_dense(units = embedding_dim, activation = "relu")
function(x, mask = NULL) {
# input shape: (batch_size, 64, features_shape)
# shape after fc: (batch_size, 64, embedding_dim)
maybecat("encoder input", x)
x <- self$fc(x)
maybecat("encoder output", x)
x
}
})
}
attention_module <-
function(gru_units,
name = NULL) {
keras_model_custom(name = name, function(self) {
self$W1 = layer_dense(units = gru_units)
self$W2 = layer_dense(units = gru_units)
self$V = layer_dense(units = 1)
function(inputs, mask = NULL) {
features <- inputs[[1]]
hidden <- inputs[[2]]
# features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)
# hidden shape == (batch_size, gru_units)
# hidden_with_time_axis shape == (batch_size, 1, gru_units)
hidden_with_time_axis <- k_expand_dims(hidden, axis = 2)
maybecat("attention module", features)
maybecat("attention module", hidden)
maybecat("attention module", hidden_with_time_axis)
# score shape == (batch_size, 64, 1)
score <-
self$V(k_tanh(self$W1(features) + self$W2(hidden_with_time_axis)))
# attention_weights shape == (batch_size, 64, 1)
attention_weights <- k_softmax(score, axis = 2)
# context_vector shape after sum == (batch_size, embedding_dim)
context_vector <-
k_sum(attention_weights * features, axis = 2)
maybecat("attention module", score)
maybecat("attention module", attention_weights)
maybecat("attention module", context_vector)
list(context_vector, attention_weights)
}
})
}
rnn_decoder <-
function(embedding_dim,
gru_units,
vocab_size,
name = NULL) {
keras_model_custom(name = name, function(self) {
self$gru_units <- gru_units
self$embedding <-
layer_embedding(input_dim = vocab_size, output_dim = embedding_dim)
self$gru <- if (tf$test$is_gpu_available()) {
layer_cudnn_gru(
units = gru_units,
return_sequences = TRUE,
return_state = TRUE,
recurrent_initializer = 'glorot_uniform'
)
} else {
layer_gru(
units = gru_units,
return_sequences = TRUE,
return_state = TRUE,
recurrent_initializer = 'glorot_uniform'
)
}
self$fc1 <- layer_dense(units = self$gru_units)
self$fc2 <- layer_dense(units = vocab_size)
self$attention <- attention_module(self$gru_units)
function(inputs, mask = NULL) {
x <- inputs[[1]]
features <- inputs[[2]]
hidden <- inputs[[3]]
maybecat("decoder", x)
maybecat("decoder", features)
maybecat("decoder", hidden)
c(context_vector, attention_weights) %<-% self$attention(list(features, hidden))
# x shape after passing through embedding == (batch_size, 1, embedding_dim)
x <- self$embedding(x)
maybecat("decoder x after embedding", x)
# x shape after concatenation == (batch_size, 1, 2 * embedding_dim)
x <-
k_concatenate(list(k_expand_dims(context_vector, 2), x))
maybecat("decoder x after concat", x)
# passing the concatenated vector to the GRU
c(output, state) %<-% self$gru(x)
maybecat("decoder output after gru", output)
maybecat("decoder state after gru", state)
# shape == (batch_size, 1, gru_units)
x <- self$fc1(output)
maybecat("decoder output after fc1", x)
# x shape == (batch_size, gru_units)
x <- k_reshape(x, c(-1, dim(x)[[3]]))
maybecat("decoder output after reshape", x)
# output shape == (batch_size, vocab_size)
x <- self$fc2(x)
maybecat("decoder output after fc2", x)
list(x, state, attention_weights)
}
})
}
encoder <- cnn_encoder(embedding_dim)
decoder <- rnn_decoder(embedding_dim, gru_units, vocab_size)
optimizer = tf$train$AdamOptimizer()
cx_loss <- function(y_true, y_pred) {
mask <- 1 - k_cast(y_true == 0L, dtype = "float32")
loss <-
tf$nn$sparse_softmax_cross_entropy_with_logits(labels = y_true, logits =
y_pred) * mask
tf$reduce_mean(loss)
}
get_caption <-
function(image) {
attention_matrix <-
matrix(0, nrow = max_length, ncol = attention_features_shape)
# shape=(1, 299, 299, 3)
temp_input <- k_expand_dims(load_image(image)[[1]], 1)
# shape=(1, 8, 8, 2048),
img_tensor_val <- image_model(temp_input)
# shape=(1, 64, 2048)
img_tensor_val <- k_reshape(img_tensor_val,
list(dim(img_tensor_val)[1],-1, dim(img_tensor_val)[4]))
# shape=(1, 64, 256)
features <- encoder(img_tensor_val)
dec_hidden <- k_zeros(c(1, gru_units))
dec_input <-
k_expand_dims(list(word_index_df[word_index_df$word == "<start>", "index"]))
result <- ""
for (t in seq_len(max_length - 1)) {
c(preds, dec_hidden, attention_weights) %<-%
decoder(list(dec_input, features, dec_hidden))
attention_weights <- k_reshape(attention_weights, c(-1))
attention_matrix[t, ] <- attention_weights %>% as.double()
pred_idx = tf$multinomial(exp(preds), num_samples = 1)[1, 1] %>% as.double()
pred_word <-
word_index_df[word_index_df$index == pred_idx, "word"]
if (pred_word == "<end>") {
result <-
paste(result, pred_word)
attention_matrix <-
attention_matrix[1:length(str_split(result, " ")[[1]]), , drop = FALSE]
return (list(str_trim(result), attention_matrix))
} else {
result <-
paste(result, pred_word)
dec_input <- k_expand_dims(list(pred_idx))
}
}
list(str_trim(result), attention_matrix)
}
plot_attention <-
function(attention_matrix,
image_name,
result,
epoch) {
image <-
image_read(image_name) %>% image_scale("299x299!")
result <- str_split(result, " ")[[1]] %>% as.list()
# attention_matrix shape: nrow = max_length, ncol = attention_features_shape
for (i in 1:length(result)) {
att <- attention_matrix[i, ] %>% np$resize(tuple(8L, 8L))
dim(att) <- c(8, 8, 1)
att <- image_read(att) %>% image_scale("299x299") %>%
image_annotate(
result[[i]],
gravity = "northeast",
size = 20,
color = "white",
location = "+20+40"
)
overlay <-
image_composite(att, image, operator = "blend", compose_args = "30")
image_write(
overlay,
paste0(
"attention_plot_epoch_",
epoch,
"_img_",
image_name %>% basename() %>% str_sub(16,-5),
"_word_",
i,
".png"
)
)
}
}
check_sample_captions <-
function(epoch, mode, plot_attention) {
images <- switch(mode,
training = train_images_4checking,
validation = validation_images_4checking)
captions <- switch(mode,
training = train_captions_4checking,
validation = validation_captions_4checking)
cat("\n", "Sample checks on ", mode, " set:", "\n", sep = "")
for (i in 1:length(images)) {
c(result, attention_matrix) %<-% get_caption(images[[i]])
real_caption <-
decode_caption(captions[i,]) %>% str_remove_all(" <pad>")
cat("\nReal caption:", real_caption, "\n")
cat("\nPredicted caption:", result, "\n")
if (plot_attention)
plot_attention(attention_matrix, images[[i]], result, epoch)
}
}
checkpoint_dir <- "./checkpoints_captions"
checkpoint_prefix <- file.path(checkpoint_dir, "ckpt")
checkpoint <-
tf$train$Checkpoint(optimizer = optimizer,
encoder = encoder,
decoder = decoder)
if (restore_checkpoint) {
checkpoint$restore(tf$train$latest_checkpoint(checkpoint_dir))
}
num_epochs <- 20
if (!restore_checkpoint) {
for (epoch in seq_len(num_epochs)) {
cat("Starting epoch:", epoch, "\n")
total_loss <- 0
progress <- 0
train_iter <- make_iterator_one_shot(train_dataset)
until_out_of_range({
progress <- progress + 1
if (progress %% 10 == 0)
cat("-")
batch <- iterator_get_next(train_iter)
loss <- 0
img_tensor <- batch[[1]]
target_caption <- batch[[2]]
dec_hidden <- k_zeros(c(batch_size, gru_units))
dec_input <-
k_expand_dims(rep(list(word_index_df[word_index_df$word == "<start>", "index"]), batch_size))
with(tf$GradientTape() %as% tape, {
features <- encoder(img_tensor)
for (t in seq_len(dim(target_caption)[2] - 1)) {
c(preds, dec_hidden, weights) %<-%
decoder(list(dec_input, features, dec_hidden))
loss <- loss + cx_loss(target_caption[, t], preds)
dec_input <- k_expand_dims(target_caption[, t])
}
})
total_loss <-
total_loss + loss / k_cast_to_floatx(dim(target_caption)[2])
variables <- c(encoder$variables, decoder$variables)
gradients <- tape$gradient(loss, variables)
optimizer$apply_gradients(purrr::transpose(list(gradients, variables)),
global_step = tf$train$get_or_create_global_step())
})
cat(paste0(
"\n\nTotal loss (epoch): ",
epoch,
": ",
(total_loss / k_cast_to_floatx(buffer_size)) %>% as.double() %>% round(4),
"\n"
))
checkpoint$save(file_prefix = checkpoint_prefix)
check_sample_captions(epoch, "training", plot_attention = FALSE)
check_sample_captions(epoch, "validation", plot_attention = FALSE)
}
}
epoch <- num_epochs
check_sample_captions(epoch, "training", plot_attention = TRUE)
check_sample_captions(epoch, "validation", plot_attention = TRUE)