Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion doc/frameworks/chainer/using_chainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ directories ('train' and 'test').
train_instance_type='ml.p3.2xlarge',
train_instance_count=1,
framework_version='5.0.0',
py_version='py3',
hyperparameters = {'epochs': 20, 'batch-size': 64, 'learning-rate': 0.1})
chainer_estimator.fit({'train': 's3://my-data-bucket/path/to/my/training/data',
'test': 's3://my-data-bucket/path/to/my/test/data'})
Expand Down Expand Up @@ -222,7 +223,8 @@ operation.
chainer_estimator = Chainer(entry_point='train_and_deploy.py',
train_instance_type='ml.p3.2xlarge',
train_instance_count=1,
framework_version='5.0.0')
framework_version='5.0.0',
py_version='py3')
chainer_estimator.fit('s3://my_bucket/my_training_data/')

# Deploy my estimator to a SageMaker Endpoint and get a Predictor
Expand Down
38 changes: 21 additions & 17 deletions src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from sagemaker.fw_utils import (
framework_name_from_image,
framework_version_from_tag,
empty_framework_version_warning,
python_deprecation_warning,
validate_version_or_image_args,
)
from sagemaker.chainer import defaults
from sagemaker.chainer.model import ChainerModel
Expand Down Expand Up @@ -51,8 +51,8 @@ def __init__(
additional_mpi_options=None,
source_dir=None,
hyperparameters=None,
py_version="py3",
framework_version=None,
py_version=None,
image_name=None,
**kwargs
):
Expand Down Expand Up @@ -103,11 +103,12 @@ def __init__(
and values, but ``str()`` will be called to convert them before
training.
py_version (str): Python version you want to use for executing your
model training code (default: 'py2'). One of 'py2' or 'py3'.
model training code. Defaults to ``None``. Required unless ``image_name``
is provided.
framework_version (str): Chainer version you want to use for
executing your model training code. List of supported versions
executing your model training code. Defaults to ``None``. Required unless
``image_name`` is provided. List of supported versions:
https://github.com/aws/sagemaker-python-sdk#chainer-sagemaker-estimators.
If not specified, this will default to 4.1.
image_name (str): If specified, the estimator will use this image
for training and hosting, instead of selecting the appropriate
SageMaker official image based on framework_version and
Expand All @@ -117,6 +118,9 @@ def __init__(
* ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
* ``custom-image:latest``

If ``framework_version`` or ``py_version`` are ``None``, then
``image_name`` is required. If also ``None``, then a ``ValueError``
will be raised.
**kwargs: Additional kwargs passed to the
:class:`~sagemaker.estimator.Framework` constructor.

Expand All @@ -126,22 +130,18 @@ def __init__(
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
if framework_version is None:
validate_version_or_image_args(framework_version, py_version, image_name)
if py_version == "py2":
logger.warning(
empty_framework_version_warning(defaults.CHAINER_VERSION, self.LATEST_VERSION)
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
)
self.framework_version = framework_version or defaults.CHAINER_VERSION
self.framework_version = framework_version
self.py_version = py_version

super(Chainer, self).__init__(
entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs
)

if py_version == "py2":
logger.warning(
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
)

self.py_version = py_version
self.use_mpi = use_mpi
self.num_processes = num_processes
self.process_slots_per_host = process_slots_per_host
Expand Down Expand Up @@ -262,15 +262,19 @@ class constructor
image_name = init_params.pop("image")
framework, py_version, tag, _ = framework_name_from_image(image_name)

if tag is None:
framework_version = None
else:
framework_version = framework_version_from_tag(tag)
init_params["framework_version"] = framework_version
init_params["py_version"] = py_version

if not framework:
# If we were unable to parse the framework name from the image it is not one of our
# officially supported images, in this case just add the image to the init params.
init_params["image_name"] = image_name
return init_params

init_params["py_version"] = py_version
init_params["framework_version"] = framework_version_from_tag(tag)

training_job_name = init_params["base_job_name"]

if framework != cls.__framework_name__:
Expand Down
31 changes: 16 additions & 15 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
create_image_uri,
model_code_key_prefix,
python_deprecation_warning,
empty_framework_version_warning,
validate_version_or_image_args,
)
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.chainer import defaults
Expand Down Expand Up @@ -67,8 +67,8 @@ def __init__(
role,
entry_point,
image=None,
py_version="py3",
framework_version=None,
py_version=None,
predictor_cls=ChainerPredictor,
model_server_workers=None,
**kwargs
Expand All @@ -88,11 +88,15 @@ def __init__(
hosting. If ``source_dir`` is specified, then ``entry_point``
must point to a file located at the root of ``source_dir``.
image (str): A Docker image URI (default: None). If not specified, a
default image for Chainer will be used.
py_version (str): Python version you want to use for executing your
model training code (default: 'py2').
default image for Chainer will be used. If ``framework_version``
or ``py_version`` are ``None``, then ``image`` is required. If
also ``None``, then a ``ValueError`` will be raised.
framework_version (str): Chainer version you want to use for
executing your model training code.
executing your model training code. Defaults to ``None``. Required
unless ``image`` is provided.
py_version (str): Python version you want to use for executing your
model training code. Defaults to ``None``. Required unless
``image`` is provided.
predictor_cls (callable[str, sagemaker.session.Session]): A function
to call to create a predictor with an endpoint name and
SageMaker ``Session``. If specified, ``deploy()`` returns the
Expand All @@ -109,21 +113,18 @@ def __init__(
:class:`~sagemaker.model.FrameworkModel` and
:class:`~sagemaker.model.Model`.
"""
super(ChainerModel, self).__init__(
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
)
validate_version_or_image_args(framework_version, py_version, image)
if py_version == "py2":
logger.warning(
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
)
self.framework_version = framework_version
self.py_version = py_version

if framework_version is None:
logger.warning(
empty_framework_version_warning(defaults.CHAINER_VERSION, defaults.LATEST_VERSION)
)
super(ChainerModel, self).__init__(
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
)

self.py_version = py_version
self.framework_version = framework_version or defaults.CHAINER_VERSION
self.model_server_workers = model_server_workers

def prepare_container_def(self, instance_type, accelerator_type=None):
Expand Down
78 changes: 32 additions & 46 deletions tests/unit/test_chainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _get_full_cpu_image_uri_with_ei(version, py_version=PYTHON_VERSION):

def _chainer_estimator(
sagemaker_session,
framework_version=defaults.CHAINER_VERSION,
framework_version,
train_instance_type=None,
base_job_name=None,
use_mpi=None,
Expand Down Expand Up @@ -202,13 +202,14 @@ def _create_train_job_with_additional_hyperparameters(version):
}


def test_additional_hyperparameters(sagemaker_session):
def test_additional_hyperparameters(sagemaker_session, chainer_version):
chainer = _chainer_estimator(
sagemaker_session,
use_mpi=True,
num_processes=4,
process_slots_per_host=10,
additional_mpi_options="-x MY_ENVIRONMENT_VARIABLE",
framework_version=chainer_version,
)
assert bool(strtobool(chainer.hyperparameters()["sagemaker_use_mpi"]))
assert int(chainer.hyperparameters()["sagemaker_num_processes"]) == 4
Expand Down Expand Up @@ -300,7 +301,7 @@ def test_create_model(sagemaker_session, chainer_version):
assert model.vpc_config is None


def test_create_model_with_optional_params(sagemaker_session):
def test_create_model_with_optional_params(sagemaker_session, chainer_version):
container_log_level = '"logging.INFO"'
source_dir = "s3://mybucket/source"
enable_cloudwatch_metrics = "true"
Expand All @@ -311,6 +312,7 @@ def test_create_model_with_optional_params(sagemaker_session):
train_instance_count=INSTANCE_COUNT,
train_instance_type=INSTANCE_TYPE,
container_log_level=container_log_level,
framework_version=chainer_version,
py_version=PYTHON_VERSION,
base_job_name="job",
source_dir=source_dir,
Expand Down Expand Up @@ -372,8 +374,8 @@ def test_chainer(strftime, sagemaker_session, chainer_version):
sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT,
train_instance_type=INSTANCE_TYPE,
py_version=PYTHON_VERSION,
framework_version=chainer_version,
py_version=PYTHON_VERSION,
)

inputs = "s3://mybucket/train"
Expand Down Expand Up @@ -414,62 +416,72 @@ def test_chainer(strftime, sagemaker_session, chainer_version):


@patch("sagemaker.utils.create_tar_file", MagicMock())
def test_model(sagemaker_session):
def test_model(sagemaker_session, chainer_version):
model = ChainerModel(
"s3://some/data.tar.gz",
role=ROLE,
entry_point=SCRIPT_PATH,
sagemaker_session=sagemaker_session,
framework_version=chainer_version,
py_version=PYTHON_VERSION,
)
predictor = model.deploy(1, GPU)
assert isinstance(predictor, ChainerPredictor)


@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
def test_model_prepare_container_def_accelerator_error(sagemaker_session):
def test_model_prepare_container_def_accelerator_error(sagemaker_session, chainer_version):
model = ChainerModel(
MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session
MODEL_DATA,
role=ROLE,
entry_point=SCRIPT_PATH,
sagemaker_session=sagemaker_session,
framework_version=chainer_version,
py_version=PYTHON_VERSION,
)
with pytest.raises(ValueError):
model.prepare_container_def(INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE)


def test_train_image_default(sagemaker_session):
def test_train_image_default(sagemaker_session, chainer_version):
chainer = Chainer(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT,
train_instance_type=INSTANCE_TYPE,
framework_version=chainer_version,
py_version=PYTHON_VERSION,
)

assert _get_full_cpu_image_uri(defaults.CHAINER_VERSION) in chainer.train_image()
assert _get_full_cpu_image_uri(chainer_version) in chainer.train_image()


def test_train_image_cpu_instances(sagemaker_session, chainer_version):
chainer = _chainer_estimator(
sagemaker_session, chainer_version, train_instance_type="ml.c2.2xlarge"
sagemaker_session, framework_version=chainer_version, train_instance_type="ml.c2.2xlarge"
)
assert chainer.train_image() == _get_full_cpu_image_uri(chainer_version)

chainer = _chainer_estimator(
sagemaker_session, chainer_version, train_instance_type="ml.c4.2xlarge"
sagemaker_session, framework_version=chainer_version, train_instance_type="ml.c4.2xlarge"
)
assert chainer.train_image() == _get_full_cpu_image_uri(chainer_version)

chainer = _chainer_estimator(sagemaker_session, chainer_version, train_instance_type="ml.m16")
chainer = _chainer_estimator(
sagemaker_session, framework_version=chainer_version, train_instance_type="ml.m16"
)
assert chainer.train_image() == _get_full_cpu_image_uri(chainer_version)


def test_train_image_gpu_instances(sagemaker_session, chainer_version):
chainer = _chainer_estimator(
sagemaker_session, chainer_version, train_instance_type="ml.g2.2xlarge"
sagemaker_session, framework_version=chainer_version, train_instance_type="ml.g2.2xlarge"
)
assert chainer.train_image() == _get_full_gpu_image_uri(chainer_version)

chainer = _chainer_estimator(
sagemaker_session, chainer_version, train_instance_type="ml.p2.2xlarge"
sagemaker_session, framework_version=chainer_version, train_instance_type="ml.p2.2xlarge"
)
assert chainer.train_image() == _get_full_gpu_image_uri(chainer_version)

Expand Down Expand Up @@ -597,13 +609,14 @@ def test_attach_custom_image(sagemaker_session):


@patch("sagemaker.chainer.estimator.python_deprecation_warning")
def test_estimator_py2_warning(warning, sagemaker_session):
def test_estimator_py2_warning(warning, sagemaker_session, chainer_version):
estimator = Chainer(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT,
train_instance_type=INSTANCE_TYPE,
framework_version=chainer_version,
py_version="py2",
)

Expand All @@ -612,49 +625,22 @@ def test_estimator_py2_warning(warning, sagemaker_session):


@patch("sagemaker.chainer.model.python_deprecation_warning")
def test_model_py2_warning(warning, sagemaker_session):
def test_model_py2_warning(warning, sagemaker_session, chainer_version):
model = ChainerModel(
MODEL_DATA,
role=ROLE,
entry_point=SCRIPT_PATH,
sagemaker_session=sagemaker_session,
framework_version=chainer_version,
py_version="py2",
)
assert model.py_version == "py2"
warning.assert_called_with(model.__framework_name__, defaults.LATEST_PY2_VERSION)


@patch("sagemaker.chainer.estimator.empty_framework_version_warning")
def test_empty_framework_version(warning, sagemaker_session):
estimator = Chainer(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT,
train_instance_type=INSTANCE_TYPE,
framework_version=None,
)

assert estimator.framework_version == defaults.CHAINER_VERSION
warning.assert_called_with(defaults.CHAINER_VERSION, Chainer.LATEST_VERSION)


@patch("sagemaker.chainer.model.empty_framework_version_warning")
def test_model_empty_framework_version(warning, sagemaker_session):
model = ChainerModel(
MODEL_DATA,
role=ROLE,
entry_point=SCRIPT_PATH,
sagemaker_session=sagemaker_session,
framework_version=None,
)
assert model.framework_version == defaults.CHAINER_VERSION
warning.assert_called_with(defaults.CHAINER_VERSION, defaults.LATEST_VERSION)


def test_custom_image_estimator_deploy(sagemaker_session):
def test_custom_image_estimator_deploy(sagemaker_session, chainer_version):
custom_image = "mycustomimage:latest"
chainer = _chainer_estimator(sagemaker_session)
chainer = _chainer_estimator(sagemaker_session, framework_version=chainer_version)
chainer.fit(inputs="s3://mybucket/train", job_name="new_name")
model = chainer.create_model(image=custom_image)
assert model.image == custom_image
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ extras = test

[testenv:py27]
setenv =
IGNORE_COVERAGE = true
IGNORE_COVERAGE = -

[testenv:flake8]
basepython = python3
Expand Down