From 328eaec80f0b8ce67c36a179e663d0c7e468212f Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Wed, 12 Jul 2023 09:53:50 -0700 Subject: [PATCH 1/6] Add Deferrable mode to StepFunctionStartExecutionOperator Add unit tests --- .../amazon/aws/operators/step_function.py | 34 ++++++++++- .../amazon/aws/triggers/stepfunction.py | 61 +++++++++++++++++++ .../amazon/aws/waiters/stepfunctions.json | 36 +++++++++++ .../operators/step_functions.rst | 1 + .../aws/operators/test_step_function.py | 17 ++++++ .../amazon/aws/triggers/test_stepfunction.py | 53 ++++++++++++++++ 6 files changed, 200 insertions(+), 2 deletions(-) create mode 100644 airflow/providers/amazon/aws/triggers/stepfunction.py create mode 100644 airflow/providers/amazon/aws/waiters/stepfunctions.json create mode 100644 tests/providers/amazon/aws/triggers/test_stepfunction.py diff --git a/airflow/providers/amazon/aws/operators/step_function.py b/airflow/providers/amazon/aws/operators/step_function.py index 2aa8bdd8e2571..c203227dc5c59 100644 --- a/airflow/providers/amazon/aws/operators/step_function.py +++ b/airflow/providers/amazon/aws/operators/step_function.py @@ -17,11 +17,14 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Sequence +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Sequence +from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook +from airflow.providers.amazon.aws.triggers.stepfunction import StepFunctionsStartExecutionTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -42,6 +45,11 @@ class StepFunctionStartExecutionOperator(BaseOperator): :param state_machine_input: JSON data input to pass to the State Machine :param aws_conn_id: aws connection to uses :param do_xcom_push: if True, execution_arn is pushed to XCom with key execution_arn. + :param waiter_max_attempts: Maximum number of attempts to poll the execution. + :param waiter_delay: Number of seconds between polling the state of the execution. + :param deferrable: If True, the operator will wait asynchronously for the job to complete. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) """ template_fields: Sequence[str] = ("state_machine_arn", "name", "input") @@ -56,6 +64,9 @@ def __init__( state_machine_input: dict | str | None = None, aws_conn_id: str = "aws_default", region_name: str | None = None, + waiter_max_attempts: int = 30, + waiter_delay: int = 60, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): super().__init__(**kwargs) @@ -64,6 +75,9 @@ def __init__( self.input = state_machine_input self.aws_conn_id = aws_conn_id self.region_name = region_name + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable def execute(self, context: Context): hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) @@ -74,9 +88,25 @@ def execute(self, context: Context): raise AirflowException(f"Failed to start State Machine execution for: {self.state_machine_arn}") self.log.info("Started State Machine execution for %s: %s", self.state_machine_arn, execution_arn) - + if self.deferrable: + self.defer( + trigger=StepFunctionsStartExecutionTrigger( + execution_arn=execution_arn, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + ), + method_name="execute_complete", + timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay), + ) return execution_arn + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + if event and event["status"] == "success": + self.log.info("State Machine execution completed successfully") + return event["execution_arn"] + class StepFunctionGetExecutionOutputOperator(BaseOperator): """ diff --git a/airflow/providers/amazon/aws/triggers/stepfunction.py b/airflow/providers/amazon/aws/triggers/stepfunction.py new file mode 100644 index 0000000000000..3dde342c6e5c4 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/stepfunction.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook +from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook +from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger + + +class StepFunctionsStartExecutionTrigger(AwsBaseWaiterTrigger): + """ + Trigger to poll for the completion of a Step Functions execution. + + :param execution_arn: ARN of the state machine to poll + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + *, + execution_arn: str, + waiter_delay: int = 60, + waiter_max_attempts: int = 30, + aws_conn_id: str | None = None, + region_name: str | None = None, + ): + self.aws_conn_id = aws_conn_id + self.region_name = region_name + + super().__init__( + serialized_fields={"execution_arn": execution_arn, "region_name": region_name}, + waiter_name="step_function_succeeded", + waiter_args={"executionArn": execution_arn}, + failure_message="Step function failed", + status_message="Status of step function execution is", + status_queries=["status"], + return_key="execution_arn", + return_value=execution_arn, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + + def hook(self) -> AwsGenericHook: + return StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) diff --git a/airflow/providers/amazon/aws/waiters/stepfunctions.json b/airflow/providers/amazon/aws/waiters/stepfunctions.json new file mode 100644 index 0000000000000..7a7af36786dc7 --- /dev/null +++ b/airflow/providers/amazon/aws/waiters/stepfunctions.json @@ -0,0 +1,36 @@ +{ + "version": 2, + "waiters": { + "step_function_succeeded": { + "operation": "DescribeExecution", + "delay": 30, + "maxAttempts": 60, + "acceptors": [ + { + "matcher": "path", + "argument": "status", + "expected": "SUCCEEDED", + "state": "success" + }, + { + "matcher": "error", + "argument": "status", + "expected": "RUNNING", + "state": "retry" + }, + { + "matcher": "path", + "argument": "status", + "expected": "FAILED", + "state": "failure" + }, + { + "matcher": "path", + "argument": "status", + "expected": "ABORTED", + "state": "failure" + } + ] + } + } +} diff --git a/docs/apache-airflow-providers-amazon/operators/step_functions.rst b/docs/apache-airflow-providers-amazon/operators/step_functions.rst index 1f207b4576f98..7736fa9b16747 100644 --- a/docs/apache-airflow-providers-amazon/operators/step_functions.rst +++ b/docs/apache-airflow-providers-amazon/operators/step_functions.rst @@ -38,6 +38,7 @@ Start an AWS Step Functions state machine execution To start a new AWS Step Functions state machine execution you can use :class:`~airflow.providers.amazon.aws.operators.step_function.StepFunctionStartExecutionOperator`. +You can also run this operator in deferrable mode by setting ``deferrable`` param to ``True``. .. exampleinclude:: /../../tests/system/providers/amazon/aws/example_step_functions.py :language: python diff --git a/tests/providers/amazon/aws/operators/test_step_function.py b/tests/providers/amazon/aws/operators/test_step_function.py index 566e134a86eb7..91ccebf7c6e29 100644 --- a/tests/providers/amazon/aws/operators/test_step_function.py +++ b/tests/providers/amazon/aws/operators/test_step_function.py @@ -22,6 +22,8 @@ import pytest +from airflow.exceptions import TaskDeferred +from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook from airflow.providers.amazon.aws.operators.step_function import ( StepFunctionGetExecutionOutputOperator, StepFunctionStartExecutionOperator, @@ -132,3 +134,18 @@ def test_execute(self, mock_hook): # Then assert hook_response == result + + @mock.patch.object(StepFunctionHook, "start_execution") + def test_step_function_start_execution_deferrable(self, mock_start_execution): + mock_start_execution.return_value = "test-execution-arn" + operator = StepFunctionStartExecutionOperator( + task_id=self.TASK_ID, + state_machine_arn=STATE_MACHINE_ARN, + name=NAME, + state_machine_input=INPUT, + aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME, + deferrable=True, + ) + with pytest.raises(TaskDeferred): + operator.execute(None) diff --git a/tests/providers/amazon/aws/triggers/test_stepfunction.py b/tests/providers/amazon/aws/triggers/test_stepfunction.py new file mode 100644 index 0000000000000..c0dcb0e4b22d1 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_stepfunction.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.providers.amazon.aws.triggers.stepfunctions import StepFunctionsStartExecutionTrigger + +TEST_EXECUTION_ARN = "test-execution-arn" +TEST_WAITER_DELAY = 10 +TEST_WAITER_MAX_ATTEMPTS = 10 +TEST_AWS_CONN_ID = "test-conn-id" +TEST_REGION_NAME = "test-region-name" + + +class TestStepFunctionsTriggers: + @pytest.mark.parametrize( + "trigger", + [ + StepFunctionsStartExecutionTrigger( + execution_arn=TEST_EXECUTION_ARN, + aws_conn_id=TEST_AWS_CONN_ID, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, + region_name=TEST_REGION_NAME, + ) + ], + ) + def test_serialize_recreate(self, trigger): + class_path, args = trigger.serialize() + + class_name = class_path.split(".")[-1] + clazz = globals()[class_name] + instance = clazz(**args) + + class_path2, args2 = instance.serialize() + + assert class_path == class_path2 + assert args == args2 From b6a4a027ff19936336acbf910829a26b54ba02cf Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Wed, 12 Jul 2023 10:12:59 -0700 Subject: [PATCH 2/6] fix import path --- tests/providers/amazon/aws/triggers/test_stepfunction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/amazon/aws/triggers/test_stepfunction.py b/tests/providers/amazon/aws/triggers/test_stepfunction.py index c0dcb0e4b22d1..093de81f281cb 100644 --- a/tests/providers/amazon/aws/triggers/test_stepfunction.py +++ b/tests/providers/amazon/aws/triggers/test_stepfunction.py @@ -18,7 +18,7 @@ import pytest -from airflow.providers.amazon.aws.triggers.stepfunctions import StepFunctionsStartExecutionTrigger +from airflow.providers.amazon.aws.triggers.stepfunction import StepFunctionsStartExecutionTrigger TEST_EXECUTION_ARN = "test-execution-arn" TEST_WAITER_DELAY = 10 From 0effb2f1c0e7f397ae91e53b17e93e8d500a1c72 Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Wed, 12 Jul 2023 10:54:41 -0700 Subject: [PATCH 3/6] Add stepfunction trigger to provider.yaml --- airflow/providers/amazon/provider.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 3622896504da4..3b513a5761d0a 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -553,6 +553,9 @@ triggers: - integration-name: Amazon RDS python-modules: - airflow.providers.amazon.aws.triggers.rds + - integration-name: AWS Step Functions + python-modules: + - airflow.providers.amazon.aws.triggers.stepfunction transfers: - source-integration-name: Amazon DynamoDB From 62adc87f38f94cb043eed75cc65e56d6832f8b2d Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Mon, 17 Jul 2023 15:47:29 -0700 Subject: [PATCH 4/6] Update docstring about default value of deferrable Minor refactor --- airflow/providers/amazon/aws/operators/step_function.py | 6 +++--- airflow/providers/amazon/aws/triggers/stepfunction.py | 8 +++----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/step_function.py b/airflow/providers/amazon/aws/operators/step_function.py index c203227dc5c59..240b46cc7efe5 100644 --- a/airflow/providers/amazon/aws/operators/step_function.py +++ b/airflow/providers/amazon/aws/operators/step_function.py @@ -24,7 +24,7 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook -from airflow.providers.amazon.aws.triggers.stepfunction import StepFunctionsStartExecutionTrigger +from airflow.providers.amazon.aws.triggers.stepfunction import StepFunctionsExecutionCompleteTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -49,7 +49,7 @@ class StepFunctionStartExecutionOperator(BaseOperator): :param waiter_delay: Number of seconds between polling the state of the execution. :param deferrable: If True, the operator will wait asynchronously for the job to complete. This implies waiting for completion. This mode requires aiobotocore module to be installed. - (default: False) + (default: False, but can be overridden in config file by setting default_deferrable to True) """ template_fields: Sequence[str] = ("state_machine_arn", "name", "input") @@ -90,7 +90,7 @@ def execute(self, context: Context): self.log.info("Started State Machine execution for %s: %s", self.state_machine_arn, execution_arn) if self.deferrable: self.defer( - trigger=StepFunctionsStartExecutionTrigger( + trigger=StepFunctionsExecutionCompleteTrigger( execution_arn=execution_arn, waiter_delay=self.waiter_delay, waiter_max_attempts=self.waiter_max_attempts, diff --git a/airflow/providers/amazon/aws/triggers/stepfunction.py b/airflow/providers/amazon/aws/triggers/stepfunction.py index 3dde342c6e5c4..c4875f078f196 100644 --- a/airflow/providers/amazon/aws/triggers/stepfunction.py +++ b/airflow/providers/amazon/aws/triggers/stepfunction.py @@ -21,7 +21,7 @@ from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger -class StepFunctionsStartExecutionTrigger(AwsBaseWaiterTrigger): +class StepFunctionsExecutionCompleteTrigger(AwsBaseWaiterTrigger): """ Trigger to poll for the completion of a Step Functions execution. @@ -39,9 +39,7 @@ def __init__( waiter_max_attempts: int = 30, aws_conn_id: str | None = None, region_name: str | None = None, - ): - self.aws_conn_id = aws_conn_id - self.region_name = region_name + ) -> None: super().__init__( serialized_fields={"execution_arn": execution_arn, "region_name": region_name}, @@ -49,7 +47,7 @@ def __init__( waiter_args={"executionArn": execution_arn}, failure_message="Step function failed", status_message="Status of step function execution is", - status_queries=["status"], + status_queries=["status", "error", "cause"], return_key="execution_arn", return_value=execution_arn, waiter_delay=waiter_delay, From b7228239db0354dc4f2d83bb7c7bc0a80b176607 Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Wed, 19 Jul 2023 15:25:57 -0700 Subject: [PATCH 5/6] Change file name from test_stepfunction.py to test_step_function.py to match other tests Fix failing test --- .../triggers/{test_stepfunction.py => test_step_function.py} | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename tests/providers/amazon/aws/triggers/{test_stepfunction.py => test_step_function.py} (95%) diff --git a/tests/providers/amazon/aws/triggers/test_stepfunction.py b/tests/providers/amazon/aws/triggers/test_step_function.py similarity index 95% rename from tests/providers/amazon/aws/triggers/test_stepfunction.py rename to tests/providers/amazon/aws/triggers/test_step_function.py index 093de81f281cb..2d29dfb8fc2ba 100644 --- a/tests/providers/amazon/aws/triggers/test_stepfunction.py +++ b/tests/providers/amazon/aws/triggers/test_step_function.py @@ -18,7 +18,7 @@ import pytest -from airflow.providers.amazon.aws.triggers.stepfunction import StepFunctionsStartExecutionTrigger +from airflow.providers.amazon.aws.triggers.stepfunction import StepFunctionsExecutionCompleteTrigger TEST_EXECUTION_ARN = "test-execution-arn" TEST_WAITER_DELAY = 10 @@ -31,7 +31,7 @@ class TestStepFunctionsTriggers: @pytest.mark.parametrize( "trigger", [ - StepFunctionsStartExecutionTrigger( + StepFunctionsExecutionCompleteTrigger( execution_arn=TEST_EXECUTION_ARN, aws_conn_id=TEST_AWS_CONN_ID, waiter_delay=TEST_WAITER_DELAY, From af9a8defcfb53422f09e77e03a1512f56d57feaa Mon Sep 17 00:00:00 2001 From: Syed Hussain Date: Wed, 19 Jul 2023 15:36:14 -0700 Subject: [PATCH 6/6] Rename triggers/stepfunction.py to triggers/step_function.py --- .../providers/amazon/aws/operators/step_function.py | 10 ++++++---- .../aws/triggers/{stepfunction.py => step_function.py} | 0 airflow/providers/amazon/provider.yaml | 2 +- .../amazon/aws/triggers/test_step_function.py | 2 +- 4 files changed, 8 insertions(+), 6 deletions(-) rename airflow/providers/amazon/aws/triggers/{stepfunction.py => step_function.py} (100%) diff --git a/airflow/providers/amazon/aws/operators/step_function.py b/airflow/providers/amazon/aws/operators/step_function.py index 240b46cc7efe5..68324df731dec 100644 --- a/airflow/providers/amazon/aws/operators/step_function.py +++ b/airflow/providers/amazon/aws/operators/step_function.py @@ -24,7 +24,7 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook -from airflow.providers.amazon.aws.triggers.stepfunction import StepFunctionsExecutionCompleteTrigger +from airflow.providers.amazon.aws.triggers.step_function import StepFunctionsExecutionCompleteTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -103,9 +103,11 @@ def execute(self, context: Context): return execution_arn def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: - if event and event["status"] == "success": - self.log.info("State Machine execution completed successfully") - return event["execution_arn"] + if event is None or event["status"] != "success": + raise AirflowException(f"Trigger error: event is {event}") + + self.log.info("State Machine execution completed successfully") + return event["execution_arn"] class StepFunctionGetExecutionOutputOperator(BaseOperator): diff --git a/airflow/providers/amazon/aws/triggers/stepfunction.py b/airflow/providers/amazon/aws/triggers/step_function.py similarity index 100% rename from airflow/providers/amazon/aws/triggers/stepfunction.py rename to airflow/providers/amazon/aws/triggers/step_function.py diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 3b513a5761d0a..a4c483ff6b3e1 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -555,7 +555,7 @@ triggers: - airflow.providers.amazon.aws.triggers.rds - integration-name: AWS Step Functions python-modules: - - airflow.providers.amazon.aws.triggers.stepfunction + - airflow.providers.amazon.aws.triggers.step_function transfers: - source-integration-name: Amazon DynamoDB diff --git a/tests/providers/amazon/aws/triggers/test_step_function.py b/tests/providers/amazon/aws/triggers/test_step_function.py index 2d29dfb8fc2ba..d0c25e096f586 100644 --- a/tests/providers/amazon/aws/triggers/test_step_function.py +++ b/tests/providers/amazon/aws/triggers/test_step_function.py @@ -18,7 +18,7 @@ import pytest -from airflow.providers.amazon.aws.triggers.stepfunction import StepFunctionsExecutionCompleteTrigger +from airflow.providers.amazon.aws.triggers.step_function import StepFunctionsExecutionCompleteTrigger TEST_EXECUTION_ARN = "test-execution-arn" TEST_WAITER_DELAY = 10