Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,16 @@ def parse_s3_url(s3url: str) -> tuple[str, str]:
:param s3url: The S3 Url to parse.
:return: the parsed bucket name and key
"""
valid_s3_format = "S3://bucket-name/key-name"
valid_s3_virtual_hosted_format = "https://bucket-name.s3.region-code.amazonaws.com/key-name"
format = s3url.split("//")
if re.match(r"s3[na]?:", format[0], re.IGNORECASE):
parsed_url = urlsplit(s3url)
if not parsed_url.netloc:
raise S3HookUriParseFailure(f'Please provide a bucket name using a valid format: "{s3url}"')
raise S3HookUriParseFailure(
"Please provide a bucket name using a valid format of the form: "
+ f'{valid_s3_format} or {valid_s3_virtual_hosted_format} but provided: "{s3url}"'
)

bucket_name = parsed_url.netloc
key = parsed_url.path.lstrip("/")
Expand All @@ -229,8 +234,16 @@ def parse_s3_url(s3url: str) -> tuple[str, str]:
elif temp_split[1] == "s3":
bucket_name = temp_split[0]
key = "/".join(format[1].split("/")[1:])
else:
raise S3HookUriParseFailure(
"Please provide a bucket name using a valid virtually hosted format which should"
+ f' be of the form: {valid_s3_virtual_hosted_format} but provided: "{s3url}"'
)
else:
raise S3HookUriParseFailure(f'Please provide a bucket name using a valid format: "{s3url}"')
raise S3HookUriParseFailure(
"Please provide a bucket name using a valid format of the form: "
+ f'{valid_s3_format} or {valid_s3_virtual_hosted_format} but provided: "{s3url}"'
)
return bucket_name, key

@staticmethod
Expand Down
10 changes: 10 additions & 0 deletions tests/providers/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.amazon.aws.exceptions import S3HookUriParseFailure
from airflow.providers.amazon.aws.hooks.s3 import (
S3Hook,
provide_bucket_name,
Expand Down Expand Up @@ -94,6 +95,15 @@ def test_parse_s3_url_virtual_hosted_style(self):
parsed = S3Hook.parse_s3_url("https://DOC-EXAMPLE-BUCKET1.s3.us-west-2.amazonaws.com/test.png")
assert parsed == ("DOC-EXAMPLE-BUCKET1", "test.png"), "Incorrect parsing of the s3 url"

def test_parse_invalid_s3_url_virtual_hosted_style(self):
with pytest.raises(
S3HookUriParseFailure,
match="Please provide a bucket name using a valid virtually hosted format which should"
+ " be of the form: https://bucket-name.s3.region-code.amazonaws.com/key-name but "
+ 'provided: "https://DOC-EXAMPLE-BUCKET1.us-west-2.amazonaws.com/test.png"',
):
S3Hook.parse_s3_url("https://DOC-EXAMPLE-BUCKET1.us-west-2.amazonaws.com/test.png")

def test_parse_s3_object_directory(self):
parsed = S3Hook.parse_s3_url("s3://test/this/is/not/a-real-s3-directory/")
assert parsed == ("test", "this/is/not/a-real-s3-directory/"), "Incorrect parsing of the s3 url"
Expand Down