Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions airflow/providers/amazon/aws/operators/rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,27 @@ class RdsBaseOperator(BaseOperator):
ui_color = "#eeaa88"
ui_fgcolor = "#ffffff"

def __init__(self, *args, aws_conn_id: str = "aws_conn_id", hook_params: dict | None = None, **kwargs):
def __init__(
self,
*args,
aws_conn_id: str = "aws_conn_id",
region_name: str | None = None,
hook_params: dict | None = None,
**kwargs,
):
if hook_params is not None:
warnings.warn(
"The parameter hook_params is deprecated and will be removed. "
"If you were using it, please get in touch either on airflow slack, "
"or by opening a github issue on the project. "
"Note that it is also incompatible with deferrable mode. "
"You can use the region_name parameter to specify the region. "
"If you were using hook_params for other purposes, please get in touch either on "
"airflow slack, or by opening a github issue on the project. "
"You can mention https://github.com/apache/airflow/pull/32352",
AirflowProviderDeprecationWarning,
stacklevel=3, # 2 is in the operator's init, 3 is in the user code creating the operator
)
self.hook_params = hook_params or {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if someone wants to pass something like region_name to the hook? This is a fairly common use case, and there is no reason to exclude that from being used in deferrable operators.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a bunch of operators that don't allow passing a region at all, like the sagemaker operators, DMS, etc.

If we want users to be able to pass a region, I think we should add that as an explicit parameter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pushed a change to that effect.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something we might want to think about actually. In which case a user might want to specify another region? Regions are set as part of connections. That means, a user set regionA in its connection, hence targeting this regions for all AWS calls but would want to target another region for a specific operator? I think that can be done by using another connection then (and then using the param aws_conn_id). WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't exactly know how that works in practice... I suppose you'd have a "main" region where you do the heavy lifting (sagemaker stuff for instance), but you'd have DBs in various regions, and you'd like to be able to hit those without having to change too much stuff ?
Are you saying that the region param should be obsoleted everywhere in the AWS provider in favor of several connection ids ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels duplicate to me ... But that's my developer opinion, I can understand, as a user, just passing a region as parameter is easier than creating a new connection. Maybe a question for @shubham22

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a bunch of operators that don't allow passing a region at all, like the sagemaker operators, DMS, etc.

I agree with you here, but in the interest of keeping things backwards compatible, if a particular operator allows passing region config, then we should continue that or if we want to stop supporting that, then we would need to go through the whole deprecating process.
At this point, I don't know how you want to proceed though :p

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just passing a region as parameter is easier than creating a new connection

I agree with this. I haven't had direct conversation on this with customers, but I would assume this is commonly done. Yes, many operators do not support it today and may be some users might request for them in the future. In any case, we shouldn't take it away if it is supported until and unless we think it is causing some other problems.

self.hook = RdsHook(aws_conn_id=aws_conn_id, **self.hook_params)
self.region_name = region_name
self.hook = RdsHook(aws_conn_id=aws_conn_id, region_name=region_name, **(hook_params or {}))
super().__init__(*args, **kwargs)

self._await_interval = 60 # seconds
Expand Down Expand Up @@ -588,7 +597,7 @@ def execute(self, context: Context) -> str:
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
hook_params=self.hook_params,
region_name=self.region_name,
waiter_name="db_instance_available",
# ignoring type because create_db_instance is a dict
response=create_db_instance, # type: ignore[arg-type]
Expand Down Expand Up @@ -674,7 +683,7 @@ def execute(self, context: Context) -> str:
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
hook_params=self.hook_params,
region_name=self.region_name,
waiter_name="db_instance_deleted",
# ignoring type because delete_db_instance is a dict
response=delete_db_instance, # type: ignore[arg-type]
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/amazon/aws/triggers/rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ def __init__(
waiter_delay: int,
waiter_max_attempts: int,
aws_conn_id: str,
hook_params: dict[str, Any],
region_name: str | None,
response: dict[str, Any],
):
self.db_instance_identifier = db_instance_identifier
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.aws_conn_id = aws_conn_id
self.hook_params = hook_params
self.region_name = region_name
self.waiter_name = waiter_name
self.response = response

Expand All @@ -67,14 +67,14 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"waiter_delay": str(self.waiter_delay),
"waiter_max_attempts": str(self.waiter_max_attempts),
"aws_conn_id": self.aws_conn_id,
"hook_params": self.hook_params,
"region_name": self.region_name,
"waiter_name": self.waiter_name,
"response": self.response,
},
)

async def run(self):
self.hook = RdsHook(aws_conn_id=self.aws_conn_id, **self.hook_params)
self.hook = RdsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
async with self.hook.async_conn as client:
waiter = client.get_waiter(self.waiter_name)
await async_wait(
Expand Down
13 changes: 7 additions & 6 deletions tests/providers/amazon/aws/triggers/test_rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
TEST_WAITER_DELAY = 10
TEST_WAITER_MAX_ATTEMPTS = 10
TEST_AWS_CONN_ID = "test-aws-id"
TEST_REGION = "sa-east-1"
TEST_RESPONSE = {
"DBInstance": {
"DBInstanceIdentifier": "test-db-instance-identifier",
Expand All @@ -47,7 +48,7 @@ def test_rds_db_instance_trigger_serialize(self):
waiter_delay=TEST_WAITER_DELAY,
waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
aws_conn_id=TEST_AWS_CONN_ID,
hook_params={},
region_name=TEST_REGION,
response=TEST_RESPONSE,
)
class_path, args = rds_db_instance_trigger.serialize()
Expand All @@ -58,7 +59,7 @@ def test_rds_db_instance_trigger_serialize(self):
assert args["waiter_delay"] == str(TEST_WAITER_DELAY)
assert args["waiter_max_attempts"] == str(TEST_WAITER_MAX_ATTEMPTS)
assert args["aws_conn_id"] == TEST_AWS_CONN_ID
assert args["hook_params"] == {}
assert args["region_name"] == TEST_REGION
assert args["response"] == TEST_RESPONSE

@pytest.mark.asyncio
Expand All @@ -75,7 +76,7 @@ async def test_rds_db_instance_trigger_run(self, mock_async_conn):
waiter_delay=TEST_WAITER_DELAY,
waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
aws_conn_id=TEST_AWS_CONN_ID,
hook_params={},
region_name=TEST_REGION,
response=TEST_RESPONSE,
)

Expand Down Expand Up @@ -104,7 +105,7 @@ async def test_rds_db_instance_trigger_run_multiple_attempts(self, mock_async_co
waiter_delay=TEST_WAITER_DELAY,
waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
aws_conn_id=TEST_AWS_CONN_ID,
hook_params={},
region_name=TEST_REGION,
response=TEST_RESPONSE,
)

Expand Down Expand Up @@ -135,7 +136,7 @@ async def test_rds_db_instance_trigger_run_attempts_exceeded(self, mock_async_co
waiter_delay=TEST_WAITER_DELAY,
waiter_max_attempts=2,
aws_conn_id=TEST_AWS_CONN_ID,
hook_params={},
region_name=TEST_REGION,
response=TEST_RESPONSE,
)

Expand Down Expand Up @@ -173,7 +174,7 @@ async def test_rds_db_instance_trigger_run_attempts_failed(self, mock_async_conn
waiter_delay=TEST_WAITER_DELAY,
waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
aws_conn_id=TEST_AWS_CONN_ID,
hook_params={},
region_name=TEST_REGION,
response=TEST_RESPONSE,
)

Expand Down