Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/v2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ The following estimator parameters have been renamed:
+------------------------------+------------------------+
| ``train_use_spot_instances`` | ``use_spot_instances`` |
+------------------------------+------------------------+
| ``train_max_run_wait`` | ``max_wait`` |
| ``train_max_wait`` | ``max_wait`` |
+------------------------------+------------------------+
| ``train_volume_size`` | ``volume_size`` |
+------------------------------+------------------------+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"train_instance_count",
"train_instance_type",
"train_max_run",
"train_max_run_wait",
"train_max_wait",
"train_use_spot_instances",
"train_volume_size",
"train_volume_kms_key",
Expand All @@ -63,7 +63,7 @@ def node_should_be_modified(self, node):
- ``train_instance_count``
- ``train_instance_type``
- ``train_max_run``
- ``train_max_run_wait``
- ``train_max_wait``
- ``train_use_spot_instances``
- ``train_volume_kms_key``
- ``train_volume_size``
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def __init__(
use_spot_instances = renamed_kwargs(
"train_use_spot_instances", "use_spot_instances", use_spot_instances, kwargs
)
max_wait = renamed_kwargs("train_max_run_wait", "max_wait", max_wait, kwargs)
max_wait = renamed_kwargs("train_max_wait", "max_wait", max_wait, kwargs)
volume_size = renamed_kwargs("train_volume_size", "volume_size", volume_size, kwargs)
volume_kms_key = renamed_kwargs(
"train_volume_kms_key", "volume_kms_key", volume_kms_key, kwargs
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/mxnet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ def __init__(
:class:`~sagemaker.estimator.EstimatorBase`.
"""
distribution = renamed_kwargs("distributions", "distribution", distribution, kwargs)
instance_type = renamed_kwargs(
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
)
validate_version_or_image_args(framework_version, py_version, image_uri)
if py_version == "py2":
logger.warning(
Expand All @@ -168,7 +171,6 @@ def __init__(
)

if distribution is not None:
instance_type = kwargs.get("instance_type")
warn_if_parameter_server_with_multi_gpu(
training_instance_type=instance_type, distribution=distribution
)
Expand Down
9 changes: 7 additions & 2 deletions src/sagemaker/sklearn/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging

from sagemaker import image_uris
from sagemaker.deprecations import renamed_kwargs
from sagemaker.estimator import Framework
from sagemaker.fw_utils import (
framework_name_from_image,
Expand Down Expand Up @@ -107,6 +108,12 @@ def __init__(
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
instance_type = renamed_kwargs(
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
)
instance_count = renamed_kwargs(
"train_instance_count", "instance_count", kwargs.get("instance_count"), kwargs
)
validate_version_or_image_args(framework_version, py_version, image_uri)
if py_version and py_version != "py3":
raise AttributeError(
Expand All @@ -117,10 +124,8 @@ def __init__(

# SciKit-Learn does not support distributed training or training on GPU instance types.
# Fail fast.
instance_type = kwargs.get("instance_type")
_validate_not_gpu_instance_type(instance_type)

instance_count = kwargs.get("instance_count")
if instance_count:
if instance_count != 1:
raise AttributeError(
Expand Down
5 changes: 4 additions & 1 deletion src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from sagemaker import image_uris, s3, utils
from sagemaker.debugger import DebuggerHookConfig
from sagemaker.deprecations import renamed_kwargs
from sagemaker.estimator import Framework
import sagemaker.fw_utils as fw
from sagemaker.tensorflow import defaults
Expand Down Expand Up @@ -112,6 +113,9 @@ def __init__(
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
instance_type = renamed_kwargs(
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
)
fw.validate_version_or_image_args(framework_version, py_version, image_uri)
if py_version == "py2":
logger.warning(
Expand All @@ -121,7 +125,6 @@ def __init__(
self.py_version = py_version

if distribution is not None:
instance_type = kwargs.get("instance_type")
fw.warn_if_parameter_server_with_multi_gpu(
training_instance_type=instance_type, distribution=distribution
)
Expand Down
6 changes: 5 additions & 1 deletion src/sagemaker/xgboost/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging

from sagemaker import image_uris
from sagemaker.deprecations import renamed_kwargs
from sagemaker.estimator import Framework, _TrainingJob
from sagemaker.fw_utils import (
framework_name_from_image,
Expand Down Expand Up @@ -95,6 +96,9 @@ def __init__(
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
instance_type = renamed_kwargs(
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
)
super(XGBoost, self).__init__(
entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
)
Expand All @@ -111,7 +115,7 @@ def __init__(
self.sagemaker_session.boto_region_name,
version=framework_version,
py_version=self.py_version,
instance_type=kwargs.get("instance_type"),
instance_type=instance_type,
image_scope="training",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"train_instance_count=1",
"train_instance_type='ml.c4.xlarge'",
"train_max_run=8 * 60 * 60",
"train_max_run_wait=1 * 60 * 60",
"train_max_wait=1 * 60 * 60",
"train_use_spot_instances=True",
"train_volume_size=30",
"train_volume_kms_key='key'",
Expand Down