Skip to content
8 changes: 6 additions & 2 deletions src/sagemaker/local/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,15 +327,17 @@ def describe(self):

class _LocalEndpointConfig(object):

def __init__(self, config_name, production_variants):
def __init__(self, config_name, production_variants, tags=None):
self.name = config_name
self.production_variants = production_variants
self.tags = tags
self.creation_time = datetime.datetime.now()

def describe(self):
response = {
'EndpointConfigName': self.name,
'EndpointConfigArn': _UNUSED_ARN,
'Tags': self.tags,
'CreationTime': self.creation_time,
'ProductionVariants': self.production_variants
}
Expand All @@ -348,7 +350,7 @@ class _LocalEndpoint(object):
_IN_SERVICE = 'InService'
_FAILED = 'Failed'

def __init__(self, endpoint_name, endpoint_config_name, local_session=None):
def __init__(self, endpoint_name, endpoint_config_name, tags=None, local_session=None):
# runtime import since there is a cyclic dependency between entities and local_session
from sagemaker.local import LocalSession
self.local_session = local_session or LocalSession()
Expand All @@ -357,6 +359,7 @@ def __init__(self, endpoint_name, endpoint_config_name, local_session=None):
self.name = endpoint_name
self.endpoint_config = local_client.describe_endpoint_config(endpoint_config_name)
self.production_variant = self.endpoint_config['ProductionVariants'][0]
self.tags = tags

model_name = self.production_variant['ModelName']
self.primary_container = local_client.describe_model(model_name)['PrimaryContainer']
Expand Down Expand Up @@ -392,6 +395,7 @@ def describe(self):
'EndpointConfigName': self.endpoint_config['EndpointConfigName'],
'CreationTime': self.create_time,
'ProductionVariants': self.endpoint_config['ProductionVariants'],
'Tags': self.tags,
'EndpointName': self.name,
'EndpointArn': _UNUSED_ARN,
'EndpointStatus': self.state
Expand Down
8 changes: 4 additions & 4 deletions src/sagemaker/local/local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ def describe_endpoint_config(self, EndpointConfigName):
'Code': 'ValidationException', 'Message': 'Could not find local endpoint config'}}
raise ClientError(error_response, 'describe_endpoint_config')

def create_endpoint_config(self, EndpointConfigName, ProductionVariants):
def create_endpoint_config(self, EndpointConfigName, ProductionVariants, Tags=None):
LocalSagemakerClient._endpoint_configs[EndpointConfigName] = _LocalEndpointConfig(
EndpointConfigName, ProductionVariants)
EndpointConfigName, ProductionVariants, Tags)

def describe_endpoint(self, EndpointName):
if EndpointName not in LocalSagemakerClient._endpoints:
Expand All @@ -138,8 +138,8 @@ def describe_endpoint(self, EndpointName):
else:
return LocalSagemakerClient._endpoints[EndpointName].describe()

def create_endpoint(self, EndpointName, EndpointConfigName):
endpoint = _LocalEndpoint(EndpointName, EndpointConfigName, self.sagemaker_session)
def create_endpoint(self, EndpointName, EndpointConfigName, Tags=None):
endpoint = _LocalEndpoint(EndpointName, EndpointConfigName, Tags, self.sagemaker_session)
LocalSagemakerClient._endpoints[EndpointName] = endpoint
endpoint.serve()

Expand Down
3 changes: 2 additions & 1 deletion src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
model_name=self.name,
initial_instance_count=initial_instance_count,
instance_type=instance_type,
accelerator_type=accelerator_type)
accelerator_type=accelerator_type,
tags=tags)
self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name)
else:
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant], tags)
Expand Down
9 changes: 6 additions & 3 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def create_endpoint_config(self, name, model_name, initial_instance_count, insta
)
return name

def create_endpoint(self, endpoint_name, config_name, wait=True):
def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True):
"""Create an Amazon SageMaker ``Endpoint`` according to the endpoint configuration specified in the request.

Once the ``Endpoint`` is created, client applications can send requests to obtain inferences.
Expand All @@ -764,7 +764,10 @@ def create_endpoint(self, endpoint_name, config_name, wait=True):
str: Name of the Amazon SageMaker ``Endpoint`` created.
"""
LOGGER.info('Creating endpoint with name {}'.format(endpoint_name))
self.sagemaker_client.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=config_name)

tags = tags or []

self.sagemaker_client.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=config_name, Tags=tags)
if wait:
self.wait_for_endpoint(endpoint_name)
return endpoint_name
Expand Down Expand Up @@ -1052,7 +1055,7 @@ def endpoint_from_production_variants(self, name, production_variants, tags=None
config_options['Tags'] = tags

self.sagemaker_client.create_endpoint_config(**config_options)
return self.create_endpoint(endpoint_name=name, config_name=name, wait=wait)
return self.create_endpoint(endpoint_name=name, config_name=name, tags=tags, wait=wait)

def expand_role(self, role):
"""Expand an IAM role name into an ARN.
Expand Down
31 changes: 31 additions & 0 deletions tests/integ/test_mxnet_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,37 @@ def test_deploy_model(mxnet_training_job, sagemaker_session, mxnet_full_version)
assert 'Could not find model' in str(exception.value)


def test_deploy_model_with_tags(mxnet_training_job, sagemaker_session, mxnet_full_version):
endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp())

with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=mxnet_training_job)
model_data = desc['ModelArtifacts']['S3ModelArtifacts']
script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py')
model = MXNetModel(model_data, 'SageMakerRole', entry_point=script_path,
py_version=PYTHON_VERSION, sagemaker_session=sagemaker_session,
framework_version=mxnet_full_version)
tags = [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}]
model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name, tags=tags)

returned_model = sagemaker_session.describe_model(EndpointName=model.name)
returned_model_tags = sagemaker_session.list_tags(ResourceArn=returned_model['ModelArn'])['Tags']

endpoint = sagemaker_session.describe_endpoint(EndpointName=endpoint_name)
endpoint_tags = sagemaker_session.list_tags(ResourceArn=endpoint['EndpointArn'])['Tags']

endpoint_config = sagemaker_session.describe_endpoint_config(EndpointConfigName=endpoint['EndpointConfigName'])
endpoint_config_tags = sagemaker_session.list_tags(ResourceArn=endpoint_config['EndpointConfigArn'])['Tags']

production_variants = endpoint_config['ProductionVariants']

assert returned_model_tags == tags
assert endpoint_config_tags == tags
assert endpoint_tags == tags
assert production_variants[0]['InstanceType'] == 'ml.m4.xlarge'
assert production_variants[0]['InitialInstanceCount'] == 1


def test_deploy_model_with_update_endpoint(mxnet_training_job, sagemaker_session, mxnet_full_version):
endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp())

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_create_deploy_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_create_endpoint_no_wait(sagemaker_session):

assert returned_name == ENDPOINT_NAME
sagemaker_session.sagemaker_client.create_endpoint.assert_called_once_with(
EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME)
EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME, Tags=[])


def test_create_endpoint_wait(sagemaker_session):
Expand All @@ -105,5 +105,5 @@ def test_create_endpoint_wait(sagemaker_session):

assert returned_name == ENDPOINT_NAME
sagemaker_session.sagemaker_client.create_endpoint.assert_called_once_with(
EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME)
EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME, Tags=[])
sagemaker_session.wait_for_endpoint.assert_called_once_with(ENDPOINT_NAME)
3 changes: 2 additions & 1 deletion tests/unit/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ def test_deploy_update_endpoint(sagemaker_session, tmpdir):
model_name=model.name,
initial_instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
accelerator_type=ACCELERATOR_TYPE
accelerator_type=ACCELERATOR_TYPE,
tags=None
)
config_name = sagemaker_session.create_endpoint_config(
name=model.name,
Expand Down
9 changes: 6 additions & 3 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,8 @@ def test_endpoint_from_production_variants(sagemaker_session):
ims.sagemaker_client.describe_endpoint_config = Mock(side_effect=ex)
sagemaker_session.endpoint_from_production_variants('some-endpoint', pvs)
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(EndpointConfigName='some-endpoint',
EndpointName='some-endpoint')
EndpointName='some-endpoint',
Tags=[])
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
EndpointConfigName='some-endpoint',
ProductionVariants=pvs)
Expand All @@ -936,7 +937,8 @@ def test_endpoint_from_production_variants_with_tags(sagemaker_session):
tags = [{'ModelName': 'TestModel'}]
sagemaker_session.endpoint_from_production_variants('some-endpoint', pvs, tags)
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(EndpointConfigName='some-endpoint',
EndpointName='some-endpoint')
EndpointName='some-endpoint',
Tags=tags)
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
EndpointConfigName='some-endpoint',
ProductionVariants=pvs,
Expand All @@ -953,7 +955,8 @@ def test_endpoint_from_production_variants_with_accelerator_type(sagemaker_sessi
tags = [{'ModelName': 'TestModel'}]
sagemaker_session.endpoint_from_production_variants('some-endpoint', pvs, tags)
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(EndpointConfigName='some-endpoint',
EndpointName='some-endpoint')
EndpointName='some-endpoint',
Tags=tags)
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
EndpointConfigName='some-endpoint',
ProductionVariants=pvs,
Expand Down