We can access the TensorFlow Dataset API via the tfdatasets
package, which enables us to create scalable input pipelines that can be
used with tfestimators. In this vignette, we
demonstrate the capability to stream datasets stored on disk for
training by building a classifier on the iris
dataset.
Let’s assume we’re given a dataset (which could be arbitrarily large) split into training and validation, and a small sample of the dataset. To simulate this scenario, we’ll create a few CSV files as follows:
set.seed(123)
train_idx <- sample(nrow(iris), nrow(iris) * 2/3)
iris_train <- iris[train_idx,]
iris_validation <- iris[-train_idx,]
iris_sample <- iris_train %>%
head(10)
write.csv(iris_train, "iris_train.csv", row.names = FALSE)
write.csv(iris_validation, "iris_validation.csv", row.names = FALSE)
write.csv(iris_sample, "iris_sample.csv", row.names = FALSE)
We construct the classifier as usual – see Estimator Basics for details on feature columns and creating estimators.
library(tfestimators)
response <- "Species"
features <- setdiff(names(iris), response)
feature_columns <- feature_columns(
column_numeric(features)
)
classifier <- dnn_classifier(
feature_columns = feature_columns,
hidden_units = c(16, 32, 16),
n_classes = 3,
label_vocabulary = c("setosa", "virginica", "versicolor")
)
The creation of the input function is similar to the in-memory
case. However, instead of passing data frames or matrices to
iris_input_fn()
, we pass TensorFlow dataset objects which
are internally iterators of the dataset files.
iris_input_fn <- function(data) {
input_fn(data, features = features, response = response)
}
iris_spec <- csv_record_spec("iris_sample.csv")
iris_train <- text_line_dataset(
"iris_train.csv", record_spec = iris_spec) %>%
dataset_batch(10) %>%
dataset_repeat(10)
iris_validation <- text_line_dataset(
"iris_validation.csv", record_spec = iris_spec) %>%
dataset_batch(10) %>%
dataset_repeat(1)
The csv_record_spec()
function is a helper function that
creates a specification from a sample file; the returned specification
is required by the text_line_dataset()
function to parse
the files. There are many transformations
available for dataset objects, but here we just demonstrate
dataset_batch()
and dataset_repeat()
which
control the batch size and how many times we iterate through the dataset
files, respectively.
Once the input functions and datasets are defined, the training and evaluation interface is exactly the same as in the in-memory case.
See the documetnation for the tfdatasets package for additional details on using TensorFlow datasets.