diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 61b53b8bcbf66..2c99d0e07ff6a 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -30,6 +30,7 @@ from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri from airflow.providers.amazon.aws.triggers.emr import ( EmrAddStepsTrigger, + EmrContainerTrigger, EmrCreateJobFlowTrigger, EmrTerminateJobFlowTrigger, ) @@ -480,6 +481,7 @@ class EmrContainerOperator(BaseOperator): Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state. :param tags: The tags assigned to job runs. Defaults to None + :param deferrable: Run operator in the deferrable mode. """ template_fields: Sequence[str] = ( @@ -508,6 +510,7 @@ def __init__( max_tries: int | None = None, tags: dict | None = None, max_polling_attempts: int | None = None, + deferrable: bool = False, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -524,6 +527,7 @@ def __init__( self.max_polling_attempts = max_polling_attempts self.tags = tags self.job_id: str | None = None + self.deferrable = deferrable if max_tries: warnings.warn( @@ -556,6 +560,26 @@ def execute(self, context: Context) -> str | None: self.client_request_token, self.tags, ) + if self.deferrable: + query_status = self.hook.check_query_status(job_id=self.job_id) + self.check_failure(query_status) + if query_status in EmrContainerHook.SUCCESS_STATES: + return self.job_id + timeout = ( + timedelta(seconds=self.max_polling_attempts * self.poll_interval) + if self.max_polling_attempts + else self.execution_timeout + ) + self.defer( + timeout=timeout, + trigger=EmrContainerTrigger( + virtual_cluster_id=self.virtual_cluster_id, + job_id=self.job_id, + aws_conn_id=self.aws_conn_id, + poll_interval=self.poll_interval, + ), + method_name="execute_complete", + ) if self.wait_for_completion: query_status = self.hook.poll_query_status( self.job_id, @@ -563,13 +587,8 @@ def execute(self, context: Context) -> str | None: poll_interval=self.poll_interval, ) - if query_status in EmrContainerHook.FAILURE_STATES: - error_message = self.hook.get_job_failure_reason(self.job_id) - raise AirflowException( - f"EMR Containers job failed. Final state is {query_status}. " - f"query_execution_id is {self.job_id}. Error: {error_message}" - ) - elif not query_status or query_status in EmrContainerHook.INTERMEDIATE_STATES: + self.check_failure(query_status) + if not query_status or query_status in EmrContainerHook.INTERMEDIATE_STATES: raise AirflowException( f"Final state of EMR Containers job is {query_status}. " f"Max tries of poll status exceeded, query_execution_id is {self.job_id}." @@ -577,6 +596,21 @@ def execute(self, context: Context) -> str | None: return self.job_id + def check_failure(self, query_status): + if query_status in EmrContainerHook.FAILURE_STATES: + error_message = self.hook.get_job_failure_reason(self.job_id) + raise AirflowException( + f"EMR Containers job failed. Final state is {query_status}. " + f"query_execution_id is {self.job_id}. Error: {error_message}" + ) + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error while running job: {event}") + + self.log.info("%s", event["message"]) + return event["job_id"] + def on_kill(self) -> None: """Cancel the submitted job run.""" if self.job_id: diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py index f7644e379d003..f21d1aacd87ed 100644 --- a/airflow/providers/amazon/aws/sensors/emr.py +++ b/airflow/providers/amazon/aws/sensors/emr.py @@ -26,7 +26,11 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri -from airflow.providers.amazon.aws.triggers.emr import EmrContainerSensorTrigger +from airflow.providers.amazon.aws.triggers.emr import ( + EmrContainerTrigger, + EmrStepSensorTrigger, + EmrTerminateJobFlowTrigger, +) from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: @@ -310,7 +314,7 @@ def execute(self, context: Context): ) self.defer( timeout=timeout, - trigger=EmrContainerSensorTrigger( + trigger=EmrContainerTrigger( virtual_cluster_id=self.virtual_cluster_id, job_id=self.job_id, aws_conn_id=self.aws_conn_id, @@ -406,9 +410,12 @@ class EmrJobFlowSensor(EmrBaseSensor): :param job_flow_id: job_flow_id to check the state of :param target_states: the target states, sensor waits until - job flow reaches any of these states + job flow reaches any of these states. In deferrable mode it would + run until reach the terminal state. :param failed_states: the failure states, sensor fails when job flow reaches any of these states + :param max_attempts: Maximum number of tries before failing + :param deferrable: Run sensor in the deferrable mode. """ template_fields: Sequence[str] = ("job_flow_id", "target_states", "failed_states") @@ -424,12 +431,16 @@ def __init__( job_flow_id: str, target_states: Iterable[str] | None = None, failed_states: Iterable[str] | None = None, + max_attempts: int = 60, + deferrable: bool = False, **kwargs, ): super().__init__(**kwargs) self.job_flow_id = job_flow_id self.target_states = target_states or ["TERMINATED"] self.failed_states = failed_states or ["TERMINATED_WITH_ERRORS"] + self.max_attempts = max_attempts + self.deferrable = deferrable def get_emr_response(self, context: Context) -> dict[str, Any]: """ @@ -488,6 +499,26 @@ def failure_message_from_response(response: dict[str, Any]) -> str | None: ) return None + def execute(self, context: Context) -> None: + if not self.deferrable: + super().execute(context=context) + elif not self.poke(context): + self.defer( + timeout=timedelta(seconds=self.poke_interval * self.max_attempts), + trigger=EmrTerminateJobFlowTrigger( + job_flow_id=self.job_flow_id, + max_attempts=self.max_attempts, + aws_conn_id=self.aws_conn_id, + poll_interval=int(self.poke_interval), + ), + method_name="execute_complete", + ) + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error while running job: {event}") + self.log.info("Job completed.") + class EmrStepSensor(EmrBaseSensor): """ @@ -503,9 +534,12 @@ class EmrStepSensor(EmrBaseSensor): :param job_flow_id: job_flow_id which contains the step check the state of :param step_id: step to check the state of :param target_states: the target states, sensor waits until - step reaches any of these states + step reaches any of these states. In case of deferrable sensor it will + for reach to terminal state :param failed_states: the failure states, sensor fails when step reaches any of these states + :param max_attempts: Maximum number of tries before failing + :param deferrable: Run sensor in the deferrable mode. """ template_fields: Sequence[str] = ("job_flow_id", "step_id", "target_states", "failed_states") @@ -522,6 +556,8 @@ def __init__( step_id: str, target_states: Iterable[str] | None = None, failed_states: Iterable[str] | None = None, + max_attempts: int = 60, + deferrable: bool = False, **kwargs, ): super().__init__(**kwargs) @@ -529,6 +565,8 @@ def __init__( self.step_id = step_id self.target_states = target_states or ["COMPLETED"] self.failed_states = failed_states or ["CANCELLED", "FAILED", "INTERRUPTED"] + self.max_attempts = max_attempts + self.deferrable = deferrable def get_emr_response(self, context: Context) -> dict[str, Any]: """ @@ -587,3 +625,25 @@ def failure_message_from_response(response: dict[str, Any]) -> str | None: f"with message {fail_details.get('Message')} and log file {fail_details.get('LogFile')}" ) return None + + def execute(self, context: Context) -> None: + if not self.deferrable: + super().execute(context=context) + elif not self.poke(context): + self.defer( + timeout=timedelta(seconds=self.max_attempts * self.poke_interval), + trigger=EmrStepSensorTrigger( + job_flow_id=self.job_flow_id, + step_id=self.step_id, + aws_conn_id=self.aws_conn_id, + max_attempts=self.max_attempts, + poke_interval=int(self.poke_interval), + ), + method_name="execute_complete", + ) + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error while running job: {event}") + + self.log.info("Job completed.") diff --git a/airflow/providers/amazon/aws/triggers/emr.py b/airflow/providers/amazon/aws/triggers/emr.py index 6ea2c3b25cd61..632144c8509cb 100644 --- a/airflow/providers/amazon/aws/triggers/emr.py +++ b/airflow/providers/amazon/aws/triggers/emr.py @@ -24,6 +24,7 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook +from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.utils.helpers import prune_dict @@ -249,7 +250,7 @@ async def run(self): ) -class EmrContainerSensorTrigger(BaseTrigger): +class EmrContainerTrigger(BaseTrigger): """ Poll for the status of EMR container until reaches terminal state. @@ -278,9 +279,9 @@ def hook(self) -> EmrContainerHook: return EmrContainerHook(self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id) def serialize(self) -> tuple[str, dict[str, Any]]: - """Serializes EmrContainerSensorTrigger arguments and classpath.""" + """Serializes EmrContainerTrigger arguments and classpath.""" return ( - "airflow.providers.amazon.aws.triggers.emr.EmrContainerSensorTrigger", + "airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger", { "virtual_cluster_id": self.virtual_cluster_id, "job_id": self.job_id, @@ -317,3 +318,63 @@ async def run(self) -> AsyncIterator[TriggerEvent]: await asyncio.sleep(int(self.poll_interval)) yield TriggerEvent({"status": "success", "job_id": self.job_id}) + + +class EmrStepSensorTrigger(BaseTrigger): + """ + Poll for the status of EMR container until reaches terminal state. + + :param job_flow_id: job_flow_id which contains the step check the state of + :param step_id: step to check the state of + :param aws_conn_id: Reference to AWS connection id + :param max_attempts: The maximum number of attempts to be made + :param poke_interval: polling period in seconds to check for the status + """ + + def __init__( + self, + job_flow_id: str, + step_id: str, + aws_conn_id: str = "aws_default", + max_attempts: int = 60, + poke_interval: int = 30, + **kwargs: Any, + ): + self.job_flow_id = job_flow_id + self.step_id = step_id + self.aws_conn_id = aws_conn_id + self.max_attempts = max_attempts + self.poke_interval = poke_interval + super().__init__(**kwargs) + + @cached_property + def hook(self) -> EmrHook: + return EmrHook(self.aws_conn_id) + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.amazon.aws.triggers.emr.EmrStepSensorTrigger", + { + "job_flow_id": self.job_flow_id, + "step_id": self.step_id, + "aws_conn_id": self.aws_conn_id, + "max_attempts": self.max_attempts, + "poke_interval": self.poke_interval, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + + async with self.hook.async_conn as client: + waiter = client.get_waiter("step_wait_for_terminal", deferrable=True, client=client) + await async_wait( + waiter=waiter, + waiter_delay=self.poke_interval, + waiter_max_attempts=self.max_attempts, + args={"ClusterId": self.job_flow_id, "StepId": self.step_id}, + failure_message=f"Error while waiting for step {self.step_id} to complete", + status_message=f"Step id: {self.step_id}, Step is still in non-terminal state", + status_args=["Step.Status.State"], + ) + + yield TriggerEvent({"status": "success"}) diff --git a/airflow/providers/amazon/aws/waiters/emr.json b/airflow/providers/amazon/aws/waiters/emr.json index d27cd08e00145..33a90c77514d1 100644 --- a/airflow/providers/amazon/aws/waiters/emr.json +++ b/airflow/providers/amazon/aws/waiters/emr.json @@ -94,6 +94,37 @@ "state": "failure" } ] + }, + "step_wait_for_terminal": { + "operation": "DescribeStep", + "delay": 30, + "maxAttempts": 60, + "acceptors": [ + { + "matcher": "path", + "argument": "Step.Status.State", + "expected": "COMPLETED", + "state": "success" + }, + { + "matcher": "path", + "argument": "Step.Status.State", + "expected": "CANCELLED", + "state": "failure" + }, + { + "matcher": "path", + "argument": "Step.Status.State", + "expected": "FAILED", + "state": "failure" + }, + { + "matcher": "path", + "argument": "Step.Status.State", + "expected": "INTERRUPTED", + "state": "failure" + } + ] } } } diff --git a/tests/providers/amazon/aws/hooks/test_emr.py b/tests/providers/amazon/aws/hooks/test_emr.py index 8ff9bc5a36edf..0507ad3718d89 100644 --- a/tests/providers/amazon/aws/hooks/test_emr.py +++ b/tests/providers/amazon/aws/hooks/test_emr.py @@ -32,7 +32,13 @@ class TestEmrHook: def test_service_waiters(self): hook = EmrHook(aws_conn_id=None) official_waiters = hook.conn.waiter_names - custom_waiters = ["job_flow_waiting", "job_flow_terminated", "notebook_running", "notebook_stopped"] + custom_waiters = [ + "job_flow_waiting", + "job_flow_terminated", + "notebook_running", + "notebook_stopped", + "step_wait_for_terminal", + ] assert sorted(hook.list_waiters()) == sorted([*official_waiters, *custom_waiters]) diff --git a/tests/providers/amazon/aws/operators/test_emr_containers.py b/tests/providers/amazon/aws/operators/test_emr_containers.py index ddc11b15c56ce..00a2eb22aae2e 100644 --- a/tests/providers/amazon/aws/operators/test_emr_containers.py +++ b/tests/providers/amazon/aws/operators/test_emr_containers.py @@ -22,9 +22,10 @@ import pytest from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator, EmrEksCreateClusterOperator +from airflow.providers.amazon.aws.triggers.emr import EmrContainerTrigger SUBMIT_JOB_SUCCESS_RETURN = { "ResponseMetadata": {"HTTPStatusCode": 200}, @@ -144,6 +145,20 @@ def test_execute_with_polling_timeout(self, mock_check_query_status): assert "Final state of EMR Containers job is SUBMITTED" in str(ctx.value) assert "Max tries of poll status exceeded" in str(ctx.value) + @mock.patch.object(EmrContainerHook, "submit_job") + @mock.patch.object( + EmrContainerHook, "check_query_status", return_value=EmrContainerHook.INTERMEDIATE_STATES[0] + ) + def test_operator_defer(self, mock_submit_job, mock_check_query_status): + """Test the execute method raise TaskDeferred if running operator in deferrable mode""" + self.emr_container.deferrable = True + self.emr_container.wait_for_completion = False + with pytest.raises(TaskDeferred) as exc: + self.emr_container.execute(context=None) + assert isinstance( + exc.value.trigger, EmrContainerTrigger + ), f"{exc.value.trigger} is not a EmrContainerTrigger" + class TestEmrEksCreateClusterOperator: def setup_method(self): diff --git a/tests/providers/amazon/aws/sensors/test_emr_containers.py b/tests/providers/amazon/aws/sensors/test_emr_containers.py index 0df3657288b9b..606281e70a620 100644 --- a/tests/providers/amazon/aws/sensors/test_emr_containers.py +++ b/tests/providers/amazon/aws/sensors/test_emr_containers.py @@ -24,7 +24,7 @@ from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook from airflow.providers.amazon.aws.sensors.emr import EmrContainerSensor -from airflow.providers.amazon.aws.triggers.emr import EmrContainerSensorTrigger +from airflow.providers.amazon.aws.triggers.emr import EmrContainerTrigger class TestEmrContainerSensor: @@ -79,8 +79,8 @@ def test_poke_cancel_pending(self, mock_check_query_status): def test_sensor_defer(self, mock_poke): self.sensor.deferrable = True mock_poke.return_value = False - with pytest.raises(TaskDeferred) as exc: + with pytest.raises(TaskDeferred) as e: self.sensor.execute(context=None) assert isinstance( - exc.value.trigger, EmrContainerSensorTrigger - ), "Trigger is not a EmrContainerSensorTrigger" + e.value.trigger, EmrContainerTrigger + ), f"{e.value.trigger} is not a EmrContainerTrigger" diff --git a/tests/providers/amazon/aws/sensors/test_emr_job_flow.py b/tests/providers/amazon/aws/sensors/test_emr_job_flow.py index 2b5eaf8f749af..ffad3c0ce54c9 100644 --- a/tests/providers/amazon/aws/sensors/test_emr_job_flow.py +++ b/tests/providers/amazon/aws/sensors/test_emr_job_flow.py @@ -24,9 +24,10 @@ import pytest from dateutil.tz import tzlocal -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.sensors.emr import EmrJobFlowSensor +from airflow.providers.amazon.aws.triggers.emr import EmrTerminateJobFlowTrigger DESCRIBE_CLUSTER_STARTING_RETURN = { "Cluster": { @@ -276,3 +277,21 @@ def test_different_target_states(self): # make sure it was called with the job_flow_id calls = [mock.call(ClusterId="j-8989898989")] self.mock_emr_client.describe_cluster.assert_has_calls(calls) + + @mock.patch("airflow.providers.amazon.aws.sensors.emr.EmrJobFlowSensor.poke") + def test_sensor_defer(self, mock_poke): + """Test the execute method raise TaskDeferred if running sensor in deferrable mode""" + sensor = EmrJobFlowSensor( + task_id="test_task", + poke_interval=0, + job_flow_id="j-8989898989", + aws_conn_id="aws_default", + target_states=["RUNNING", "WAITING"], + deferrable=True, + ) + mock_poke.return_value = False + with pytest.raises(TaskDeferred) as exc: + sensor.execute(context=None) + assert isinstance( + exc.value.trigger, EmrTerminateJobFlowTrigger + ), f"{exc.value.trigger} is not a EmrTerminateJobFlowTrigger " diff --git a/tests/providers/amazon/aws/sensors/test_emr_step.py b/tests/providers/amazon/aws/sensors/test_emr_step.py index 7c9569d6002f3..8387207ba1e02 100644 --- a/tests/providers/amazon/aws/sensors/test_emr_step.py +++ b/tests/providers/amazon/aws/sensors/test_emr_step.py @@ -24,10 +24,11 @@ import pytest from dateutil.tz import tzlocal -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink from airflow.providers.amazon.aws.sensors.emr import EmrStepSensor +from airflow.providers.amazon.aws.triggers.emr import EmrStepSensorTrigger DESCRIBE_JOB_STEP_RUNNING_RETURN = { "ResponseMetadata": {"HTTPStatusCode": 200, "RequestId": "8dee8db2-3719-11e6-9e20-35b2f861a2a6"}, @@ -230,3 +231,20 @@ def test_step_interrupted(self, *_): mock_isinstance.return_value = True with pytest.raises(AirflowException): self.sensor.execute(None) + + @mock.patch("airflow.providers.amazon.aws.sensors.emr.EmrStepSensor.poke") + def test_sensor_defer(self, mock_poke): + """Test the execute method raise TaskDeferred if running sensor in deferrable mode""" + sensor = EmrStepSensor( + task_id="test_task", + poke_interval=0, + job_flow_id="j-8989898989", + step_id="s-VK57YR1Z9Z5N", + aws_conn_id="aws_default", + deferrable=True, + ) + + mock_poke.return_value = False + with pytest.raises(TaskDeferred) as exc: + sensor.execute(context=None) + assert isinstance(exc.value.trigger, EmrStepSensorTrigger), "Trigger is not a EmrStepSensorTrigger" diff --git a/tests/providers/amazon/aws/triggers/test_emr.py b/tests/providers/amazon/aws/triggers/test_emr.py index 86e54cb94ae23..ff57eab4a5d15 100644 --- a/tests/providers/amazon/aws/triggers/test_emr.py +++ b/tests/providers/amazon/aws/triggers/test_emr.py @@ -25,8 +25,9 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.emr import EmrHook from airflow.providers.amazon.aws.triggers.emr import ( - EmrContainerSensorTrigger, + EmrContainerTrigger, EmrCreateJobFlowTrigger, + EmrStepSensorTrigger, EmrTerminateJobFlowTrigger, ) from airflow.triggers.base import TriggerEvent @@ -39,6 +40,8 @@ JOB_ID = "job-1234" AWS_CONN_ID = "aws_emr_conn" POLL_INTERVAL = 60 +TARGET_STATE = ["TERMINATED"] +STEP_ID = "s-1234" class TestEmrCreateJobFlowTrigger: @@ -360,16 +363,16 @@ async def test_emr_terminate_job_flow_trigger_run_attempts_failed( assert mock_get_waiter().wait.call_count == 3 -class TestEmrContainerSensorTrigger: - def test_emr_container_sensor_trigger_serialize(self): - emr_trigger = EmrContainerSensorTrigger( +class TestEmrContainerTrigger: + def test_emr_container_trigger_serialize(self): + emr_trigger = EmrContainerTrigger( virtual_cluster_id=VIRTUAL_CLUSTER_ID, job_id=JOB_ID, aws_conn_id=AWS_CONN_ID, poll_interval=POLL_INTERVAL, ) class_path, args = emr_trigger.serialize() - assert class_path == "airflow.providers.amazon.aws.triggers.emr.EmrContainerSensorTrigger" + assert class_path == "airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger" assert args["virtual_cluster_id"] == VIRTUAL_CLUSTER_ID assert args["job_id"] == JOB_ID assert args["aws_conn_id"] == AWS_CONN_ID @@ -384,7 +387,7 @@ async def test_emr_container_trigger_run(self, mock_async_conn, mock_get_waiter) mock_get_waiter().wait = AsyncMock() - emr_trigger = EmrContainerSensorTrigger( + emr_trigger = EmrContainerTrigger( virtual_cluster_id=VIRTUAL_CLUSTER_ID, job_id=JOB_ID, aws_conn_id=AWS_CONN_ID, @@ -412,7 +415,7 @@ async def test_emr_trigger_run_multiple_attempts(self, mock_async_conn, mock_get mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, error, True]) mock_sleep.return_value = True - emr_trigger = EmrContainerSensorTrigger( + emr_trigger = EmrContainerTrigger( virtual_cluster_id=VIRTUAL_CLUSTER_ID, job_id=JOB_ID, aws_conn_id=AWS_CONN_ID, @@ -448,7 +451,7 @@ async def test_emr_trigger_run_attempts_failed(self, mock_async_conn, mock_get_w ) mock_sleep.return_value = True - emr_trigger = EmrContainerSensorTrigger( + emr_trigger = EmrContainerTrigger( virtual_cluster_id=VIRTUAL_CLUSTER_ID, job_id=JOB_ID, aws_conn_id=AWS_CONN_ID, @@ -460,3 +463,108 @@ async def test_emr_trigger_run_attempts_failed(self, mock_async_conn, mock_get_w assert mock_get_waiter().wait.call_count == 3 assert response == TriggerEvent({"status": "failure", "message": f"Job Failed: {error_failed}"}) + + +class TestEmrStepSensorTrigger: + def test_emr_step_trigger_serialize(self): + """Test trigger serialize object and path as expected""" + emr_trigger = EmrStepSensorTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + step_id=STEP_ID, + aws_conn_id=AWS_CONN_ID, + poke_interval=POLL_INTERVAL, + ) + class_path, args = emr_trigger.serialize() + assert class_path == "airflow.providers.amazon.aws.triggers.emr.EmrStepSensorTrigger" + assert args["job_flow_id"] == TEST_JOB_FLOW_ID + assert args["step_id"] == STEP_ID + assert args["aws_conn_id"] == AWS_CONN_ID + assert args["max_attempts"] == 60 + assert args["poke_interval"] == POLL_INTERVAL + + @pytest.mark.asyncio + @mock.patch.object(EmrHook, "async_conn") + async def test_emr_step_trigger_run(self, mock_async_conn): + """Test trigger emit success if condition met""" + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + + a_mock.get_waiter().wait = AsyncMock() + + emr_trigger = EmrStepSensorTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + step_id=STEP_ID, + aws_conn_id=AWS_CONN_ID, + poll_interval=POLL_INTERVAL, + ) + + generator = emr_trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "success"}) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(EmrHook, "async_conn") + async def test_emr_trigger_run_multiple_attempts(self, mock_async_conn, mock_sleep): + """Test trigger try max attempt if attempt not exceeded and job still running""" + mock_sleep.return_value = True + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"Step": {"Status": {"State": "RUNNING"}}}, + ) + a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, error, True]) + + emr_trigger = EmrStepSensorTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + step_id=STEP_ID, + aws_conn_id=AWS_CONN_ID, + poll_interval=POLL_INTERVAL, + ) + + generator = emr_trigger.run() + response = await generator.asend(None) + + assert a_mock.get_waiter().wait.call_count == 4 + assert response == TriggerEvent({"status": "success"}) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(EmrHook, "async_conn") + async def test_emr_trigger_run_attempts_failed(self, mock_async_conn, mock_sleep): + """Test trigger does fail if max attempt exceeded and job still not succeeded""" + mock_sleep.return_value = True + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + + error_running = WaiterError( + name="test_name", + reason="test reason", + last_response={"Step": {"Status": {"State": "RUNNING"}}}, + ) + error_failed = WaiterError( + name="test_name", + reason="Waiter encountered a terminal failure state", + last_response={"Step": {"Status": {"State": "CANCELLED"}}}, + ) + + a_mock.get_waiter().wait = AsyncMock(side_effect=[error_running, error_failed]) + mock_sleep.return_value = True + + emr_trigger = EmrStepSensorTrigger( + job_flow_id=TEST_JOB_FLOW_ID, + step_id=STEP_ID, + aws_conn_id=AWS_CONN_ID, + poll_interval=POLL_INTERVAL, + ) + + with pytest.raises(AirflowException) as exc: + generator = emr_trigger.run() + await generator.asend(None) + + assert a_mock.get_waiter().wait.call_count == 2 + assert "Error while waiting for step s-1234 to complete" in str(exc.value)