diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py index dea61c0858d50..185c7ab13f9c8 100644 --- a/airflow/providers/amazon/aws/hooks/emr.py +++ b/airflow/providers/amazon/aws/hooks/emr.py @@ -26,7 +26,7 @@ from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook -from airflow.utils.helpers import prune_dict +from airflow.providers.amazon.aws.utils.waiter_with_logging import wait class EmrHook(AwsBaseHook): @@ -158,6 +158,9 @@ def add_job_flow_steps( :param execution_role_arn: The ARN of the runtime role for a step on the cluster. """ config = {} + waiter_delay = waiter_delay or 30 + waiter_max_attempts = waiter_max_attempts or 60 + if execution_role_arn: config["ExecutionRoleArn"] = execution_role_arn response = self.get_conn().add_job_flow_steps(JobFlowId=job_flow_id, Steps=steps, **config) @@ -169,16 +172,23 @@ def add_job_flow_steps( if wait_for_completion: waiter = self.get_conn().get_waiter("step_complete") for step_id in response["StepIds"]: - waiter.wait( - ClusterId=job_flow_id, - StepId=step_id, - WaiterConfig=prune_dict( - { - "Delay": waiter_delay, - "MaxAttempts": waiter_max_attempts, - } - ), - ) + try: + wait( + waiter=waiter, + waiter_max_attempts=waiter_max_attempts, + waiter_delay=waiter_delay, + args={"ClusterId": job_flow_id, "StepId": step_id}, + failure_message=f"EMR Steps failed: {step_id}", + status_message="EMR Step status is", + status_args=["Step.Status.State", "Step.Status.StateChangeReason"], + ) + except AirflowException as ex: + if "EMR Steps failed" in str(ex): + resp = self.get_conn().describe_step(ClusterId=job_flow_id, StepId=step_id) + failure_details = resp["Step"]["Status"].get("FailureDetails", None) + if failure_details: + self.log.error("EMR Steps failed: %s", failure_details) + raise return response["StepIds"] def test_connection(self): diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index f75b12327fe7c..1bf2375a16a2f 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -100,8 +100,8 @@ def __init__( aws_conn_id: str = "aws_default", steps: list[dict] | str | None = None, wait_for_completion: bool = False, - waiter_delay: int | None = None, - waiter_max_attempts: int | None = None, + waiter_delay: int | None = 30, + waiter_max_attempts: int | None = 60, execution_role_arn: str | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, diff --git a/tests/providers/amazon/aws/hooks/test_emr.py b/tests/providers/amazon/aws/hooks/test_emr.py index 0507ad3718d89..c68b25fbdfea2 100644 --- a/tests/providers/amazon/aws/hooks/test_emr.py +++ b/tests/providers/amazon/aws/hooks/test_emr.py @@ -22,6 +22,7 @@ import boto3 import pytest +from botocore.exceptions import WaiterError from moto import mock_emr from airflow.exceptions import AirflowException @@ -113,6 +114,43 @@ def test_add_job_flow_steps_wait_for_completion(self, mock_conn): mock_conn.get_waiter.assert_called_once_with("step_complete") + @mock.patch("time.sleep", return_value=True) + @mock.patch.object(EmrHook, "conn") + def test_add_job_flow_steps_raises_exception_on_failure(self, mock_conn, mock_sleep, caplog): + hook = EmrHook(aws_conn_id="aws_default", emr_conn_id="emr_default", region_name="us-east-1") + mock_conn.describe_step.return_value = { + "Step": { + "Status": { + "State": "FAILED", + "FailureDetails": "test failure details", + } + } + } + mock_conn.add_job_flow_steps.return_value = { + "StepIds": [ + "step_id", + ], + "ResponseMetadata": {"HTTPStatusCode": 200}, + } + steps = [ + { + "ActionOnFailure": "test_step", + "HadoopJarStep": { + "Args": ["test args"], + "Jar": "test.jar", + }, + "Name": "step_1", + } + ] + waiter_error = WaiterError(name="test_error", reason="test_reason", last_response={}) + waiter_error_failure = WaiterError(name="test_error", reason="terminal failure", last_response={}) + mock_conn.get_waiter().wait.side_effect = [waiter_error, waiter_error_failure] + + with pytest.raises(AirflowException): + hook.add_job_flow_steps(job_flow_id="job_flow_id", steps=steps, wait_for_completion=True) + assert "test failure details" in caplog.messages[-1] + mock_conn.get_waiter.assert_called_with("step_complete") + @mock_emr def test_create_job_flow_extra_args(self): """ diff --git a/tests/providers/amazon/aws/operators/test_emr_add_steps.py b/tests/providers/amazon/aws/operators/test_emr_add_steps.py index 0b279c051f9c6..a2c69b659980a 100644 --- a/tests/providers/amazon/aws/operators/test_emr_add_steps.py +++ b/tests/providers/amazon/aws/operators/test_emr_add_steps.py @@ -241,8 +241,8 @@ def test_wait_for_completion(self, mock_add_job_flow_steps, *_): job_flow_id=job_flow_id, steps=[], wait_for_completion=False, - waiter_delay=None, - waiter_max_attempts=None, + waiter_delay=30, + waiter_max_attempts=60, execution_role_arn=None, )