Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,16 +346,21 @@ def download_folder(bucket_name, prefix, target, sagemaker_session):

# Try to download the prefix as an object first, in case it is a file and not a 'directory'.
# Do this first, in case the object has broader permissions than the bucket.
try:
s3.Object(bucket_name, prefix).download_file(os.path.join(target, os.path.basename(prefix)))
return
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "404" and e.response["Error"]["Message"] == "Not Found":
# S3 also throws this error if the object is a folder,
# so assume that is the case here, and then raise for an actual 404 later.
_download_files_under_prefix(bucket_name, prefix, target, s3)
else:
raise
if not prefix.endswith("/"):
try:
file_destination = os.path.join(target, os.path.basename(prefix))
s3.Object(bucket_name, prefix).download_file(file_destination)
return
except botocore.exceptions.ClientError as e:
err_info = e.response["Error"]
if err_info["Code"] == "404" and err_info["Message"] == "Not Found":
# S3 also throws this error if the object is a folder,
# so assume that is the case here, and then raise for an actual 404 later.
pass
else:
raise

_download_files_under_prefix(bucket_name, prefix, target, s3)


def _download_files_under_prefix(bucket_name, prefix, target, s3):
Expand All @@ -370,7 +375,7 @@ def _download_files_under_prefix(bucket_name, prefix, target, s3):
bucket = s3.Bucket(bucket_name)
for obj_sum in bucket.objects.filter(Prefix=prefix):
# if obj_sum is a folder object skip it.
if obj_sum.key != "" and obj_sum.key[-1] == "/":
if obj_sum.key.endswith("/"):
continue
obj = s3.Object(obj_sum.bucket_name, obj_sum.key)
s3_relative_path = obj_sum.key[len(prefix) :].lstrip("/")
Expand Down
6 changes: 5 additions & 1 deletion tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
Expand Down Expand Up @@ -349,12 +349,16 @@ def obj_mock_download(path):
call(os.path.join("/tmp", "train", "validation_data.csv")),
]
obj_mock.download_file.assert_has_calls(calls)
assert s3_mock.Object.call_count == 3

s3_mock.reset_mock()
obj_mock.reset_mock()

# Test with a trailing slash for the prefix.
sagemaker.utils.download_folder(BUCKET_NAME, "/prefix/", "/tmp", session)
obj_mock.download_file.assert_called()
obj_mock.download_file.assert_has_calls(calls)
assert s3_mock.Object.call_count == 2


@patch("os.makedirs")
Expand Down