diff --git a/README.rst b/README.rst index 28659dc4da..9cb81c3211 100644 --- a/README.rst +++ b/README.rst @@ -752,12 +752,25 @@ Preparing the TensorFlow training script ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Your TensorFlow training script must be a **Python 2.7** source file. The current supported TensorFlow -versions are **1.6.0 (default)**, **1.5.0**, and **1.4.1**. This training script **must contain** the following functions: +versions are **1.6.0 (default)**, **1.5.0**, and **1.4.1**. The SageMaker TensorFlow docker image +uses this script by calling specifically-named functions from this script. + +The training script **must contain** the following: + +- Exactly one of the following: + + - ``model_fn``: defines the model that will be trained. + - ``keras_model_fn``: defines the ``tf.keras`` model that will be trained. + - ``estimator_fn``: defines the ``tf.estimator.Estimator`` that will train the model. -- ``model_fn``: defines the model that will be trained. - ``train_input_fn``: preprocess and load training data. - ``eval_input_fn``: preprocess and load evaluation data. -- ``serving_input_fn``: defines the features to be passed to the model during prediction. + +In addition, it may optionally contain: + +- ``serving_input_fn``: Defines the features to be passed to the model during prediction. **Important:** + this function is used only during training, but is required to deploy the model resulting from training + in a SageMaker endpoint. Creating a ``model_fn`` ^^^^^^^^^^^^^^^^^^^^^^^ @@ -793,6 +806,8 @@ The ``model_fn`` must accept four positional arguments: your TensorFlow training script. You can use this to pass hyperparameters to your training script. +The ``model_fn`` must return a ``tf.estimator.EstimatorSpec``. + Example of a complete ``model_fn`` '''''''''''''''''''''''''''''''''' @@ -875,9 +890,9 @@ The basic skeleton for the ``train_input_fn`` looks like this: # Logic to the following: # 1. Reads the **training** dataset files located in training_dir # 2. Preprocess the dataset - # 3. Return 1) a mapping of feature columns to Tensors with + # 3. Return 1) a dict of feature names to Tensors with # the corresponding feature data, and 2) a Tensor containing labels - return feature_cols, labels + return features, labels An ``eval_input_fn`` follows the same format: @@ -887,9 +902,12 @@ An ``eval_input_fn`` follows the same format: # Logic to the following: # 1. Reads the **evaluation** dataset files located in training_dir # 2. Preprocess the dataset - # 3. Return 1) a mapping of feature columns to Tensors with + # 3. Return 1) a dict of feature names to Tensors with # the corresponding feature data, and 2) a Tensor containing labels - return feature_cols, labels + return features, labels + +**Note:** For TensorFlow 1.4 and 1.5, ``train_input_fn`` and ``eval_input_fn`` may also return a no-argument +function which returns the tuple ``features, labels``. This is no longer supported for TensorFlow 1.6 and up. Example of a complete ``train_input_fn`` and ``eval_input_fn`` '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' @@ -922,14 +940,9 @@ More details on how to create input functions can be find in `Building Input Fun Creating a ``serving_input_fn`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -``serving_input_fn`` is used to define the shapes and types of the inputs -the model accepts when the model is exported for Tensorflow Serving. ``serving_input_fn`` is called -at the end of model training and is not called during inference. (If you'd like to preprocess inference data, -please see ``input_fn``). This function has the following purposes: +``serving_input_fn`` is used to define the shapes and types of the inputs the model accepts when the model is exported for Tensorflow Serving. It is optional, but required for deploying the trained model to a SageMaker endpoint. -- To add placeholders to the graph that the serving system will feed with inference requests. -- To add any additional ops needed to convert data from the input format into the feature Tensors - expected by the model. +``serving_input_fn`` is called at the end of model training and is **not** called during inference. (If you'd like to preprocess inference data, please see **Overriding input preprocessing with an input_fn**). The basic skeleton for the ``serving_input_fn`` looks like this: @@ -939,8 +952,10 @@ The basic skeleton for the ``serving_input_fn`` looks like this: # Logic to the following: # 1. Defines placeholders that TensorFlow serving will feed with inference requests # 2. Preprocess input data - # 3. Returns a tf.estimator.export.ServingInputReceiver object, which packages the placeholders - and the resulting feature Tensors together. + # 3. Returns a tf.estimator.export.ServingInputReceiver or tf.estimator.export.TensorServingInputReceiver, + # which packages the placeholders and the resulting feature Tensors together. + +**Note:** For TensorFlow 1.4 and 1.5, ``serving_input_fn`` may also return a no-argument function which returns a ``tf.estimator.export.ServingInputReceiver`` or``tf.estimator.export.TensorServingInputReceiver``. This is no longer supported for TensorFlow 1.6 and up. Example of a complete ``serving_input_fn`` '''''''''''''''''''''''''''''''''''''''''' @@ -1137,15 +1152,23 @@ These hyperparameters are used by TensorFlow to fine tune the training. You need to add them inside the hyperparameters dictionary in the ``TensorFlow`` estimator constructor. +**All versions** + - ``save_summary_steps (int)`` Save summaries every this many steps. - ``save_checkpoints_secs (int)`` Save checkpoints every this many seconds. Can not be specified with ``save_checkpoints_steps``. - ``save_checkpoints_steps (int)`` Save checkpoints every this many steps. Can not be specified with ``save_checkpoints_secs``. - ``keep_checkpoint_max (int)`` The maximum number of recent checkpoint files to keep. As new files are created, older files are deleted. If None or 0, all checkpoint files are kept. Defaults to 5 (that is, the 5 most recent checkpoint files are kept.) - ``keep_checkpoint_every_n_hours (int)`` Number of hours between each checkpoint to be saved. The default value of 10,000 hours effectively disables the feature. - ``log_step_count_steps (int)`` The frequency, in number of global steps, that the global step/sec will be logged during training. + +**TensorFlow 1.6 and up** + +- ``start_delay_secs (int)`` See docs for this parameter in `tf.estimator.EvalSpec `_. +- ``throttle_secs (int)`` See docs for this parameter in `tf.estimator.EvalSpec `_. + +**TensorFlow 1.4 and 1.5** + - ``eval_metrics (dict)`` ``dict`` of string, metric function. If `None`, default set is used. This should be ``None`` if the ``estimator`` is `tf.estimator.Estimator `_. If metrics are provided they will be *appended* to the default set. -- ``train_monitors (list)`` A list of monitors to pass during training. -- ``eval_hooks (list)`` A list of `SessionRunHook` hooks to pass during evaluation. - ``eval_delay_secs (int)`` Start evaluating after waiting for this many seconds. - ``continuous_eval_throttle_secs (int)`` Do not re-evaluate unless the last evaluation was started at least this many seconds ago. - ``min_eval_frequency (int)`` The minimum number of steps between evaluations. Of course, evaluation does not occur if no new snapshot is available, hence, this is the minimum. If 0, the evaluation will only happen after training. If None, defaults to 1000. @@ -1398,7 +1421,7 @@ This process looks like this: The common functionality can be extended by the addiction of the following two functions to your training script: -Overriding input precessing with an ``input_fn`` +Overriding input preprocessing with an ``input_fn`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ An example of ``input_fn`` for the content-type "application/python-pickle" can be seen below: @@ -1417,7 +1440,7 @@ An example of ``input_fn`` for the content-type "application/python-pickle" can # if the content type is not supported. pass -Overriding output precessing with an ``output_fn`` +Overriding output postprocessing with an ``output_fn`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ An example of ``output_fn`` for the accept type "application/python-pickle" can be seen below: