Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
c7493c1
Add pytorch estimator and model. Initial commit.
nadiaya Mar 28, 2018
c638cb9
Add integ tests.
nadiaya Mar 28, 2018
01ea9aa
Update integ tests.
nadiaya Apr 5, 2018
5d07cf7
Add unit tests.
nadiaya Apr 5, 2018
10673b2
Merge branch 'master' into pytorch
nadiaya Apr 5, 2018
922a7a5
Update unit tests.
nadiaya Apr 5, 2018
5003333
Run integ tests for pytorch against local mode and real cpu instance.
nadiaya Apr 5, 2018
2d4cb33
Do not try to reattach to the training job when running locally since…
nadiaya Apr 11, 2018
b15b7c2
Local mode does not propagate errors raised in the customer script.
nadiaya Apr 11, 2018
d603c51
Do not use local mode in integ tests yet.
nadiaya Apr 12, 2018
ac28b7c
Merge branch 'master' into pytorch
nadiaya Apr 21, 2018
bc0f5eb
Merge pull request #1 from aws/pytorch
nadiaya Apr 21, 2018
58b9743
Merge branch 'master' of https://github.com/aws/sagemaker-python-sdk
nadiaya Apr 21, 2018
7ed5cc0
Add integ tests for prediction. Change tests to pick image python ver…
nadiaya Apr 26, 2018
1844ce5
Add prediction output assertions.
nadiaya Apr 26, 2018
7646115
Fix typo,
nadiaya Apr 26, 2018
e2334a9
Merge pull request #4 from aws/pytorch-prediction
nadiaya Apr 30, 2018
55a7acb
Add support for pytorch 0.4.0. Make it default version. Get rid of 0.…
nadiaya May 8, 2018
cb676eb
Add correct file format to model file.
nadiaya May 8, 2018
b39ed6a
Merge pull request #6 from aws/pytorch-0.4
nadiaya May 8, 2018
b47c3df
Merge branch 'master' of https://github.com/aws/sagemaker-python-sdk
nadiaya May 8, 2018
aa99ba9
Switch pytorch estimator integ tests to use script mode. (#60)
nadiaya Jun 5, 2018
d254128
Merge remote-tracking branch 'public/master' into pytorch-release
nadiaya Jun 5, 2018
3a9dedd
Make pytorch code and test conform with latest python sdk guidlines.
nadiaya Jun 5, 2018
73b0daa
Use npy as a default format for prediction instead of json. (#63)
nadiaya Jun 13, 2018
ff596ea
Merge remote-tracking branch 'public/master' into pytorch-release
nadiaya Jun 13, 2018
9b75787
Merge remote-tracking branch 'public/master' into pytorch-release
nadiaya Jun 14, 2018
684dfb3
Merge remote-tracking branch 'public/master' into pytorch-release
nadiaya Jun 15, 2018
89538c4
Add PyTorch documentaion. (#65)
nadiaya Jun 15, 2018
104a9ed
Merge remote-tracking branch 'public/master' into pytorch-release
nadiaya Jun 15, 2018
6703a96
Merge branch 'pytorch-release' into pytorch
nadiaya Jun 20, 2018
6e31341
Update Changelog and Readme with new version.
nadiaya Jun 20, 2018
276082e
Merge branch 'master' into pytorch
winstonaws Jun 20, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
CHANGELOG
=========

1.4.3dev
========
1.5.0
=====
* feature: Add Support for PyTorch Framework
* feature: Estimators: add support for TensorFlow 1.7.0
* feature: Estimators: add support for TensorFlow 1.8.0
* feature: Allow Local Serving of Models in S3
Expand Down
16 changes: 15 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ You can install from source by cloning this repository and issuing a pip install

git clone https://github.com/aws/sagemaker-python-sdk.git
python setup.py sdist
pip install dist/sagemaker-1.4.2.tar.gz
pip install dist/sagemaker-1.5.0.tar.gz

Supported Python versions
~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -236,6 +236,20 @@ More details at `Chainer SageMaker Estimators and Models`_.
.. _Chainer SageMaker Estimators and Models: src/sagemaker/chainer/README.rst


PyTorch SageMaker Estimators
-------------------------------

With PyTorch Estimators, you can train and host PyTorch models on Amazon SageMaker.

Supported versions of PyTorch: ``0.4.0``

You can visit the PyTorch repository at https://github.com/pytorch/pytorch.

More details at `PyTorch SageMaker Estimators and Models`_.

.. _PyTorch SageMaker Estimators and Models: src/sagemaker/pytorch/README.rst


AWS SageMaker Estimators
------------------------
Amazon SageMaker provides several built-in machine learning algorithms that you can use for a variety of problem types.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def read(fname):


setup(name="sagemaker",
version="1.4.2",
version="1.5.0",
description="Open source library for training and deploying models on Amazon SageMaker.",
packages=find_packages('src'),
package_dir={'': 'src'},
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def framework_name_from_image(image_name):
else:
# extract framework, python version and image tag
# We must support both the legacy and current image name format.
name_pattern = re.compile('^sagemaker-(tensorflow|mxnet|chainer):(.*?)-(.*?)-(py2|py3)$')
name_pattern = re.compile('^sagemaker-(tensorflow|mxnet|chainer|pytorch):(.*?)-(.*?)-(py2|py3)$')
legacy_name_pattern = re.compile('^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$')
name_match = name_pattern.match(sagemaker_match.group(8))
legacy_match = legacy_name_pattern.match(sagemaker_match.group(8))
Expand Down
711 changes: 711 additions & 0 deletions src/sagemaker/pytorch/README.rst

Large diffs are not rendered by default.

17 changes: 17 additions & 0 deletions src/sagemaker/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
from sagemaker.pytorch.estimator import PyTorch
from sagemaker.pytorch.model import PyTorchModel, PyTorchPredictor

__all__ = [PyTorch, PyTorchModel, PyTorchPredictor]
16 changes: 16 additions & 0 deletions src/sagemaker/pytorch/defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

PYTORCH_VERSION = '0.4'
PYTHON_VERSION = 'py3'
112 changes: 112 additions & 0 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
from sagemaker.estimator import Framework
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
from sagemaker.pytorch.model import PyTorchModel


class PyTorch(Framework):
"""Handle end-to-end training and deployment of custom PyTorch code."""

__framework_name__ = "pytorch"

def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_version=PYTHON_VERSION,
framework_version=PYTORCH_VERSION, **kwargs):
"""
This ``Estimator`` executes an PyTorch script in a managed PyTorch execution environment, within a SageMaker
Training Job. The managed PyTorch environment is an Amazon-built Docker container that executes functions
defined in the supplied ``entry_point`` Python script.

Training is started by calling :meth:`~sagemaker.amazon.estimator.Framework.fit` on this Estimator.
After training is complete, calling :meth:`~sagemaker.amazon.estimator.Framework.deploy` creates a
hosted SageMaker endpoint and returns an :class:`~sagemaker.amazon.pytorch.model.PyTorchPredictor` instance
that can be used to perform inference against the hosted model.

Technical documentation on preparing PyTorch scripts for SageMaker training and using the PyTorch Estimator is
available on the project home-page: https://github.com/aws/sagemaker-python-sdk

Args:
entry_point (str): Path (absolute or relative) to the Python source file which should be executed
as the entry point to training. This should be compatible with either Python 2.7 or Python 3.5.
source_dir (str): Path (absolute or relative) to a directory with any other training
source code dependencies aside from tne entry point file (default: None). Structure within this
directory are preserved when training on Amazon SageMaker.
hyperparameters (dict): Hyperparameters that will be used for training (default: None).
The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
For convenience, this accepts other types for keys and values, but ``str()`` will be called
to convert them before training.
py_version (str): Python version you want to use for executing your model training code (default: 'py3').
One of 'py2' or 'py3'.
framework_version (str): PyTorch version you want to use for executing your model training code.
List of supported versions https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor.
"""
super(PyTorch, self).__init__(entry_point, source_dir, hyperparameters, **kwargs)
self.py_version = py_version
self.framework_version = framework_version

def train_image(self):
"""Return the Docker image to use for training.

The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does the model training, calls this method to
find the image to use for model training.

Returns:
str: The URI of the Docker image.
"""
return create_image_uri(self.sagemaker_session.boto_session.region_name, self.__framework_name__,
self.train_instance_type, framework_version=self.framework_version,
py_version=self.py_version)

def create_model(self, model_server_workers=None):
"""Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``.

Args:
model_server_workers (int): Optional. The number of worker processes used by the inference server.
If None, server will use one worker per vCPU.

Returns:
sagemaker.pytorch.model.PyTorchModel: A SageMaker ``PyTorchModel`` object.
See :func:`~sagemaker.pytorch.model.PyTorchModel` for full details.
"""
return PyTorchModel(self.model_data, self.role, self.entry_point, source_dir=self.source_dir,
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
container_log_level=self.container_log_level, code_location=self.code_location,
py_version=self.py_version, framework_version=self.framework_version,
model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session)

@classmethod
def _prepare_init_params_from_job_description(cls, job_details):
"""Convert the job description to init params that can be handled by the class constructor

Args:
job_details: the returned job details from a describe_training_job API call.

Returns:
dictionary: The transformed init_params

"""
init_params = super(PyTorch, cls)._prepare_init_params_from_job_description(job_details)
framework, py_version, tag = framework_name_from_image(init_params.pop('image'))

init_params['py_version'] = py_version
init_params['framework_version'] = framework_version_from_tag(tag)

training_job_name = init_params['base_job_name']

if framework != cls.__framework_name__:
raise ValueError("Training job: {} didn't use image for requested framework".format(training_job_name))

return init_params
94 changes: 94 additions & 0 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import sagemaker
from sagemaker.fw_utils import create_image_uri
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer
from sagemaker.utils import name_from_image


class PyTorchPredictor(RealTimePredictor):
"""A RealTimePredictor for inference against PyTorch Endpoints.

This is able to serialize Python lists, dictionaries, and numpy arrays to multidimensional tensors for PyTorch
inference."""

def __init__(self, endpoint_name, sagemaker_session=None):
"""Initialize an ``PyTorchPredictor``.

Args:
endpoint_name (str): The name of the endpoint to perform inference on.
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
using the default AWS configuration chain.
"""
super(PyTorchPredictor, self).__init__(endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer)


class PyTorchModel(FrameworkModel):
"""An PyTorch SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""

__framework_name__ = 'pytorch'

def __init__(self, model_data, role, entry_point, image=None, py_version=PYTHON_VERSION,
framework_version=PYTORCH_VERSION, predictor_cls=PyTorchPredictor,
model_server_workers=None, **kwargs):
"""Initialize an PyTorchModel.

Args:
model_data (str): The S3 location of a SageMaker model data ``.tar.gz`` file.
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs
that create Amazon SageMaker endpoints use this role to access training data and model artifacts.
After the endpoint is created, the inference code might use the IAM role,
if it needs to access an AWS resource.
entry_point (str): Path (absolute or relative) to the Python source file which should be executed
as the entry point to model hosting. This should be compatible with either Python 2.7 or Python 3.5.
image (str): A Docker image URI (default: None). If not specified, a default image for PyTorch will be used.
py_version (str): Python version you want to use for executing your model training code (default: 'py3').
framework_version (str): PyTorch version you want to use for executing your model training code.
predictor_cls (callable[str, sagemaker.session.Session]): A function to call to create a predictor
with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of
invoking this function on the created endpoint name.
model_server_workers (int): Optional. The number of worker processes used by the inference server.
If None, server will use one worker per vCPU.
**kwargs: Keyword arguments passed to the ``FrameworkModel`` initializer.
"""
super(PyTorchModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs)
self.py_version = py_version
self.framework_version = framework_version
self.model_server_workers = model_server_workers

def prepare_container_def(self, instance_type):
"""Return a container definition with framework configuration set in model environment variables.

Args:
instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.

Returns:
dict[str, str]: A container definition object usable with the CreateModel API.
"""
deploy_image = self.image
if not deploy_image:
region_name = self.sagemaker_session.boto_session.region_name
deploy_image = create_image_uri(region_name, self.__framework_name__, instance_type,
self.framework_version, self.py_version)
deploy_key_prefix = self.key_prefix or self.name or name_from_image(deploy_image)
self._upload_code(deploy_key_prefix)
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())

if self.model_server_workers:
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def mxnet_version(request):
return request.param


@pytest.fixture(scope='module', params=["0.4", "0.4.0"])
def pytorch_version(request):
return request.param


@pytest.fixture(scope='module', params=['4.0', '4.0.0'])
def chainer_version(request):
return request.param
Expand All @@ -96,6 +101,11 @@ def mxnet_full_version(request):
return request.param


@pytest.fixture(scope='module', params=["0.4.0"])
def pytorch_full_version(request):
return request.param


@pytest.fixture(scope='module', params=['4.0.0'])
def chainer_full_version(request):
return request.param
3 changes: 3 additions & 0 deletions tests/data/pytorch_mnist/failure_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
if __name__ == '__main__':
"""For use with integration tests expecting failures."""
raise Exception('This failure is expected.')
Loading