Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions airflow/providers/amazon/aws/operators/step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Sequence
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
from airflow.providers.amazon.aws.triggers.step_function import StepFunctionsExecutionCompleteTrigger

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand All @@ -42,6 +45,11 @@ class StepFunctionStartExecutionOperator(BaseOperator):
:param state_machine_input: JSON data input to pass to the State Machine
:param aws_conn_id: aws connection to uses
:param do_xcom_push: if True, execution_arn is pushed to XCom with key execution_arn.
:param waiter_max_attempts: Maximum number of attempts to poll the execution.
:param waiter_delay: Number of seconds between polling the state of the execution.
:param deferrable: If True, the operator will wait asynchronously for the job to complete.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False, but can be overridden in config file by setting default_deferrable to True)
"""

template_fields: Sequence[str] = ("state_machine_arn", "name", "input")
Expand All @@ -56,6 +64,9 @@ def __init__(
state_machine_input: dict | str | None = None,
aws_conn_id: str = "aws_default",
region_name: str | None = None,
waiter_max_attempts: int = 30,
waiter_delay: int = 60,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -64,6 +75,9 @@ def __init__(
self.input = state_machine_input
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable

def execute(self, context: Context):
hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
Expand All @@ -74,9 +88,27 @@ def execute(self, context: Context):
raise AirflowException(f"Failed to start State Machine execution for: {self.state_machine_arn}")

self.log.info("Started State Machine execution for %s: %s", self.state_machine_arn, execution_arn)

if self.deferrable:
self.defer(
trigger=StepFunctionsExecutionCompleteTrigger(
execution_arn=execution_arn,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
)
return execution_arn

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
if event is None or event["status"] != "success":
raise AirflowException(f"Trigger error: event is {event}")

self.log.info("State Machine execution completed successfully")
return event["execution_arn"]


class StepFunctionGetExecutionOutputOperator(BaseOperator):
"""
Expand Down
59 changes: 59 additions & 0 deletions airflow/providers/amazon/aws/triggers/step_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger


class StepFunctionsExecutionCompleteTrigger(AwsBaseWaiterTrigger):
"""
Trigger to poll for the completion of a Step Functions execution.

:param execution_arn: ARN of the state machine to poll
:param waiter_delay: The amount of time in seconds to wait between attempts.
:param waiter_max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""

def __init__(
self,
*,
execution_arn: str,
waiter_delay: int = 60,
waiter_max_attempts: int = 30,
aws_conn_id: str | None = None,
region_name: str | None = None,
) -> None:

super().__init__(
serialized_fields={"execution_arn": execution_arn, "region_name": region_name},
waiter_name="step_function_succeeded",
waiter_args={"executionArn": execution_arn},
failure_message="Step function failed",
status_message="Status of step function execution is",
status_queries=["status", "error", "cause"],
return_key="execution_arn",
return_value=execution_arn,
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
)

def hook(self) -> AwsGenericHook:
return StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
36 changes: 36 additions & 0 deletions airflow/providers/amazon/aws/waiters/stepfunctions.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{
"version": 2,
"waiters": {
"step_function_succeeded": {
"operation": "DescribeExecution",
"delay": 30,
"maxAttempts": 60,
"acceptors": [
{
"matcher": "path",
"argument": "status",
"expected": "SUCCEEDED",
"state": "success"
},
{
"matcher": "error",
"argument": "status",
"expected": "RUNNING",
"state": "retry"
},
{
"matcher": "path",
"argument": "status",
"expected": "FAILED",
"state": "failure"
},
{
"matcher": "path",
"argument": "status",
"expected": "ABORTED",
"state": "failure"
}
]
}
}
}
3 changes: 3 additions & 0 deletions airflow/providers/amazon/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,9 @@ triggers:
- integration-name: Amazon RDS
python-modules:
- airflow.providers.amazon.aws.triggers.rds
- integration-name: AWS Step Functions
python-modules:
- airflow.providers.amazon.aws.triggers.step_function

transfers:
- source-integration-name: Amazon DynamoDB
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Start an AWS Step Functions state machine execution

To start a new AWS Step Functions state machine execution you can use
:class:`~airflow.providers.amazon.aws.operators.step_function.StepFunctionStartExecutionOperator`.
You can also run this operator in deferrable mode by setting ``deferrable`` param to ``True``.

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_step_functions.py
:language: python
Expand Down
17 changes: 17 additions & 0 deletions tests/providers/amazon/aws/operators/test_step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

import pytest

from airflow.exceptions import TaskDeferred
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
from airflow.providers.amazon.aws.operators.step_function import (
StepFunctionGetExecutionOutputOperator,
StepFunctionStartExecutionOperator,
Expand Down Expand Up @@ -132,3 +134,18 @@ def test_execute(self, mock_hook):

# Then
assert hook_response == result

@mock.patch.object(StepFunctionHook, "start_execution")
def test_step_function_start_execution_deferrable(self, mock_start_execution):
mock_start_execution.return_value = "test-execution-arn"
operator = StepFunctionStartExecutionOperator(
task_id=self.TASK_ID,
state_machine_arn=STATE_MACHINE_ARN,
name=NAME,
state_machine_input=INPUT,
aws_conn_id=AWS_CONN_ID,
region_name=REGION_NAME,
deferrable=True,
)
with pytest.raises(TaskDeferred):
operator.execute(None)
53 changes: 53 additions & 0 deletions tests/providers/amazon/aws/triggers/test_step_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import pytest

from airflow.providers.amazon.aws.triggers.step_function import StepFunctionsExecutionCompleteTrigger

TEST_EXECUTION_ARN = "test-execution-arn"
TEST_WAITER_DELAY = 10
TEST_WAITER_MAX_ATTEMPTS = 10
TEST_AWS_CONN_ID = "test-conn-id"
TEST_REGION_NAME = "test-region-name"


class TestStepFunctionsTriggers:
@pytest.mark.parametrize(
"trigger",
[
StepFunctionsExecutionCompleteTrigger(
execution_arn=TEST_EXECUTION_ARN,
aws_conn_id=TEST_AWS_CONN_ID,
waiter_delay=TEST_WAITER_DELAY,
waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS,
region_name=TEST_REGION_NAME,
)
],
)
def test_serialize_recreate(self, trigger):
class_path, args = trigger.serialize()

class_name = class_path.split(".")[-1]
clazz = globals()[class_name]
instance = clazz(**args)

class_path2, args2 = instance.serialize()

assert class_path == class_path2
assert args == args2