Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ CHANGELOG
========

* bug-fix: Local Mode: Create output/data directory expected by SageMaker Container.
* enhancement: Add support for volume KMS key to Transformer

1.9.2
=====
Expand Down
12 changes: 8 additions & 4 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def delete_endpoint(self):

def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None,
max_payload=None, tags=None, role=None):
max_payload=None, tags=None, role=None, volume_kms_key=None):
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
SageMaker Session and base job name used by the Estimator.

Expand All @@ -350,6 +350,8 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
the training job are used for the transform job.
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
transform jobs. If not specified, the role from the Estimator will be used.
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
compute instance (default: None).
"""
self._ensure_latest_training_job()

Expand All @@ -360,7 +362,7 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
output_path=output_path, output_kms_key=output_kms_key, accept=accept,
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
env=env, tags=tags, base_transform_job_name=self.base_job_name,
sagemaker_session=self.sagemaker_session)
volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session)

@property
def training_job_analytics(self):
Expand Down Expand Up @@ -756,7 +758,7 @@ def _update_init_params(cls, hp, tf_arguments):

def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None,
max_payload=None, tags=None, role=None, model_server_workers=None):
max_payload=None, tags=None, role=None, model_server_workers=None, volume_kms_key=None):
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
SageMaker Session and base job name used by the Estimator.

Expand All @@ -780,6 +782,8 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
transform jobs. If not specified, the role from the Estimator will be used.
model_server_workers (int): Optional. The number of worker processes used by the inference server.
If None, server will use one worker per vCPU.
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
compute instance (default: None).
"""
self._ensure_latest_training_job()
role = role or self.role
Expand All @@ -799,7 +803,7 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
output_path=output_path, output_kms_key=output_kms_key, accept=accept,
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
env=transform_env, tags=tags, base_transform_job_name=self.base_job_name,
sagemaker_session=self.sagemaker_session)
volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session)


def _s3_uri_prefix(channel_name, s3_data):
Expand Down
13 changes: 9 additions & 4 deletions src/sagemaker/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Transformer(object):

def __init__(self, model_name, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
output_kms_key=None, accept=None, max_concurrent_transforms=None, max_payload=None, tags=None,
env=None, base_transform_job_name=None, sagemaker_session=None):
env=None, base_transform_job_name=None, sagemaker_session=None, volume_kms_key=None):
"""Initialize a ``Transformer``.

Args:
Expand All @@ -50,6 +50,8 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
using the default AWS configuration chain.
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
compute instance (default: None).
"""
self.model_name = model_name
self.strategy = strategy
Expand All @@ -62,6 +64,7 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass

self.instance_count = instance_count
self.instance_type = instance_type
self.volume_kms_key = volume_kms_key

self.max_concurrent_transforms = max_concurrent_transforms
self.max_payload = max_payload
Expand Down Expand Up @@ -159,6 +162,7 @@ def _prepare_init_params_from_job_description(cls, job_details):
init_params['model_name'] = job_details['ModelName']
init_params['instance_count'] = job_details['TransformResources']['InstanceCount']
init_params['instance_type'] = job_details['TransformResources']['InstanceType']
init_params['volume_kms_key'] = job_details['TransformResources'].get('VolumeKmsKeyId')
init_params['strategy'] = job_details.get('BatchStrategy')
init_params['assemble_with'] = job_details['TransformOutput'].get('AssembleWith')
init_params['output_path'] = job_details['TransformOutput']['S3OutputPath']
Expand Down Expand Up @@ -200,7 +204,8 @@ def _load_config(data, data_type, content_type, compression_type, split_type, tr
output_config = _TransformJob._prepare_output_config(transformer.output_path, transformer.output_kms_key,
transformer.assemble_with, transformer.accept)

resource_config = _TransformJob._prepare_resource_config(transformer.instance_count, transformer.instance_type)
resource_config = _TransformJob._prepare_resource_config(transformer.instance_count, transformer.instance_type,
transformer.volume_kms_key)

return {'input_config': input_config,
'output_config': output_config,
Expand Down Expand Up @@ -241,5 +246,5 @@ def _prepare_output_config(s3_path, kms_key_id, assemble_with, accept):
return config

@staticmethod
def _prepare_resource_config(instance_count, instance_type):
return {'InstanceCount': instance_count, 'InstanceType': instance_type}
def _prepare_resource_config(instance_count, instance_type, volume_kms_key):
return {'InstanceCount': instance_count, 'InstanceType': instance_type, 'VolumeKmsKeyId': volume_kms_key}
91 changes: 91 additions & 0 deletions tests/integ/kms_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

KEY_ALIAS = "SageMakerKmsKey"
KEY_POLICY = '''
{{
"Version": "2012-10-17",
"Id": "sagemaker-kms-integ-test-policy",
"Statement": [
{{
"Sid": "Enable IAM User Permissions",
"Effect": "Allow",
"Principal": {{
"AWS": "arn:aws:iam::{account_id}:root"
}},
"Action": "kms:*",
"Resource": "*"
}},
{{
"Sid": "Allow use of the key",
"Effect": "Allow",
"Principal": {{
"AWS": "arn:aws:iam::{account_id}:role/SageMakerRole"
}},
"Action": [
"kms:Encrypt",
"kms:Decrypt",
"kms:ReEncrypt*",
"kms:GenerateDataKey*",
"kms:DescribeKey"
],
"Resource": "*"
}},
{{
"Sid": "Allow attachment of persistent resources",
"Effect": "Allow",
"Principal": {{
"AWS": "arn:aws:iam::{account_id}:role/SageMakerRole"
}},
"Action": [
"kms:CreateGrant",
"kms:ListGrants",
"kms:RevokeGrant"
],
"Resource": "*",
"Condition": {{
"Bool": {{
"kms:GrantIsForAWSResource": "true"
}}
}}
}}
]
}}
'''


def _get_kms_key_arn(kms_client, alias):
try:
response = kms_client.describe_key(KeyId='alias/' + alias)
return response['KeyMetadata']['Arn']
except kms_client.exceptions.NotFoundException:
return None


def _create_kms_key(kms_client, account_id):
response = kms_client.create_key(
Policy=KEY_POLICY.format(account_id=account_id),
Description='KMS key for SageMaker Python SDK integ tests',
)
key_arn = response['KeyMetadata']['Arn']
response = kms_client.create_alias(AliasName='alias/' + KEY_ALIAS, TargetKeyId=key_arn)
return key_arn


def get_or_create_kms_key(kms_client, account_id):
kms_key_arn = _get_kms_key_arn(kms_client, KEY_ALIAS)
if kms_key_arn is not None:
return kms_key_arn
else:
return _create_kms_key(kms_client, account_id)
15 changes: 12 additions & 3 deletions tests/integ/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from sagemaker.mxnet import MXNet
from sagemaker.transformer import Transformer
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
from tests.integ.kms_utils import get_or_create_kms_key
from tests.integ.timeout import timeout


Expand All @@ -47,8 +48,16 @@ def test_transform_mxnet(sagemaker_session):
transform_input = mx.sagemaker_session.upload_data(path=transform_input_path,
key_prefix=transform_input_key_prefix)

transformer = _create_transformer_and_transform_job(mx, transform_input)
sts_client = sagemaker_session.boto_session.client('sts')
account_id = sts_client.get_caller_identity()['Account']
kms_client = sagemaker_session.boto_session.client('kms')
kms_key_arn = get_or_create_kms_key(kms_client, account_id)

transformer = _create_transformer_and_transform_job(mx, transform_input, kms_key_arn)
transformer.wait()
job_desc = transformer.sagemaker_session.sagemaker_client.describe_transform_job(
TransformJobName=transformer.latest_transform_job.name)
assert kms_key_arn == job_desc['TransformResources']['VolumeKmsKeyId']


@pytest.mark.continuous_testing
Expand Down Expand Up @@ -90,7 +99,7 @@ def test_attach_transform_kmeans(sagemaker_session):
attached_transformer.wait()


def _create_transformer_and_transform_job(estimator, transform_input):
transformer = estimator.transformer(1, 'ml.m4.xlarge')
def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None):
transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key)
transformer.transform(transform_input, content_type='text/csv')
return transformer
3 changes: 2 additions & 1 deletion tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
transformer = fw.transformer(INSTANCE_COUNT, INSTANCE_TYPE, strategy=strategy, assemble_with=assemble_with,
output_path=OUTPUT_PATH, output_kms_key=kms_key, accept=accept, tags=TAGS,
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
env=env, role=new_role, model_server_workers=1)
volume_kms_key=kms_key, env=env, role=new_role, model_server_workers=1)

sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, new_role, MODEL_CONTAINER_DEF)
assert transformer.strategy == strategy
Expand All @@ -501,6 +501,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
assert transformer.env == env
assert transformer.base_transform_job_name == base_name
assert transformer.tags == TAGS
assert transformer.volume_kms_key == kms_key


def test_ensure_latest_training_job(sagemaker_session):
Expand Down
13 changes: 9 additions & 4 deletions tests/unit/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

INSTANCE_COUNT = 1
INSTANCE_TYPE = 'ml.m4.xlarge'
KMS_KEY_ID = 'kms-key-id'

S3_DATA_TYPE = 'S3Prefix'
S3_BUCKET = 'bucket'
Expand All @@ -48,7 +49,8 @@ def sagemaker_session():
@pytest.fixture()
def transformer(sagemaker_session):
return Transformer(MODEL_NAME, INSTANCE_COUNT, INSTANCE_TYPE,
output_path=OUTPUT_PATH, sagemaker_session=sagemaker_session)
output_path=OUTPUT_PATH, sagemaker_session=sagemaker_session,
volume_kms_key=KMS_KEY_ID)


@patch('sagemaker.transformer._TransformJob.start_new')
Expand Down Expand Up @@ -178,7 +180,8 @@ def test_prepare_init_params_from_job_description_all_keys(transformer):
'ModelName': MODEL_NAME,
'TransformResources': {
'InstanceCount': INSTANCE_COUNT,
'InstanceType': INSTANCE_TYPE
'InstanceType': INSTANCE_TYPE,
'VolumeKmsKeyId': KMS_KEY_ID
},
'BatchStrategy': None,
'TransformOutput': {
Expand All @@ -197,6 +200,7 @@ def test_prepare_init_params_from_job_description_all_keys(transformer):
assert init_params['model_name'] == MODEL_NAME
assert init_params['instance_count'] == INSTANCE_COUNT
assert init_params['instance_type'] == INSTANCE_TYPE
assert init_params['volume_kms_key'] == KMS_KEY_ID


# _TransformJob tests
Expand Down Expand Up @@ -227,6 +231,7 @@ def test_load_config(transformer):
'resource_config': {
'InstanceCount': INSTANCE_COUNT,
'InstanceType': INSTANCE_TYPE,
'VolumeKmsKeyId': KMS_KEY_ID,
},
}

Expand Down Expand Up @@ -292,8 +297,8 @@ def test_prepare_output_config_with_optional_params():


def test_prepare_resource_config():
config = _TransformJob._prepare_resource_config(INSTANCE_COUNT, INSTANCE_TYPE)
assert config == {'InstanceCount': INSTANCE_COUNT, 'InstanceType': INSTANCE_TYPE}
config = _TransformJob._prepare_resource_config(INSTANCE_COUNT, INSTANCE_TYPE, KMS_KEY_ID)
assert config == {'InstanceCount': INSTANCE_COUNT, 'InstanceType': INSTANCE_TYPE, 'VolumeKmsKeyId': KMS_KEY_ID}


def test_transform_job_wait(sagemaker_session):
Expand Down