Skip to content
Merged
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/hooks/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def poll_query_status(
try:
wait(
waiter=self.get_waiter("query_complete"),
waiter_delay=sleep_time or self.sleep_time,
waiter_delay=self.sleep_time if sleep_time is None else sleep_time,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a somewhat unrelated fix that allows specifying a sleep time of 0 in unit tests. Without this, athena unit tests were taking 30s each

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine, but you can also just mock the sleep function no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think just setting sleep_time=0 is sooo much simpler & cleaner & easier to read

waiter_max_attempts=max_polling_attempts or 120,
args={"QueryExecutionId": query_execution_id},
failure_message=f"Error while waiting for query {query_execution_id} to complete",
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/amazon/aws/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink
from airflow.providers.amazon.aws.triggers.batch import (
BatchCreateComputeEnvironmentTrigger,
BatchOperatorTrigger,
BatchJobTrigger,
)
from airflow.providers.amazon.aws.utils import trim_none_values
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
Expand Down Expand Up @@ -221,12 +221,12 @@ def execute(self, context: Context):
if self.deferrable:
self.defer(
timeout=self.execution_timeout,
trigger=BatchOperatorTrigger(
trigger=BatchJobTrigger(
job_id=self.job_id,
max_retries=self.max_retries or 10,
waiter_max_attempts=self.max_retries or 10,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
poll_interval=self.poll_interval,
waiter_delay=self.poll_interval,
),
method_name="execute_complete",
)
Expand Down
16 changes: 9 additions & 7 deletions airflow/providers/amazon/aws/operators/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook, should_retry_eni
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.triggers.ecs import ClusterWaiterTrigger, TaskDoneTrigger
from airflow.providers.amazon.aws.triggers.ecs import (
ClusterActiveTrigger,
ClusterInactiveTrigger,
TaskDoneTrigger,
)
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
from airflow.utils.helpers import prune_dict
from airflow.utils.session import provide_session
Expand Down Expand Up @@ -139,13 +143,12 @@ def execute(self, context: Context):
self.log.info("Cluster %r in state: %r.", self.cluster_name, cluster_state)
elif self.deferrable:
self.defer(
trigger=ClusterWaiterTrigger(
waiter_name="cluster_active",
trigger=ClusterActiveTrigger(
cluster_arn=cluster_details["clusterArn"],
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region=self.region,
region_name=self.region,
),
method_name="_complete_exec_with_cluster_desc",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
Expand Down Expand Up @@ -217,13 +220,12 @@ def execute(self, context: Context):
self.log.info("Cluster %r in state: %r.", self.cluster_name, cluster_state)
elif self.deferrable:
self.defer(
trigger=ClusterWaiterTrigger(
waiter_name="cluster_inactive",
trigger=ClusterInactiveTrigger(
cluster_arn=cluster_details["clusterArn"],
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region=self.region,
region_name=self.region,
),
method_name="_complete_exec_with_cluster_desc",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
Expand Down
13 changes: 6 additions & 7 deletions airflow/providers/amazon/aws/operators/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@
from airflow.providers.amazon.aws.hooks.eks import EksHook
from airflow.providers.amazon.aws.triggers.eks import (
EksCreateFargateProfileTrigger,
EksCreateNodegroupTrigger,
EksDeleteFargateProfileTrigger,
EksNodegroupTrigger,
EksDeleteNodegroupTrigger,
)
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction
Expand Down Expand Up @@ -413,12 +414,11 @@ def execute(self, context: Context):

if self.deferrable:
self.defer(
trigger=EksNodegroupTrigger(
waiter_name="nodegroup_active",
trigger=EksCreateNodegroupTrigger(
cluster_name=self.cluster_name,
nodegroup_name=self.nodegroup_name,
aws_conn_id=self.aws_conn_id,
region=self.region,
region_name=self.region,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
Expand Down Expand Up @@ -711,12 +711,11 @@ def execute(self, context: Context):
eks_hook.delete_nodegroup(clusterName=self.cluster_name, nodegroupName=self.nodegroup_name)
if self.deferrable:
self.defer(
trigger=EksNodegroupTrigger(
waiter_name="nodegroup_deleted",
trigger=EksDeleteNodegroupTrigger(
cluster_name=self.cluster_name,
nodegroup_name=self.nodegroup_name,
aws_conn_id=self.aws_conn_id,
region=self.region,
region_name=self.region,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def execute(self, context: Context) -> str | None:
virtual_cluster_id=self.virtual_cluster_id,
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
poll_interval=self.poll_interval,
waiter_delay=self.poll_interval,
),
method_name="execute_complete",
)
Expand Down Expand Up @@ -943,8 +943,8 @@ def execute(self, context: Context) -> None:
self.defer(
trigger=EmrTerminateJobFlowTrigger(
job_flow_id=self.job_flow_id,
poll_interval=self.waiter_delay,
max_attempts=self.waiter_max_attempts,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/operators/glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def execute(self, context: Context):
self.defer(
trigger=GlueCrawlerCompleteTrigger(
crawler_name=crawler_name,
poll_interval=self.poll_interval,
waiter_delay=self.poll_interval,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
Expand Down
20 changes: 10 additions & 10 deletions airflow/providers/amazon/aws/operators/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,8 @@ def execute(self, context: Context):
self.defer(
trigger=RedshiftCreateClusterTrigger(
cluster_identifier=self.cluster_identifier,
poll_interval=self.poll_interval,
max_attempt=self.max_attempt,
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_attempt,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
Expand Down Expand Up @@ -361,8 +361,8 @@ def execute(self, context: Context) -> Any:
self.defer(
trigger=RedshiftCreateClusterSnapshotTrigger(
cluster_identifier=self.cluster_identifier,
poll_interval=self.poll_interval,
max_attempts=self.max_attempt,
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_attempt,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
Expand Down Expand Up @@ -510,8 +510,8 @@ def execute(self, context: Context):
self.defer(
trigger=RedshiftResumeClusterTrigger(
cluster_identifier=self.cluster_identifier,
poll_interval=self.poll_interval,
max_attempts=self.max_attempts,
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
Expand Down Expand Up @@ -598,8 +598,8 @@ def execute(self, context: Context):
self.defer(
trigger=RedshiftPauseClusterTrigger(
cluster_identifier=self.cluster_identifier,
poll_interval=self.poll_interval,
max_attempts=self.max_attempts,
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
Expand Down Expand Up @@ -690,8 +690,8 @@ def execute(self, context: Context):
timeout=timedelta(seconds=self.max_attempts * self.poll_interval + 60),
trigger=RedshiftDeleteClusterTrigger(
cluster_identifier=self.cluster_identifier,
poll_interval=self.poll_interval,
max_attempts=self.max_attempts,
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
Expand Down
14 changes: 8 additions & 6 deletions airflow/providers/amazon/aws/sensors/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.providers.amazon.aws.triggers.batch import BatchSensorTrigger
from airflow.providers.amazon.aws.triggers.batch import BatchJobTrigger
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
Expand Down Expand Up @@ -98,11 +98,12 @@ def execute(self, context: Context) -> None:
)
self.defer(
timeout=timeout,
trigger=BatchSensorTrigger(
trigger=BatchJobTrigger(
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
poke_interval=self.poke_interval,
waiter_delay=int(self.poke_interval),
waiter_max_attempts=self.max_retries,
),
method_name="execute_complete",
)
Expand All @@ -113,9 +114,10 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None:

Relies on trigger to throw an exception, otherwise it assumes execution was successful.
"""
if "status" in event and event["status"] == "failure":
raise AirflowException(event["message"])
self.log.info(event["message"])
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")
job_id = event["job_id"]
self.log.info("Batch Job %s complete", job_id)

@deprecated(reason="use `hook` property instead.")
def get_hook(self) -> BatchClientHook:
Expand Down
10 changes: 5 additions & 5 deletions airflow/providers/amazon/aws/sensors/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def execute(self, context: Context):
virtual_cluster_id=self.virtual_cluster_id,
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
poll_interval=self.poll_interval,
waiter_delay=self.poll_interval,
),
method_name="execute_complete",
)
Expand Down Expand Up @@ -501,9 +501,9 @@ def execute(self, context: Context) -> None:
timeout=timedelta(seconds=self.poke_interval * self.max_attempts),
trigger=EmrTerminateJobFlowTrigger(
job_flow_id=self.job_flow_id,
max_attempts=self.max_attempts,
waiter_max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
poll_interval=int(self.poke_interval),
waiter_delay=int(self.poke_interval),
),
method_name="execute_complete",
)
Expand Down Expand Up @@ -628,9 +628,9 @@ def execute(self, context: Context) -> None:
trigger=EmrStepSensorTrigger(
job_flow_id=self.job_flow_id,
step_id=self.step_id,
waiter_delay=int(self.poke_interval),
waiter_max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
max_attempts=self.max_attempts,
poke_interval=int(self.poke_interval),
),
method_name="execute_complete",
)
Expand Down
58 changes: 20 additions & 38 deletions airflow/providers/amazon/aws/triggers/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,61 +16,43 @@
# under the License.
from __future__ import annotations

from typing import Any

from airflow.providers.amazon.aws.hooks.athena import AthenaHook
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger


class AthenaTrigger(BaseTrigger):
class AthenaTrigger(AwsBaseWaiterTrigger):
"""
Trigger for RedshiftCreateClusterOperator.

The trigger will asynchronously poll the boto3 API and wait for the
Redshift cluster to be in the `available` state.

:param query_execution_id: ID of the Athena query execution to watch
:param poll_interval: The amount of time in seconds to wait between attempts.
:param max_attempt: The maximum number of attempts to be made.
:param waiter_delay: The amount of time in seconds to wait between attempts.
:param waiter_max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""

def __init__(
self,
query_execution_id: str,
poll_interval: int,
max_attempt: int,
waiter_delay: int,
waiter_max_attempts: int,
Comment on lines -42 to +41
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this trigger was added in #32186 merged on June 27th, last provider release was on June 20th, so this breaking change is OK.

aws_conn_id: str,
):
self.query_execution_id = query_execution_id
self.poll_interval = poll_interval
self.max_attempt = max_attempt
self.aws_conn_id = aws_conn_id

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
self.__class__.__module__ + "." + self.__class__.__qualname__,
{
"query_execution_id": str(self.query_execution_id),
"poll_interval": str(self.poll_interval),
"max_attempt": str(self.max_attempt),
"aws_conn_id": str(self.aws_conn_id),
},
super().__init__(
serialized_fields={"query_execution_id": query_execution_id},
waiter_name="query_complete",
waiter_args={"QueryExecutionId": query_execution_id},
failure_message=f"Error while waiting for query {query_execution_id} to complete",
status_message=f"Query execution id: {query_execution_id}",
status_queries=["QueryExecution.Status"],
return_value=query_execution_id,
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
)

async def run(self):
hook = AthenaHook(self.aws_conn_id)
async with hook.async_conn as client:
waiter = hook.get_waiter("query_complete", deferrable=True, client=client)
await async_wait(
waiter=waiter,
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_attempt,
args={"QueryExecutionId": self.query_execution_id},
failure_message=f"Error while waiting for query {self.query_execution_id} to complete",
status_message=f"Query execution id: {self.query_execution_id}, "
"Query is still in non-terminal state",
status_args=["QueryExecution.Status.State"],
)
yield TriggerEvent({"status": "success", "value": self.query_execution_id})
def hook(self) -> AwsGenericHook:
return AthenaHook(self.aws_conn_id)
Loading