Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name='m
return estimator

def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None,
use_compiled_model=False, update_endpoint=False, **kwargs):
use_compiled_model=False, update_endpoint=False, wait=True, **kwargs):
"""Deploy the trained model to an Amazon SageMaker endpoint and return a ``sagemaker.RealTimePredictor`` object.

More information:
Expand All @@ -355,6 +355,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
For more information about tags, see https://boto3.amazonaws.com/v1/documentation\
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
wait (bool): Whether the call should wait until the deployment of model completes (default: True).

**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
``create_model()`` to accept ``**kwargs`` to customize model creation during deploy.
Expand All @@ -381,7 +382,8 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
accelerator_type=accelerator_type,
endpoint_name=endpoint_name,
update_endpoint=update_endpoint,
tags=self.tags)
tags=self.tags,
wait=wait)

@property
def model_data(self):
Expand Down
5 changes: 3 additions & 2 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def compile(self, target_instance_family, input_shape, output_path, role,
return self

def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None,
update_endpoint=False, tags=None, kms_key=None):
update_endpoint=False, tags=None, kms_key=None, wait=True):
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.

Create a SageMaker ``Model`` and ``EndpointConfig``, and deploy an ``Endpoint`` from this ``Model``.
Expand Down Expand Up @@ -256,6 +256,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
tags(List[dict[str, str]]): The list of tags to attach to this specific endpoint.
kms_key (str): The ARN of the KMS key that is used to encrypt the data on the
storage volume attached to the instance hosting the endpoint.
wait (bool): Whether the call should wait until the deployment of this model completes (default: True).

Returns:
callable[string, sagemaker.session.Session] or None: Invocation of ``self.predictor_cls`` on
Expand Down Expand Up @@ -296,7 +297,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name)
else:
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant],
tags, kms_key)
tags, kms_key, wait)

if self.predictor_cls:
return self.predictor_cls(self.endpoint_name, self.sagemaker_session)
Expand Down
6 changes: 4 additions & 2 deletions src/sagemaker/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def pipeline_container_def(self, instance_type):

return sagemaker.pipeline_container_def(self.models, instance_type)

def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags=None):
def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags=None, wait=True):
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.

Create a SageMaker ``Model`` and ``EndpointConfig``, and deploy an ``Endpoint`` from this ``Model``.
Expand All @@ -86,6 +86,7 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags
endpoint_name (str): The name of the endpoint to create (default: None).
If not specified, a unique endpoint name will be created.
tags(List[dict[str, str]]): The list of tags to attach to this specific endpoint.
wait (bool): Whether the call should wait until the deployment of model completes (default: True).

Returns:
callable[string, sagemaker.session.Session] or None: Invocation of ``self.predictor_cls`` on
Expand All @@ -101,7 +102,8 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags

production_variant = sagemaker.production_variant(self.name, instance_type, initial_instance_count)
self.endpoint_name = endpoint_name or self.name
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant], tags)
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant], tags,
wait=wait)
if self.predictor_cls:
return self.predictor_cls(self.endpoint_name, self.sagemaker_session)

Expand Down
6 changes: 4 additions & 2 deletions src/sagemaker/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,8 @@ def attach(cls, tuning_job_name, sagemaker_session=None, job_details=None, estim

return tuner

def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None, **kwargs):
def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None, wait=True,
**kwargs):
"""Deploy the best trained or user specified model to an Amazon SageMaker endpoint and return a
``sagemaker.RealTimePredictor`` object.

Expand All @@ -342,6 +343,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
endpoint_name (str): Name to use for creating an Amazon SageMaker endpoint. If not specified,
the name of the training job is used.
wait (bool): Whether the call should wait until the deployment of model completes (default: True).
**kwargs: Other arguments needed for deployment. Please refer to the ``create_model()`` method of
the associated estimator to see what other arguments are needed.

Expand All @@ -354,7 +356,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
sagemaker_session=self.estimator.sagemaker_session)
return best_estimator.deploy(initial_instance_count, instance_type,
accelerator_type=accelerator_type,
endpoint_name=endpoint_name, **kwargs)
endpoint_name=endpoint_name, wait=wait, **kwargs)

def stop_tuning_job(self):
"""Stop latest running hyperparameter tuning job.
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,8 @@ def test_fit_deploy_keep_tags(sagemaker_session):
sagemaker_session.endpoint_from_production_variants.assert_called_with(job_name,
variant,
tags,
None)
None,
True)

sagemaker_session.create_model.assert_called_with(
ANY,
Expand Down
15 changes: 10 additions & 5 deletions tests/unit/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ def test_deploy(sagemaker_session, tmpdir):
'InitialInstanceCount': 1,
'VariantName': 'AllTraffic'}],
None,
None)
None,
True)


@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
Expand All @@ -182,7 +183,8 @@ def test_deploy_endpoint_name(sagemaker_session, tmpdir):
'InitialInstanceCount': 55,
'VariantName': 'AllTraffic'}],
None,
None)
None,
True)


@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
Expand All @@ -199,7 +201,8 @@ def test_deploy_tags(sagemaker_session, tmpdir):
'InitialInstanceCount': 1,
'VariantName': 'AllTraffic'}],
tags,
None)
None,
True)


@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
Expand All @@ -217,7 +220,8 @@ def test_deploy_accelerator_type(tfo, time, sagemaker_session):
'VariantName': 'AllTraffic',
'AcceleratorType': ACCELERATOR_TYPE}],
None,
None)
None,
True)


@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
Expand All @@ -235,7 +239,8 @@ def test_deploy_kms_key(tfo, time, sagemaker_session):
'InitialInstanceCount': 1,
'VariantName': 'AllTraffic'}],
None,
key)
key,
True)


@patch('sagemaker.session.Session')
Expand Down
9 changes: 6 additions & 3 deletions tests/unit/test_pipeline_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def test_deploy(tfo, time, sagemaker_session):
'InstanceType': INSTANCE_TYPE,
'InitialInstanceCount': 1,
'VariantName': 'AllTraffic'}],
None)
None,
wait=True)


@patch('tarfile.open')
Expand All @@ -119,7 +120,8 @@ def test_deploy_endpoint_name(tfo, time, sagemaker_session):
'InstanceType': INSTANCE_TYPE,
'InitialInstanceCount': 1,
'VariantName': 'AllTraffic'}],
None)
None,
wait=True)


@patch('tarfile.open')
Expand Down Expand Up @@ -178,7 +180,8 @@ def test_deploy_tags(tfo, time, sagemaker_session):
'InstanceType': INSTANCE_TYPE,
'InitialInstanceCount': 1,
'VariantName': 'AllTraffic'}],
tags)
tags,
wait=True)


def test_delete_model_without_deploy(sagemaker_session):
Expand Down