Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/v2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ The following estimator parameters have been renamed:
+------------------------------+------------------------+
| ``train_use_spot_instances`` | ``use_spot_instances`` |
+------------------------------+------------------------+
| ``train_max_run_wait`` | ``max_wait`` |
| ``train_max_wait`` | ``max_wait`` |
+------------------------------+------------------------+
| ``train_volume_size`` | ``volume_size`` |
+------------------------------+------------------------+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"train_instance_count",
"train_instance_type",
"train_max_run",
"train_max_run_wait",
"train_max_wait",
"train_use_spot_instances",
"train_volume_size",
"train_volume_kms_key",
Expand All @@ -63,7 +63,7 @@ def node_should_be_modified(self, node):
- ``train_instance_count``
- ``train_instance_type``
- ``train_max_run``
- ``train_max_run_wait``
- ``train_max_wait``
- ``train_use_spot_instances``
- ``train_volume_kms_key``
- ``train_volume_size``
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def __init__(
use_spot_instances = renamed_kwargs(
"train_use_spot_instances", "use_spot_instances", use_spot_instances, kwargs
)
max_wait = renamed_kwargs("train_max_run_wait", "max_wait", max_wait, kwargs)
max_wait = renamed_kwargs("train_max_wait", "max_wait", max_wait, kwargs)
volume_size = renamed_kwargs("train_volume_size", "volume_size", volume_size, kwargs)
volume_kms_key = renamed_kwargs(
"train_volume_kms_key", "volume_kms_key", volume_kms_key, kwargs
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/mxnet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ def __init__(
:class:`~sagemaker.estimator.EstimatorBase`.
"""
distribution = renamed_kwargs("distributions", "distribution", distribution, kwargs)
instance_type = renamed_kwargs(
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
)
validate_version_or_image_args(framework_version, py_version, image_uri)
if py_version == "py2":
logger.warning(
Expand All @@ -168,7 +171,6 @@ def __init__(
)

if distribution is not None:
instance_type = kwargs.get("instance_type")
warn_if_parameter_server_with_multi_gpu(
training_instance_type=instance_type, distribution=distribution
)
Expand Down
9 changes: 7 additions & 2 deletions src/sagemaker/sklearn/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging

from sagemaker import image_uris
from sagemaker.deprecations import renamed_kwargs
from sagemaker.estimator import Framework
from sagemaker.fw_utils import (
framework_name_from_image,
Expand Down Expand Up @@ -107,6 +108,12 @@ def __init__(
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
instance_type = renamed_kwargs(
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
)
instance_count = renamed_kwargs(
"train_instance_count", "instance_count", kwargs.get("instance_count"), kwargs
)
validate_version_or_image_args(framework_version, py_version, image_uri)
if py_version and py_version != "py3":
raise AttributeError(
Expand All @@ -117,10 +124,8 @@ def __init__(

# SciKit-Learn does not support distributed training or training on GPU instance types.
# Fail fast.
instance_type = kwargs.get("instance_type")
_validate_not_gpu_instance_type(instance_type)

instance_count = kwargs.get("instance_count")
if instance_count:
if instance_count != 1:
raise AttributeError(
Expand Down
5 changes: 4 additions & 1 deletion src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from sagemaker import image_uris, s3, utils
from sagemaker.debugger import DebuggerHookConfig
from sagemaker.deprecations import renamed_kwargs
from sagemaker.estimator import Framework
import sagemaker.fw_utils as fw
from sagemaker.tensorflow import defaults
Expand Down Expand Up @@ -112,6 +113,9 @@ def __init__(
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
instance_type = renamed_kwargs(
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
)
fw.validate_version_or_image_args(framework_version, py_version, image_uri)
if py_version == "py2":
logger.warning(
Expand All @@ -121,7 +125,6 @@ def __init__(
self.py_version = py_version

if distribution is not None:
instance_type = kwargs.get("instance_type")
fw.warn_if_parameter_server_with_multi_gpu(
training_instance_type=instance_type, distribution=distribution
)
Expand Down
6 changes: 5 additions & 1 deletion src/sagemaker/xgboost/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging

from sagemaker import image_uris
from sagemaker.deprecations import renamed_kwargs
from sagemaker.estimator import Framework, _TrainingJob
from sagemaker.fw_utils import (
framework_name_from_image,
Expand Down Expand Up @@ -95,6 +96,9 @@ def __init__(
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
instance_type = renamed_kwargs(
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
)
super(XGBoost, self).__init__(
entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
)
Expand All @@ -111,7 +115,7 @@ def __init__(
self.sagemaker_session.boto_region_name,
version=framework_version,
py_version=self.py_version,
instance_type=kwargs.get("instance_type"),
instance_type=instance_type,
image_scope="training",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"train_instance_count=1",
"train_instance_type='ml.c4.xlarge'",
"train_max_run=8 * 60 * 60",
"train_max_run_wait=1 * 60 * 60",
"train_max_wait=1 * 60 * 60",
"train_use_spot_instances=True",
"train_volume_size=30",
"train_volume_kms_key='key'",
Expand Down