Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/sagemaker/local/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,15 @@ def _download_folder(self, bucket_name, prefix, target):

for obj_sum in bucket.objects.filter(Prefix=prefix):
obj = s3.Object(obj_sum.bucket_name, obj_sum.key)
file_path = os.path.join(target, obj_sum.key[len(prefix) + 1:])
s3_relative_path = obj_sum.key[len(prefix):].lstrip('/')
file_path = os.path.join(target, s3_relative_path)

try:
os.makedirs(os.path.dirname(file_path))
except OSError as exc:
if exc.errno != errno.EEXIST:
raise
pass

obj.download_file(file_path)

def _prepare_training_volumes(self, data_dir, input_data_config, hyperparameters):
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,14 @@ def test_download_folder(makedirs):
calls = [call(os.path.join('/tmp', 'train/train_data.csv')),
call(os.path.join('/tmp', 'train/validation_data.csv'))]
obj_mock.download_file.assert_has_calls(calls)
obj_mock.reset_mock()

sagemaker_container._download_folder(BUCKET_NAME, '/prefix/', '/tmp')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: either parametrize these 2 test cases or add a comment about the difference between then.

obj_mock.download_file.assert_called()
calls = [call(os.path.join('/tmp', 'train/train_data.csv')),
call(os.path.join('/tmp', 'train/validation_data.csv'))]

obj_mock.download_file.assert_has_calls(calls)


def test_ecr_login_non_ecr():
Expand Down