diff --git a/airflow/providers/amazon/aws/operators/step_function.py b/airflow/providers/amazon/aws/operators/step_function.py index 2aa8bdd8e2571..68324df731dec 100644 --- a/airflow/providers/amazon/aws/operators/step_function.py +++ b/airflow/providers/amazon/aws/operators/step_function.py @@ -17,11 +17,14 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Sequence +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Sequence +from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook +from airflow.providers.amazon.aws.triggers.step_function import StepFunctionsExecutionCompleteTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -42,6 +45,11 @@ class StepFunctionStartExecutionOperator(BaseOperator): :param state_machine_input: JSON data input to pass to the State Machine :param aws_conn_id: aws connection to uses :param do_xcom_push: if True, execution_arn is pushed to XCom with key execution_arn. + :param waiter_max_attempts: Maximum number of attempts to poll the execution. + :param waiter_delay: Number of seconds between polling the state of the execution. + :param deferrable: If True, the operator will wait asynchronously for the job to complete. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False, but can be overridden in config file by setting default_deferrable to True) """ template_fields: Sequence[str] = ("state_machine_arn", "name", "input") @@ -56,6 +64,9 @@ def __init__( state_machine_input: dict | str | None = None, aws_conn_id: str = "aws_default", region_name: str | None = None, + waiter_max_attempts: int = 30, + waiter_delay: int = 60, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): super().__init__(**kwargs) @@ -64,6 +75,9 @@ def __init__( self.input = state_machine_input self.aws_conn_id = aws_conn_id self.region_name = region_name + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable def execute(self, context: Context): hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) @@ -74,9 +88,27 @@ def execute(self, context: Context): raise AirflowException(f"Failed to start State Machine execution for: {self.state_machine_arn}") self.log.info("Started State Machine execution for %s: %s", self.state_machine_arn, execution_arn) - + if self.deferrable: + self.defer( + trigger=StepFunctionsExecutionCompleteTrigger( + execution_arn=execution_arn, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + ), + method_name="execute_complete", + timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay), + ) return execution_arn + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + if event is None or event["status"] != "success": + raise AirflowException(f"Trigger error: event is {event}") + + self.log.info("State Machine execution completed successfully") + return event["execution_arn"] + class StepFunctionGetExecutionOutputOperator(BaseOperator): """ diff --git a/airflow/providers/amazon/aws/triggers/step_function.py b/airflow/providers/amazon/aws/triggers/step_function.py new file mode 100644 index 0000000000000..c4875f078f196 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/step_function.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook +from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook +from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger + + +class StepFunctionsExecutionCompleteTrigger(AwsBaseWaiterTrigger): + """ + Trigger to poll for the completion of a Step Functions execution. + + :param execution_arn: ARN of the state machine to poll + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + *, + execution_arn: str, + waiter_delay: int = 60, + waiter_max_attempts: int = 30, + aws_conn_id: str | None = None, + region_name: str | None = None, + ) -> None: + + super().__init__( + serialized_fields={"execution_arn": execution_arn, "region_name": region_name}, + waiter_name="step_function_succeeded", + waiter_args={"executionArn": execution_arn}, + failure_message="Step function failed", + status_message="Status of step function execution is", + status_queries=["status", "error", "cause"], + return_key="execution_arn", + return_value=execution_arn, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + + def hook(self) -> AwsGenericHook: + return StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) diff --git a/airflow/providers/amazon/aws/waiters/stepfunctions.json b/airflow/providers/amazon/aws/waiters/stepfunctions.json new file mode 100644 index 0000000000000..7a7af36786dc7 --- /dev/null +++ b/airflow/providers/amazon/aws/waiters/stepfunctions.json @@ -0,0 +1,36 @@ +{ + "version": 2, + "waiters": { + "step_function_succeeded": { + "operation": "DescribeExecution", + "delay": 30, + "maxAttempts": 60, + "acceptors": [ + { + "matcher": "path", + "argument": "status", + "expected": "SUCCEEDED", + "state": "success" + }, + { + "matcher": "error", + "argument": "status", + "expected": "RUNNING", + "state": "retry" + }, + { + "matcher": "path", + "argument": "status", + "expected": "FAILED", + "state": "failure" + }, + { + "matcher": "path", + "argument": "status", + "expected": "ABORTED", + "state": "failure" + } + ] + } + } +} diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 3622896504da4..a4c483ff6b3e1 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -553,6 +553,9 @@ triggers: - integration-name: Amazon RDS python-modules: - airflow.providers.amazon.aws.triggers.rds + - integration-name: AWS Step Functions + python-modules: + - airflow.providers.amazon.aws.triggers.step_function transfers: - source-integration-name: Amazon DynamoDB diff --git a/docs/apache-airflow-providers-amazon/operators/step_functions.rst b/docs/apache-airflow-providers-amazon/operators/step_functions.rst index 1f207b4576f98..7736fa9b16747 100644 --- a/docs/apache-airflow-providers-amazon/operators/step_functions.rst +++ b/docs/apache-airflow-providers-amazon/operators/step_functions.rst @@ -38,6 +38,7 @@ Start an AWS Step Functions state machine execution To start a new AWS Step Functions state machine execution you can use :class:`~airflow.providers.amazon.aws.operators.step_function.StepFunctionStartExecutionOperator`. +You can also run this operator in deferrable mode by setting ``deferrable`` param to ``True``. .. exampleinclude:: /../../tests/system/providers/amazon/aws/example_step_functions.py :language: python diff --git a/tests/providers/amazon/aws/operators/test_step_function.py b/tests/providers/amazon/aws/operators/test_step_function.py index 566e134a86eb7..91ccebf7c6e29 100644 --- a/tests/providers/amazon/aws/operators/test_step_function.py +++ b/tests/providers/amazon/aws/operators/test_step_function.py @@ -22,6 +22,8 @@ import pytest +from airflow.exceptions import TaskDeferred +from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook from airflow.providers.amazon.aws.operators.step_function import ( StepFunctionGetExecutionOutputOperator, StepFunctionStartExecutionOperator, @@ -132,3 +134,18 @@ def test_execute(self, mock_hook): # Then assert hook_response == result + + @mock.patch.object(StepFunctionHook, "start_execution") + def test_step_function_start_execution_deferrable(self, mock_start_execution): + mock_start_execution.return_value = "test-execution-arn" + operator = StepFunctionStartExecutionOperator( + task_id=self.TASK_ID, + state_machine_arn=STATE_MACHINE_ARN, + name=NAME, + state_machine_input=INPUT, + aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME, + deferrable=True, + ) + with pytest.raises(TaskDeferred): + operator.execute(None) diff --git a/tests/providers/amazon/aws/triggers/test_step_function.py b/tests/providers/amazon/aws/triggers/test_step_function.py new file mode 100644 index 0000000000000..d0c25e096f586 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_step_function.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.providers.amazon.aws.triggers.step_function import StepFunctionsExecutionCompleteTrigger + +TEST_EXECUTION_ARN = "test-execution-arn" +TEST_WAITER_DELAY = 10 +TEST_WAITER_MAX_ATTEMPTS = 10 +TEST_AWS_CONN_ID = "test-conn-id" +TEST_REGION_NAME = "test-region-name" + + +class TestStepFunctionsTriggers: + @pytest.mark.parametrize( + "trigger", + [ + StepFunctionsExecutionCompleteTrigger( + execution_arn=TEST_EXECUTION_ARN, + aws_conn_id=TEST_AWS_CONN_ID, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, + region_name=TEST_REGION_NAME, + ) + ], + ) + def test_serialize_recreate(self, trigger): + class_path, args = trigger.serialize() + + class_name = class_path.split(".")[-1] + clazz = globals()[class_name] + instance = clazz(**args) + + class_path2, args2 = instance.serialize() + + assert class_path == class_path2 + assert args == args2