diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index db19cfcb71..5b25b61d81 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -346,16 +346,21 @@ def download_folder(bucket_name, prefix, target, sagemaker_session): # Try to download the prefix as an object first, in case it is a file and not a 'directory'. # Do this first, in case the object has broader permissions than the bucket. - try: - s3.Object(bucket_name, prefix).download_file(os.path.join(target, os.path.basename(prefix))) - return - except botocore.exceptions.ClientError as e: - if e.response["Error"]["Code"] == "404" and e.response["Error"]["Message"] == "Not Found": - # S3 also throws this error if the object is a folder, - # so assume that is the case here, and then raise for an actual 404 later. - _download_files_under_prefix(bucket_name, prefix, target, s3) - else: - raise + if not prefix.endswith("/"): + try: + file_destination = os.path.join(target, os.path.basename(prefix)) + s3.Object(bucket_name, prefix).download_file(file_destination) + return + except botocore.exceptions.ClientError as e: + err_info = e.response["Error"] + if err_info["Code"] == "404" and err_info["Message"] == "Not Found": + # S3 also throws this error if the object is a folder, + # so assume that is the case here, and then raise for an actual 404 later. + pass + else: + raise + + _download_files_under_prefix(bucket_name, prefix, target, s3) def _download_files_under_prefix(bucket_name, prefix, target, s3): @@ -370,7 +375,7 @@ def _download_files_under_prefix(bucket_name, prefix, target, s3): bucket = s3.Bucket(bucket_name) for obj_sum in bucket.objects.filter(Prefix=prefix): # if obj_sum is a folder object skip it. - if obj_sum.key != "" and obj_sum.key[-1] == "/": + if obj_sum.key.endswith("/"): continue obj = s3.Object(obj_sum.bucket_name, obj_sum.key) s3_relative_path = obj_sum.key[len(prefix) :].lstrip("/") diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 2cbb6561e1..02db2e1533 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of @@ -349,12 +349,16 @@ def obj_mock_download(path): call(os.path.join("/tmp", "train", "validation_data.csv")), ] obj_mock.download_file.assert_has_calls(calls) + assert s3_mock.Object.call_count == 3 + + s3_mock.reset_mock() obj_mock.reset_mock() # Test with a trailing slash for the prefix. sagemaker.utils.download_folder(BUCKET_NAME, "/prefix/", "/tmp", session) obj_mock.download_file.assert_called() obj_mock.download_file.assert_has_calls(calls) + assert s3_mock.Object.call_count == 2 @patch("os.makedirs")