--- title: "Dataset API" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Dataset API} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} type: docs repo: https://github.com/rstudio/tfestimators menu: main: name: "Dataset API" identifier: "tfestimators-dataset-api" parent: "tfestimators-advanced" weight: 40 --- ```{r setup, include=FALSE} knitr::opts_chunk$set(echo = TRUE) knitr::opts_chunk$set(eval = FALSE) ``` ## Overview We can access the TensorFlow Dataset API via the [tfdatasets](https://tensorflow.rstudio.com/tools/tfdatasets/) package, which enables us to create scalable input pipelines that can be used with **tfestimators**. In this vignette, we demonstrate the capability to stream datasets stored on disk for training by building a classifier on the `iris` dataset. ## Dataset Preparation Let's assume we're given a dataset (which could be arbitrarily large) split into training and validation, and a small sample of the dataset. To simulate this scenario, we'll create a few CSV files as follows: ```{r} set.seed(123) train_idx <- sample(nrow(iris), nrow(iris) * 2/3) iris_train <- iris[train_idx,] iris_validation <- iris[-train_idx,] iris_sample <- iris_train %>% head(10) write.csv(iris_train, "iris_train.csv", row.names = FALSE) write.csv(iris_validation, "iris_validation.csv", row.names = FALSE) write.csv(iris_sample, "iris_sample.csv", row.names = FALSE) ``` ## Estimator Construction We construct the classifier as usual -- see [Estimator Basics](https://tensorflow.rstudio.com/guide/tfestimators/estimator_basics/) for details on feature columns and creating estimators. ```{r} library(tfestimators) response <- "Species" features <- setdiff(names(iris), response) feature_columns <- feature_columns( column_numeric(features) ) classifier <- dnn_classifier( feature_columns = feature_columns, hidden_units = c(16, 32, 16), n_classes = 3, label_vocabulary = c("setosa", "virginica", "versicolor") ) ``` ## Input Function The creation of the input function is similar to the [in-memory case](https://tensorflow.rstudio.com/guide/tfestimators/estimator_basics/#input-functions). However, instead of passing data frames or matrices to `iris_input_fn()`, we pass TensorFlow dataset objects which are internally iterators of the dataset files. ```{r} iris_input_fn <- function(data) { input_fn(data, features = features, response = response) } iris_spec <- csv_record_spec("iris_sample.csv") iris_train <- text_line_dataset( "iris_train.csv", record_spec = iris_spec) %>% dataset_batch(10) %>% dataset_repeat(10) iris_validation <- text_line_dataset( "iris_validation.csv", record_spec = iris_spec) %>% dataset_batch(10) %>% dataset_repeat(1) ``` The `csv_record_spec()` function is a helper function that creates a specification from a sample file; the returned specification is required by the `text_line_dataset()` function to parse the files. There are many [transformations](https://tensorflow.rstudio.com/guide/tfdatasets/introduction/#transformations) available for dataset objects, but here we just demonstrate `dataset_batch()` and `dataset_repeat()` which control the batch size and how many times we iterate through the dataset files, respectively. ## Training and Evaluation Once the input functions and datasets are defined, the training and evaluation interface is exactly the same as in the in-memory case. ```{r} history <- train(classifier, input_fn = iris_input_fn(iris_train)) plot(history) predictions <- predict(classifier, input_fn = iris_input_fn(iris_validation)) predictions evaluation <- evaluate(classifier, input_fn = iris_input_fn(iris_validation)) evaluation ``` ## Learning More See the documetnation for the [tfdatasets](https://tensorflow.rstudio.com/tools/tfdatasets/) package for additional details on using TensorFlow datasets.