Transfer learning with tfhub
This tutorial classifies movie reviews as positive or negative using the text of the review. This is an example of binary — or two-class — classification, an important and widely applicable kind of machine learning problem.
We’ll use the IMDB dataset that contains the text of 50,000 movie reviews from the Internet Movie Database. These are split into 25,000 reviews for training and 25,000 reviews for testing. The training and testing sets are balanced, meaning they contain an equal number of positive and negative reviews.
We will use Keras to build and train the model and tfhub for Transfer Learning. We will also use tfds to load the IMDB dataset.
Let’s start and load the required libraries.
Download the IMDB dataset
The IMDB dataset is available on imdb reviews or on tfds. The one that comes packaged with Keras is already pre-processed so it’s not useful for this tutorial.
The following code downloads the IMDB dataset to your machine:
imdb <- tfds_load(
"imdb_reviews:1.0.0",
split = list("train[:60%]", "train[-40%:]", "test"),
as_supervised = TRUE
)
summary(imdb)
## ── Large Movie Review Dataset.
## This is a dataset for binary sentiment classification containing
## ❯ Name: imdb_reviews
## ❯ Version: 1.0.0
## ❯ URLs: http://ai.stanford.edu/~amaas/data/sentiment/
## Registered S3 method overwritten by 'R.oo':
## method from
## throw.default R.methodsS3
## ❯ Size: 80.2 MiB
## ❯ Splits:
## — test (25000 examples)
## — train (25000 examples)
## — unsupervised (50000 examples)
## ❯ Schema:
## — label [] INT
## — text [] BYTES
tfds_load
returns a TensorFlow Dataset, an abstraction that represents a sequence of elements, in which each element consists of one or more components.
To access individual elements, of a Dataset you can use:
first <- imdb[[1]] %>%
dataset_batch(1) %>% # Used to get only the first example
reticulate::as_iterator() %>%
reticulate::iter_next()
str(first)
## List of 2
## $ :tf.Tensor([b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."], shape=(1,), dtype=string)
## $ :tf.Tensor([0], shape=(1,), dtype=int64)
We will see next that Keras knows how to extract elements from TensorFlow Datasets automatically making it a much more memory efficient alterantive than loading the entire dataset to RAM before passing to Keras.
Build the model
The neural network is created by stacking layers—this requires three main architectural decisions:
- How to represent the text?
- How many layers to use in the model?
- How many hidden units to use for each layer?
In this example, the input data consists of sentences. The labels to predict are either 0 or 1.
One way to represent the text is to convert sentences into embeddings vectors. We can use a pre-trained text embedding as the first layer, which will have three advantages: * we don’t have to worry about text preprocessing, * we can benefit from transfer learning, * the embedding has a fixed size, so it’s simpler to process.
For this example we will use a pre-trained text embedding model from TensorFlow Hub called google/tf2-preview/gnews-swivel-20dim/1.
There are three other pre-trained models to test for the sake of this tutorial:
- google/tf2-preview/gnews-swivel-20dim-with-oov/1 - same as google/tf2-preview/gnews-swivel-20dim/1, but with 2.5% vocabulary converted to OOV buckets. This can help if vocabulary of the task and vocabulary of the model don’t fully overlap.
- google/tf2-preview/nnlm-en-dim50/1 - A much larger model with ~1M vocabulary size and 50 dimensions.
- google/tf2-preview/nnlm-en-dim128/1 - Even larger model with ~1M vocabulary size and 128 dimensions.
Let’s first create a Keras layer that uses a TensorFlow Hub model to embed the sentences, and try it out on a couple of input examples. Note that no matter the length of the input text, the output shape of the embeddings is: (num_examples, embedding_dimension)
.
embedding_layer <- layer_hub(handle = "https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim/1")
embedding_layer(first[[1]])
## tf.Tensor(
## [[ 1.765786 -3.882232 3.9134233 -1.5557289 -3.3362343 -1.7357955
## -1.9954445 1.2989551 5.081598 -1.1041286 -2.0503852 -0.72675157
## -0.65675956 0.24436149 -3.7208383 2.0954835 2.2969332 -2.0689783
## -2.9489717 -1.1315987 ]], shape=(1, 20), dtype=float32)
Let’s now build the full model:
model <- keras_model_sequential() %>%
layer_hub(
handle = "https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim/1",
input_shape = list(),
dtype = tf$string,
trainable = TRUE
) %>%
layer_dense(units = 16, activation = "relu") %>%
layer_dense(units = 1, activation = "sigmoid")
summary(model)
## Model: "sequential"
## ___________________________________________________________________________
## Layer (type) Output Shape Param #
## ===========================================================================
## keras_layer_1 (KerasLayer) (None, 20) 400020
## ___________________________________________________________________________
## dense (Dense) (None, 16) 336
## ___________________________________________________________________________
## dense_1 (Dense) (None, 1) 17
## ===========================================================================
## Total params: 400,373
## Trainable params: 400,373
## Non-trainable params: 0
## ___________________________________________________________________________
The layers are stacked sequentially to build the classifier:
- The first layer is a TensorFlow Hub layer. This layer uses a pre-trained Saved Model to map a sentence into its embedding vector. The pre-trained text embedding model that we are using (google/tf2-preview/gnews-swivel-20dim/1) splits the sentence into tokens, embeds each token and then combines the embedding. The resulting dimensions are: (num_examples, embedding_dimension).
- This fixed-length output vector is piped through a fully-connected (Dense) layer with 16 hidden units.
- The last layer is densely connected with a single output node. Using the sigmoid activation function, this value is a float between 0 and 1, representing a probability, or confidence level.
Let’s compile the model.
Loss function and optimizer
A model needs a loss function and an optimizer for training. Since this is a binary classification problem and the model outputs a probability (a single-unit layer with a sigmoid activation), we’ll use the binary_crossentropy
loss function.
This isn’t the only choice for a loss function, you could, for instance, choose mean_squared_error. But, generally, binary_crossentropy
is better for dealing with probabilities—it measures the “distance” between probability distributions, or in our case, between the ground-truth distribution and the predictions.
Later, when we are exploring regression problems (say, to predict the price of a house), we will see how to use another loss function called mean squared error.
Now, configure the model to use an optimizer and a loss function:
Train the model
Train the model for 20 epochs in mini-batches of 512 samples. This is 20 iterations over all samples in the dataset. While training, monitor the model’s loss and accuracy on the 10,000 samples from the validation set:
model %>%
fit(
imdb[[1]] %>% dataset_shuffle(10000) %>% dataset_batch(512),
epochs = 20,
validation_data = imdb[[2]] %>% dataset_batch(512),
verbose = 2
)
## Epoch 1/20
## 30/30 - 3s - loss: 0.8245 - accuracy: 0.5371 - val_loss: 0.0000e+00 - val_accuracy: 0.0000e+00
## Epoch 2/20
## 30/30 - 2s - loss: 0.6803 - accuracy: 0.5983 - val_loss: 0.6597 - val_accuracy: 0.6131
## Epoch 3/20
## 30/30 - 2s - loss: 0.6277 - accuracy: 0.6553 - val_loss: 0.6107 - val_accuracy: 0.6708
## Epoch 4/20
## 30/30 - 3s - loss: 0.5771 - accuracy: 0.7061 - val_loss: 0.5655 - val_accuracy: 0.7175
## Epoch 5/20
## 30/30 - 3s - loss: 0.5298 - accuracy: 0.7503 - val_loss: 0.5250 - val_accuracy: 0.7506
## Epoch 6/20
## 30/30 - 3s - loss: 0.4901 - accuracy: 0.7849 - val_loss: 0.4887 - val_accuracy: 0.7748
## Epoch 7/20
## 30/30 - 3s - loss: 0.4484 - accuracy: 0.8079 - val_loss: 0.4583 - val_accuracy: 0.7945
## Epoch 8/20
## 30/30 - 4s - loss: 0.4121 - accuracy: 0.8288 - val_loss: 0.4290 - val_accuracy: 0.8144
## Epoch 9/20
## 30/30 - 3s - loss: 0.3824 - accuracy: 0.8455 - val_loss: 0.4071 - val_accuracy: 0.8246
## Epoch 10/20
## 30/30 - 3s - loss: 0.3523 - accuracy: 0.8587 - val_loss: 0.3852 - val_accuracy: 0.8376
## Epoch 11/20
## 30/30 - 3s - loss: 0.3263 - accuracy: 0.8709 - val_loss: 0.3682 - val_accuracy: 0.8442
## Epoch 12/20
## 30/30 - 3s - loss: 0.3019 - accuracy: 0.8829 - val_loss: 0.3539 - val_accuracy: 0.8496
## Epoch 13/20
## 30/30 - 3s - loss: 0.2847 - accuracy: 0.8923 - val_loss: 0.3458 - val_accuracy: 0.8512
## Epoch 14/20
## 30/30 - 3s - loss: 0.2622 - accuracy: 0.9007 - val_loss: 0.3320 - val_accuracy: 0.8591
## Epoch 15/20
## 30/30 - 3s - loss: 0.2476 - accuracy: 0.9111 - val_loss: 0.3236 - val_accuracy: 0.8625
## Epoch 16/20
## 30/30 - 3s - loss: 0.2309 - accuracy: 0.9167 - val_loss: 0.3175 - val_accuracy: 0.8663
## Epoch 17/20
## 30/30 - 3s - loss: 0.2145 - accuracy: 0.9230 - val_loss: 0.3130 - val_accuracy: 0.8677
## Epoch 18/20
## 30/30 - 3s - loss: 0.2023 - accuracy: 0.9288 - val_loss: 0.3086 - val_accuracy: 0.8698
## Epoch 19/20
## 30/30 - 3s - loss: 0.1894 - accuracy: 0.9349 - val_loss: 0.3057 - val_accuracy: 0.8711
## Epoch 20/20
## 30/30 - 3s - loss: 0.1804 - accuracy: 0.9382 - val_loss: 0.3037 - val_accuracy: 0.8724
Evaluate the model
And let’s see how the model performs. Two values will be returned. Loss (a number which represents our error, lower values are better), and accuracy.
## $loss
## [1] 0.3144187
##
## $accuracy
## [1] 0.86488
This fairly naive approach achieves an accuracy of about 87%. With more advanced approaches, the model should get closer to 95%.