Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 23 additions & 185 deletions doc/frameworks/mxnet/using_mxnet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,6 @@ To train an MXNet model by using the SageMaker Python SDK:
Prepare an MXNet Training Script
================================

.. warning::
The structure for training scripts changed starting at MXNet version 1.3.
Make sure you refer to the correct section of this README when you prepare your script.
For information on how to upgrade an old script to the new format, see `"Updating your MXNet training script" <#updating-your-mxnet-training-script>`__.

For versions 1.3 and higher
---------------------------
Your MXNet training script must be compatible with Python 2.7 or 3.6.

The training script is very similar to a training script you might run outside of Amazon SageMaker, but you can access useful properties about the training environment through various environment variables, including the following:

* ``SM_MODEL_DIR``: A string that represents the path where the training job writes the model artifacts to.
Expand Down Expand Up @@ -89,119 +80,8 @@ If you want to use, for example, boolean hyperparameters, you need to specify ``

For more on training environment variables, please visit `SageMaker Containers <https://github.com/aws/sagemaker-containers>`_.

For versions 1.2 and lower
--------------------------

Your MXNet training script must be compatible with Python 2.7 or 3.5.
The script must contain a function named ``train``, which Amazon SageMaker invokes to run training.
You can include other functions as well, but it must contain a ``train`` function.

When you run your script on Amazon SageMaker via the ``MXNet`` estimator, Amazon SageMaker injects information about the training environment into your training function via Python keyword arguments.
You can choose to take advantage of these by including them as keyword arguments in your train function. The full list of arguments is:

- ``hyperparameters (dict[string,string])``: The hyperparameters passed
to an Amazon SageMaker TrainingJob that runs your MXNet training script. You
can use this to pass hyperparameters to your training script.
- ``input_data_config (dict[string,dict])``: The Amazon SageMaker TrainingJob
InputDataConfig object, that's set when the Amazon SageMaker TrainingJob is
created. This is discussed in more detail below.
- ``channel_input_dirs (dict[string,string])``: A collection of
directories containing training data. When you run training, you can
partition your training data into different logical "channels".
Depending on your problem, some common channel ideas are: "train",
"test", "evaluation" or "images',"labels".
- ``output_data_dir (str)``: A directory where your training script can
write data that is moved to Amazon S3 after training is complete.
- ``num_gpus (int)``: The number of GPU devices available on your
training instance.
- ``num_cpus (int)``: The number of CPU devices available on your training instance.
- ``hosts (list[str])``: The list of host names running in the
Amazon SageMaker Training Job cluster.
- ``current_host (str)``: The name of the host executing the script.
When you use Amazon SageMaker for MXNet training, the script is run on each
host in the cluster.

A training script that takes advantage of all arguments would have the following definition:

.. code:: python

def train(hyperparameters, input_data_config, channel_input_dirs, output_data_dir,
num_gpus, num_cpus, hosts, current_host)

You don't have to use all the arguments.
Arguments you don't care about can be ignored by including ``**kwargs``.

.. code:: python

# Only work with hyperparameters and num_gpus, and ignore all other hyperparameters
def train(hyperparameters, num_gpus, **kwargs)

.. note::
**Writing a training script that imports correctly:**
When Amazon SageMaker runs your training script, it imports it as a Python module and then invokes ``train`` on the imported module.
Consequently, you should not include any statements that won't execute successfully in Amazon SageMaker when your module is imported.
For example, don't attempt to open any local files in top-level statements in your training script.

If you want to run your training script locally by using the Python interpreter, use a ``___name__ == '__main__'`` guard.
For more information, see https://stackoverflow.com/questions/419163/what-does-if-name-main-do.

Save the Model
^^^^^^^^^^^^^^

Just as you enable training by defining a ``train`` function in your training script, you enable model saving by defining a ``save`` function in your script.
If your script includes a ``save`` function, Amazon SageMaker invokes it with the return value of ``train``.
Model saving is a two-step process.
First, return the model you want to save from ``train``.
Then, define your model-serialization logic in ``save``.

Amazon SageMaker provides a default implementation of ``save`` that works with MXNet Module API ``Module`` objects.
If your training script does not define a ``save`` function, then the default ``save`` function is invoked on the return value of your ``train`` function.

The default serialization system generates three files:

- ``model-shapes.json``: A JSON list, containing a serialization of the
``Module`` ``data_shapes`` property. Each object in the list contains
the serialization of one ``DataShape`` in the returned ``Module``.
Each object has a ``name`` property, containing the ``DataShape``
name and a ``shape`` property, which is a list of that dimensions for
the shape of that ``DataShape``. For example:

.. code:: javascript

[
{"name":"images", "shape":[100, 1, 28, 28]},
{"name":"labels", "shape":[100, 1]}
]

- ``model-symbol.json``: The MXNet ``Module`` ``Symbol`` serialization,
produced by invoking ``save`` on the ``symbol`` property of the
``Module`` being saved.
- ``modle.params``: The MXNet ``Module`` parameters, produced by
invoking ``save_params`` on the ``Module`` being saved.

You can provide your own save function. This is useful if you are not working with the ``Module`` API or you need special processing.

To provide your own save function, define a ``save`` function in your training script:

.. code:: python

def save(model, model_dir)

The function should take two arguments:

- ``model``: This is the object that is returned from your ``train`` function.
You may return an object of any type from ``train``;
you do not have to return ``Module`` or ``Gluon`` API specific objects.
If your ``train`` function does not return an object, ``model`` is set to ``None``.
- ``model_dir``: This is the string path on the Amazon SageMaker training host where you save your model.
Files created in this directory are accessible in Amazon S3 after your Amazon SageMaker Training Job completes.

After your ``train`` function completes, Amazon SageMaker invokes ``save`` with the object returned from ``train``.

.. note::
**How to save Gluon models with Amazon SageMaker:**
If your train function returns a Gluon API ``net`` object as its model, you need to write your own ``save`` function and serialize the ``net`` parameters.
Saving ``net`` parameters is covered in the `Serialization section <http://gluon.mxnet.io/chapter03_deep-neural-networks/serialization.html>`__ of the collaborative Gluon deep-learning book `"The Straight Dope" <http://gluon.mxnet.io/index.html>`__.
If you want to use MXNet 1.2 or lower, see `an older version of this page <https://sagemaker.readthedocs.io/en/v1.61.0/frameworks/mxnet/using_mxnet.html>`_.

Save a Checkpoint
-----------------
Expand Down Expand Up @@ -233,86 +113,44 @@ To save MXNet model checkpoints, do the following in your training script:

For a complete example of an MXNet training script that impelements checkpointing, see https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/mxnet_gluon_cifar10/cifar10.py.

Save the Model
--------------

Update your MXNet training script
---------------------------------

The structure for training scripts changed with MXNet version 1.3.
The ``train`` function is no longer be required; instead the training script must be able to be run as a standalone script.
In this way, the training script is similar to a training script you might run outside of Amazon SageMaker.

There are a few steps needed to make a training script with the old format compatible with the new format.

First, add a `main guard <https://docs.python.org/3/library/__main__.html>`__ (``if __name__ == '__main__':``).
The code executed from your main guard needs to:

1. Set hyperparameters and directory locations
2. Initiate training
3. Save the model

Hyperparameters are passed as command-line arguments to your training script.
In addition, the container defines the locations of input data and where to save the model artifacts and output data as environment variables rather than passing that information as arguments to the ``train`` function.
You can find the full list of available environment variables in the `SageMaker Containers README <https://github.com/aws/sagemaker-containers#list-of-provided-environment-variables-by-sagemaker-containers>`__.

We recommend using `an argument parser <https://docs.python.org/3.5/howto/argparse.html>`__ for this part.
Using the ``argparse`` library as an example, the code looks something like this:

.. code:: python

import argparse
import os

if __name__ == '__main__':
parser = argparse.ArgumentParser()

# hyperparameters sent by the client are passed as command-line arguments to the script.
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=100)
parser.add_argument('--learning-rate', type=float, default=0.1)

# input data and model directories
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
parser.add_argument('--test', type=str, default=os.environ['SM_CHANNEL_TEST'])

args, _ = parser.parse_known_args()

The code in the main guard should also take care of training and saving the model.
This can be as simple as just calling the ``train`` and ``save`` methods used in the previous training script format:

.. code:: python

if __name__ == '__main__':
# arg parsing (shown above) goes here

model = train(args.batch_size, args.epochs, args.learning_rate, args.train, args.test)
save(args.model_dir, model)

Note that saving the model is no longer be done by default; this must be done by the training script.
If you were previously relying on the default save method, you can import one from the container:
There is a default save method that can be imported when training on SageMaker:

.. code:: python

from sagemaker_mxnet_container.training_utils import save
from sagemaker_mxnet_training.training_utils import save

if __name__ == '__main__':
# arg parsing and training (shown above) goes here

save(args.model_dir, model)

Lastly, if you were relying on the container launching a parameter server for use with distributed training, you must set ``distributions`` to the following dictionary when creating an MXNet estimator:
The default serialization system generates three files:

.. code:: python
- ``model-shapes.json``: A JSON list, containing a serialization of the
``Module`` ``data_shapes`` property. Each object in the list contains
the serialization of one ``DataShape`` in the returned ``Module``.
Each object has a ``name`` property, containing the ``DataShape``
name and a ``shape`` property, which is a list of that dimensions for
the shape of that ``DataShape``. For example:

from sagemaker.mxnet import MXNet
.. code:: javascript

estimator = MXNet('path-to-distributed-training-script.py',
...,
distributions={'parameter_server': {'enabled': True}})
[
{"name":"images", "shape":[100, 1, 28, 28]},
{"name":"labels", "shape":[100, 1]}
]

- ``model-symbol.json``: The MXNet ``Module`` ``Symbol`` serialization,
produced by invoking ``save`` on the ``symbol`` property of the
``Module`` being saved.
- ``modle.params``: The MXNet ``Module`` parameters, produced by
invoking ``save_params`` on the ``Module`` being saved.

Use third-party libraries
-------------------------
=========================

When running your training script on Amazon SageMaker, it has access to some pre-installed third-party libraries, including ``mxnet``, ``numpy``, ``onnx``, and ``keras-mxnet``.
For more information on the runtime environment, including specific package versions, see `SageMaker MXNet Containers <#sagemaker-mxnet-containers>`__.
Expand Down