Skip to content
3 changes: 2 additions & 1 deletion src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
model_name=self.name,
initial_instance_count=initial_instance_count,
instance_type=instance_type,
accelerator_type=accelerator_type)
accelerator_type=accelerator_type,
tags=tags)
self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name)
else:
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant], tags)
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def create_endpoint_config(self, name, model_name, initial_instance_count, insta
)
return name

def create_endpoint(self, endpoint_name, config_name, wait=True):
def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True):
"""Create an Amazon SageMaker ``Endpoint`` according to the endpoint configuration specified in the request.

Once the ``Endpoint`` is created, client applications can send requests to obtain inferences.
Expand All @@ -764,7 +764,7 @@ def create_endpoint(self, endpoint_name, config_name, wait=True):
str: Name of the Amazon SageMaker ``Endpoint`` created.
"""
LOGGER.info('Creating endpoint with name {}'.format(endpoint_name))
self.sagemaker_client.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=config_name)
self.sagemaker_client.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=config_name, Tags=tags)
if wait:
self.wait_for_endpoint(endpoint_name)
return endpoint_name
Expand Down Expand Up @@ -1052,7 +1052,7 @@ def endpoint_from_production_variants(self, name, production_variants, tags=None
config_options['Tags'] = tags

self.sagemaker_client.create_endpoint_config(**config_options)
return self.create_endpoint(endpoint_name=name, config_name=name, wait=wait)
return self.create_endpoint(endpoint_name=name, config_name=name, tags=tags, wait=wait)

def expand_role(self, role):
"""Expand an IAM role name into an ARN.
Expand Down
25 changes: 25 additions & 0 deletions tests/integ/test_mxnet_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,31 @@ def test_deploy_model(mxnet_training_job, sagemaker_session, mxnet_full_version)
sagemaker_session.sagemaker_client.describe_model(ModelName=model.name)
assert 'Could not find model' in str(exception.value)

def test_deploy_model_with_tags(mxnet_training_job, sagemaker_session, mxnet_full_version):
endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp())

with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=mxnet_training_job)
model_data = desc['ModelArtifacts']['S3ModelArtifacts']
script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py')
model = MXNetModel(model_data, 'SageMakerRole', entry_point=script_path,
py_version=PYTHON_VERSION, sagemaker_session=sagemaker_session,
framework_version=mxnet_full_version)
tags = [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}]
predictor = model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name, tags=tags)

endpoint = sagemaker_session.describe_endpoint(EndpointName=endpoint_name)
endpoint_tags = sagemaker_session.list_tags(ResourceArn=endpoint['EndpointArn'])['Tags']

endpoint_config = sagemaker_session.describe_endpoint_config(EndpointConfigName=endpoint['EndpointConfigName'])
endpoint_config_tags = sagemaker_session.list_tags(ResourceArn=endpoint_config['EndpointConfigArn'])['Tags']

production_variants = endpoint_config['ProductionVariants']

assert endpoint_config_tags == tags
assert endpoint_tags == tags
assert production_variants[0]['InstanceType'] == 'ml.m4.xlarge'
assert production_variants[0]['InitialInstanceCount'] == 1

def test_deploy_model_with_update_endpoint(mxnet_training_job, sagemaker_session, mxnet_full_version):
endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp())
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_create_deploy_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_create_endpoint_no_wait(sagemaker_session):

assert returned_name == ENDPOINT_NAME
sagemaker_session.sagemaker_client.create_endpoint.assert_called_once_with(
EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME)
EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME, Tags=None)


def test_create_endpoint_wait(sagemaker_session):
Expand All @@ -105,5 +105,5 @@ def test_create_endpoint_wait(sagemaker_session):

assert returned_name == ENDPOINT_NAME
sagemaker_session.sagemaker_client.create_endpoint.assert_called_once_with(
EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME)
EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME, Tags=None)
sagemaker_session.wait_for_endpoint.assert_called_once_with(ENDPOINT_NAME)
3 changes: 2 additions & 1 deletion tests/unit/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ def test_deploy_update_endpoint(sagemaker_session, tmpdir):
model_name=model.name,
initial_instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
accelerator_type=ACCELERATOR_TYPE
accelerator_type=ACCELERATOR_TYPE,
tags=None
)
config_name = sagemaker_session.create_endpoint_config(
name=model.name,
Expand Down
9 changes: 6 additions & 3 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,8 @@ def test_endpoint_from_production_variants(sagemaker_session):
ims.sagemaker_client.describe_endpoint_config = Mock(side_effect=ex)
sagemaker_session.endpoint_from_production_variants('some-endpoint', pvs)
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(EndpointConfigName='some-endpoint',
EndpointName='some-endpoint')
EndpointName='some-endpoint',
Tags=None)
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
EndpointConfigName='some-endpoint',
ProductionVariants=pvs)
Expand All @@ -936,7 +937,8 @@ def test_endpoint_from_production_variants_with_tags(sagemaker_session):
tags = [{'ModelName': 'TestModel'}]
sagemaker_session.endpoint_from_production_variants('some-endpoint', pvs, tags)
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(EndpointConfigName='some-endpoint',
EndpointName='some-endpoint')
EndpointName='some-endpoint',
Tags=tags)
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
EndpointConfigName='some-endpoint',
ProductionVariants=pvs,
Expand All @@ -953,7 +955,8 @@ def test_endpoint_from_production_variants_with_accelerator_type(sagemaker_sessi
tags = [{'ModelName': 'TestModel'}]
sagemaker_session.endpoint_from_production_variants('some-endpoint', pvs, tags)
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(EndpointConfigName='some-endpoint',
EndpointName='some-endpoint')
EndpointName='some-endpoint',
Tags=tags)
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
EndpointConfigName='some-endpoint',
ProductionVariants=pvs,
Expand Down