diff --git a/CHANGELOG.rst b/CHANGELOG.rst index e684b1f4bd..2f0516a80f 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -8,6 +8,8 @@ CHANGELOG * bug-fix: Unit Tests: Improve unit test runtime * bug-fix: Estimators: Fix attach for LDA * bug-fix: Estimators: allow code_location to have no key prefix +* bug-fix: Local Mode: Fix s3 training data download when there is a trailing slash + 1.4.1 ===== diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index f4a6d74837..432567484b 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -267,7 +267,8 @@ def _download_folder(self, bucket_name, prefix, target): for obj_sum in bucket.objects.filter(Prefix=prefix): obj = s3.Object(obj_sum.bucket_name, obj_sum.key) - file_path = os.path.join(target, obj_sum.key[len(prefix) + 1:]) + s3_relative_path = obj_sum.key[len(prefix):].lstrip('/') + file_path = os.path.join(target, s3_relative_path) try: os.makedirs(os.path.dirname(file_path)) @@ -275,7 +276,6 @@ def _download_folder(self, bucket_name, prefix, target): if exc.errno != errno.EEXIST: raise pass - obj.download_file(file_path) def _prepare_training_volumes(self, data_dir, input_data_config, hyperparameters): diff --git a/tests/unit/test_image.py b/tests/unit/test_image.py index 0d3b490104..e32f9427e6 100644 --- a/tests/unit/test_image.py +++ b/tests/unit/test_image.py @@ -366,6 +366,15 @@ def test_download_folder(makedirs): calls = [call(os.path.join('/tmp', 'train/train_data.csv')), call(os.path.join('/tmp', 'train/validation_data.csv'))] obj_mock.download_file.assert_has_calls(calls) + obj_mock.reset_mock() + + # Testing with a trailing slash for the prefix. + sagemaker_container._download_folder(BUCKET_NAME, '/prefix/', '/tmp') + obj_mock.download_file.assert_called() + calls = [call(os.path.join('/tmp', 'train/train_data.csv')), + call(os.path.join('/tmp', 'train/validation_data.csv'))] + + obj_mock.download_file.assert_has_calls(calls) def test_ecr_login_non_ecr():