--- title: Customizing what happens in `fit()` with TensorFlow date-created: 2020/04/15 last-modified: 2023/06/27 description: Overriding the training step of the Model class with TensorFlow. output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Customizing what happens in `fit()` with TensorFlow} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ## Introduction When you're doing supervised learning, you can use `fit()` and everything works smoothly. When you need to take control of every little detail, you can write your own training loop entirely from scratch. But what if you need a custom training algorithm, but you still want to benefit from the convenient features of `fit()`, such as callbacks, built-in distribution support, or step fusing? A core principle of Keras is **progressive disclosure of complexity**. You should always be able to get into lower-level workflows in a gradual way. You shouldn't fall off a cliff if the high-level functionality doesn't exactly match your use case. You should be able to gain more control over the small details while retaining a commensurate amount of high-level convenience. When you need to customize what `fit()` does, you should **override the training step function of the `Model` class**. This is the function that is called by `fit()` for every batch of data. You will then be able to call `fit()` as usual -- and it will be running your own learning algorithm. Note that this pattern does not prevent you from building models with the Functional API. You can do this whether you're building `Sequential` models, Functional API models, or subclassed models. Let's see how that works. ## Setup ``` r library(reticulate) library(tensorflow, exclude = c("set_random_seed", "shape")) library(keras3) ``` ## A first simple example Let's start from a simple example: - We create a new class that subclasses `Model`. - We just override the method `train_step(self, data)`. - We return a dictionary mapping metric names (including the loss) to their current value. The input argument `data` is what gets passed to fit as training data: - If you pass arrays, by calling `fit(x, y, ...)`, then `data` will be the list `(x, y)` - If you pass a `tf_dataset`, by calling `fit(dataset, ...)`, then `data` will be what gets yielded by `dataset` at each batch. In the body of the `train_step()` method, we implement a regular training update, similar to what you are already familiar with. Importantly, **we compute the loss via `self.compute_loss()`**, which wraps the loss(es) function(s) that were passed to `compile()`. Similarly, we call `metric$update_state(y, y_pred)` on metrics from `self$metrics`, to update the state of the metrics that were passed in `compile()`, and we query results from `self$metrics` at the end to retrieve their current value. ``` r CustomModel <- new_model_class( "CustomModel", train_step = function(data) { c(x, y = NULL, sample_weight = NULL) %<-% data with(tf$GradientTape() %as% tape, { y_pred <- self(x, training = TRUE) loss <- self$compute_loss(y = y, y_pred = y_pred, sample_weight = sample_weight) }) # Compute gradients trainable_vars <- self$trainable_variables gradients <- tape$gradient(loss, trainable_vars) # Update weights self$optimizer$apply(gradients, trainable_vars) # Update metrics (includes the metric that tracks the loss) for (metric in self$metrics) { if (metric$name == "loss") metric$update_state(loss) else metric$update_state(y, y_pred) } # Return a dict mapping metric names to current value metrics <- lapply(self$metrics, function(m) m$result()) metrics <- setNames(metrics, sapply(self$metrics, function(m) m$name)) metrics } ) ``` Let's try this out: ``` r # Construct and compile an instance of CustomModel inputs <- keras_input(shape = 32) outputs <- layer_dense(inputs, 1) model <- CustomModel(inputs, outputs) model |> compile(optimizer = "adam", loss = "mse", metrics = "mae") # Just use `fit` as usual x <- random_normal(c(1000, 32)) y <- random_normal(c(1000, 1)) model |> fit(x, y, epochs = 3) ``` ``` ## Epoch 1/3 ## 32/32 - 1s - 23ms/step - loss: 3.2271 - mae: 1.4339 ## Epoch 2/3 ## 32/32 - 0s - 1ms/step - loss: 2.9034 - mae: 1.3605 ## Epoch 3/3 ## 32/32 - 0s - 1ms/step - loss: 2.6272 - mae: 1.2960 ``` ## Going lower-level Naturally, you could just skip passing a loss function in `compile()`, and instead do everything *manually* in `train_step`. Likewise for metrics. Here's a lower-level example, that only uses `compile()` to configure the optimizer: - We start by creating `Metric` instances to track our loss and a MAE score (in `__init__()`). - We implement a custom `train_step()` that updates the state of these metrics (by calling `update_state()` on them), then query them (via `result()`) to return their current average value, to be displayed by the progress bar and to be pass to any callback. - Note that we would need to call `reset_states()` on our metrics between each epoch! Otherwise calling `result()` would return an average since the start of training, whereas we usually work with per-epoch averages. Thankfully, the framework can do that for us: just list any metric you want to reset in the `metrics` property of the model. The model will call `reset_states()` on any object listed here at the beginning of each `fit()` epoch or at the beginning of a call to `evaluate()`. ``` r CustomModel <- new_model_class( "CustomModel", initialize = function(...) { super$initialize(...) self$loss_tracker <- metric_mean(name = "loss") self$mae_metric <- metric_mean_absolute_error(name = "mae") self$loss_fn <- loss_mean_squared_error() }, train_step = function(data) { c(x, y = NULL, sample_weight = NULL) %<-% data with(tf$GradientTape() %as% tape, { y_pred <- self(x, training = TRUE) loss <- self$loss_fn(y, y_pred, sample_weight = sample_weight) }) # Compute gradients trainable_vars <- self$trainable_variables gradients <- tape$gradient(loss, trainable_vars) # Update weights self$optimizer$apply(gradients, trainable_vars) # Compute our own metrics self$loss_tracker$update_state(loss) self$mae_metric$update_state(y, y_pred) # Return a dict mapping metric names to current value list( loss = self$loss_tracker$result(), mae = self$mae_metric$result() ) }, metrics = mark_active(function() { # We list our `Metric` objects here so that `reset_states()` can be # called automatically at the start of each epoch # or at the start of `evaluate()`. list(self$loss_tracker, self$mae_metric) }) ) # Construct and compile an instance of CustomModel inputs <- keras_input(shape = 32) outputs <- layer_dense(inputs, 1) model <- CustomModel(inputs, outputs) # We don't pass a loss or metrics here. model |> compile(optimizer = "adam") # Just use `fit` as usual x <- random_normal(c(1000, 32)) y <- random_normal(c(1000, 1)) model |> fit(x, y, epochs = 3) ``` ``` ## Epoch 1/3 ## 32/32 - 1s - 22ms/step - loss: 2.5170 - mae: 1.2923 ## Epoch 2/3 ## 32/32 - 0s - 1ms/step - loss: 2.2689 - mae: 1.2241 ## Epoch 3/3 ## 32/32 - 0s - 1ms/step - loss: 2.0578 - mae: 1.1633 ``` ## Supporting `sample_weight` & `class_weight` You may have noticed that our first basic example didn't make any mention of sample weighting. If you want to support the `fit()` arguments `sample_weight` and `class_weight`, you'd simply do the following: - Unpack `sample_weight` from the `data` argument - Pass it to `compute_loss` & `update_state` (of course, you could also just apply it manually if you don't rely on `compile()` for losses & metrics) - That's it. ``` r CustomModel <- new_model_class( "CustomModel", train_step = function(data) { c(x, y = NULL, sample_weight = NULL) %<-% data with(tf$GradientTape() %as% tape, { y_pred <- self(x, training = TRUE) loss <- self$compute_loss(y = y, y_pred = y_pred, sample_weight = sample_weight) }) # Compute gradients trainable_vars <- self$trainable_variables gradients <- tape$gradient(loss, trainable_vars) # Update weights self$optimizer$apply_gradients(zip_lists(gradients, trainable_vars)) # Update metrics (includes the metric that tracks the loss) for (metric in self$metrics) { if (metric$name == "loss") { metric$update_state(loss) } else { metric$update_state(y, y_pred, sample_weight = sample_weight) } } # Return a dict mapping metric names to current value metrics <- lapply(self$metrics, function(m) m$result()) metrics <- setNames(metrics, sapply(self$metrics, function(m) m$name)) metrics } ) # Construct and compile an instance of CustomModel inputs <- keras_input(shape = 32) outputs <- layer_dense(inputs, units = 1) model <- CustomModel(inputs, outputs) model |> compile(optimizer = "adam", loss = "mse", metrics = "mae") # You can now use sample_weight argument x <- random_normal(c(1000, 32)) y <- random_normal(c(1000, 1)) sw <- random_normal(c(1000, 1)) model |> fit(x, y, sample_weight = sw, epochs = 3) ``` ``` ## Epoch 1/3 ## 32/32 - 1s - 26ms/step - loss: 0.1681 - mae: 1.3434 ## Epoch 2/3 ## 32/32 - 0s - 9ms/step - loss: 0.1394 - mae: 1.3364 ## Epoch 3/3 ## 32/32 - 0s - 1ms/step - loss: 0.1148 - mae: 1.3286 ``` ## Providing your own evaluation step What if you want to do the same for calls to `model.evaluate()`? Then you would override `test_step` in exactly the same way. Here's what it looks like: ``` r CustomModel <- new_model_class( "CustomModel", test_step = function(data) { # Unpack the data c(x, y = NULL, sw = NULL) %<-% data # Compute predictions y_pred = self(x, training = FALSE) # Updates the metrics tracking the loss self$compute_loss(y = y, y_pred = y_pred, sample_weight = sw) # Update the metrics. for (metric in self$metrics) { if (metric$name != "loss") { metric$update_state(y, y_pred, sample_weight = sw) } } # Return a dict mapping metric names to current value. # Note that it will include the loss (tracked in self.metrics). metrics <- lapply(self$metrics, function(m) m$result()) metrics <- setNames(metrics, sapply(self$metrics, function(m) m$name)) metrics } ) # Construct an instance of CustomModel inputs <- keras_input(shape = 32) outputs <- layer_dense(inputs, 1) model <- CustomModel(inputs, outputs) model |> compile(loss = "mse", metrics = "mae") # Evaluate with our custom test_step x <- random_normal(c(1000, 32)) y <- random_normal(c(1000, 1)) model |> evaluate(x, y) ``` ``` ## 32/32 - 0s - 10ms/step - loss: 0.0000e+00 - mae: 1.3871 ``` ``` ## $loss ## [1] 0 ## ## $mae ## [1] 1.387149 ``` ## Wrapping up: an end-to-end GAN example Let's walk through an end-to-end example that leverages everything you just learned. Let's consider: - A generator network meant to generate 28x28x1 images. - A discriminator network meant to classify 28x28x1 images into two classes ("fake" and "real"). - One optimizer for each. - A loss function to train the discriminator. ``` r # Create the discriminator discriminator <- keras_model_sequential(name = "discriminator", input_shape = c(28, 28, 1)) |> layer_conv_2d(filters = 64, kernel_size = c(3, 3), strides = c(2, 2), padding = "same") |> layer_activation_leaky_relu(negative_slope = 0.2) |> layer_conv_2d(filters = 128, kernel_size = c(3, 3), strides = c(2, 2), padding = "same") |> layer_activation_leaky_relu(negative_slope = 0.2) |> layer_global_max_pooling_2d() |> layer_dense(units = 1) # Create the generator latent_dim <- 128 generator <- keras_model_sequential(name = "generator", input_shape = latent_dim) |> layer_dense(7 * 7 * 128) |> layer_activation_leaky_relu(negative_slope = 0.2) |> layer_reshape(target_shape = c(7, 7, 128)) |> layer_conv_2d_transpose(filters = 128, kernel_size = c(4, 4), strides = c(2, 2), padding = "same") |> layer_activation_leaky_relu(negative_slope = 0.2) |> layer_conv_2d_transpose(filters = 128, kernel_size = c(4, 4), strides = c(2, 2), padding = "same") |> layer_activation_leaky_relu(negative_slope = 0.2) |> layer_conv_2d(filters = 1, kernel_size = c(7, 7), padding = "same", activation = "sigmoid") ``` Here's a feature-complete GAN class, overriding `compile()` to use its own signature, and implementing the entire GAN algorithm in 17 lines in `train_step`: ``` r GAN <- Model( classname = "GAN", initialize = function(discriminator, generator, latent_dim, ...) { super$initialize(...) self$discriminator <- discriminator self$generator <- generator self$latent_dim <- as.integer(latent_dim) self$d_loss_tracker <- metric_mean(name = "d_loss") self$g_loss_tracker <- metric_mean(name = "g_loss") }, compile = function(d_optimizer, g_optimizer, loss_fn, ...) { super$compile(...) self$d_optimizer <- d_optimizer self$g_optimizer <- g_optimizer self$loss_fn <- loss_fn }, metrics = active_property(function() { list(self$d_loss_tracker, self$g_loss_tracker) }), train_step = function(real_images) { # Sample random points in the latent space batch_size <- shape(real_images)[[1]] random_latent_vectors <- tf$random$normal(shape(batch_size, self$latent_dim)) # Decode them to fake images generated_images <- self$generator(random_latent_vectors) # Combine them with real images combined_images <- op_concatenate(list(generated_images, real_images)) # Assemble labels discriminating real from fake images labels <- op_concatenate(list(op_ones(c(batch_size, 1)), op_zeros(c(batch_size, 1)))) # Add random noise to the labels - important trick! labels %<>% `+`(tf$random$uniform(shape(.), maxval = 0.05)) # Train the discriminator with(tf$GradientTape() %as% tape, { predictions <- self$discriminator(combined_images) d_loss <- self$loss_fn(labels, predictions) }) grads <- tape$gradient(d_loss, self$discriminator$trainable_weights) self$d_optimizer$apply_gradients( zip_lists(grads, self$discriminator$trainable_weights)) # Sample random points in the latent space random_latent_vectors <- tf$random$normal(shape(batch_size, self$latent_dim)) # Assemble labels that say "all real images" misleading_labels <- op_zeros(c(batch_size, 1)) # Train the generator (note that we should *not* update the weights # of the discriminator)! with(tf$GradientTape() %as% tape, { predictions <- self$discriminator(self$generator(random_latent_vectors)) g_loss <- self$loss_fn(misleading_labels, predictions) }) grads <- tape$gradient(g_loss, self$generator$trainable_weights) self$g_optimizer$apply_gradients( zip_lists(grads, self$generator$trainable_weights)) list(d_loss = d_loss, g_loss = g_loss) } ) ``` Let's test-drive it: ``` r batch_size <- 64 c(c(x_train, .), c(x_test, .)) %<-% dataset_mnist() all_digits <- op_concatenate(list(x_train, x_test)) all_digits <- op_reshape(all_digits, c(-1, 28, 28, 1)) dataset <- all_digits |> tfdatasets::tensor_slices_dataset() |> tfdatasets::dataset_map(\(x) op_cast(x, "float32") / 255) |> tfdatasets::dataset_shuffle(buffer_size = 1024) |> tfdatasets::dataset_batch(batch_size = batch_size) gan <- GAN(discriminator = discriminator, generator = generator, latent_dim = latent_dim) gan |> compile( d_optimizer = optimizer_adam(learning_rate = 0.0003), g_optimizer = optimizer_adam(learning_rate = 0.0003), loss_fn = loss_binary_crossentropy(from_logits = TRUE) ) # To limit the execution time, we only train on 100 batches. You can train on # the entire dataset. You will need about 20 epochs to get nice results. gan |> fit( tfdatasets::dataset_take(dataset, 100), epochs = 1 ) ``` ``` ## 100/100 - 5s - 54ms/step - d_loss: 0.0000e+00 - g_loss: 0.0000e+00 ``` The ideas behind deep learning are simple, so why should their implementation be painful?