diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index 88feb01311784..b9b3322c49102 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -38,7 +38,10 @@ BatchJobQueueLink, ) from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink -from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger +from airflow.providers.amazon.aws.triggers.batch import ( + BatchCreateComputeEnvironmentTrigger, + BatchOperatorTrigger, +) from airflow.providers.amazon.aws.utils import trim_none_values from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher @@ -402,14 +405,16 @@ class BatchCreateComputeEnvironmentOperator(BaseOperator): services on your behalf (templated). :param tags: Tags that you apply to the compute-environment to help you categorize and organize your resources. - :param max_retries: Exponential back-off retries, 4200 = 48 hours; polling - is only used when waiters is None. - :param status_retries: Number of HTTP retries to get job status, 10; polling - is only used when waiters is None. + :param poll_interval: How long to wait in seconds between 2 polls at the environment status. + Only useful when deferrable is True. + :param max_retries: How many times to poll for the environment status. + Only useful when deferrable is True. :param aws_conn_id: Connection ID of AWS credentials / region name. If None, credential boto3 strategy will be used. :param region_name: Region name to use in AWS Hook. Overrides the ``region_name`` in connection if provided. + :param deferrable: If True, the operator will wait asynchronously for the environment to be created. + This mode requires aiobotocore module to be installed. (default: False) """ template_fields: Sequence[str] = ( @@ -428,13 +433,24 @@ def __init__( unmanaged_v_cpus: int | None = None, service_role: str | None = None, tags: dict | None = None, + poll_interval: int = 30, max_retries: int | None = None, - status_retries: int | None = None, aws_conn_id: str | None = None, region_name: str | None = None, + deferrable: bool = False, **kwargs, ): + if "status_retries" in kwargs: + warnings.warn( + "The `status_retries` parameter is unused and should be removed. " + "It'll be deleted in a future version.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + kwargs.pop("status_retries") # remove before calling super() to prevent unexpected arg error + super().__init__(**kwargs) + self.compute_environment_name = compute_environment_name self.environment_type = environment_type self.state = state @@ -442,17 +458,16 @@ def __init__( self.compute_resources = compute_resources self.service_role = service_role self.tags = tags or {} - self.max_retries = max_retries - self.status_retries = status_retries + self.poll_interval = poll_interval + self.max_retries = max_retries or 120 self.aws_conn_id = aws_conn_id self.region_name = region_name + self.deferrable = deferrable @cached_property def hook(self): """Create and return a BatchClientHook.""" return BatchClientHook( - max_retries=self.max_retries, - status_retries=self.status_retries, aws_conn_id=self.aws_conn_id, region_name=self.region_name, ) @@ -468,6 +483,21 @@ def execute(self, context: Context): "serviceRole": self.service_role, "tags": self.tags, } - self.hook.client.create_compute_environment(**trim_none_values(kwargs)) + response = self.hook.client.create_compute_environment(**trim_none_values(kwargs)) + arn = response["computeEnvironmentArn"] + + if self.deferrable: + self.defer( + trigger=BatchCreateComputeEnvironmentTrigger( + arn, self.poll_interval, self.max_retries, self.aws_conn_id, self.region_name + ), + method_name="execute_complete", + ) self.log.info("AWS Batch compute environment created successfully") + return arn + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error while waiting for the compute environment to be ready: {event}") + return event["value"] diff --git a/airflow/providers/amazon/aws/triggers/batch.py b/airflow/providers/amazon/aws/triggers/batch.py index f4a5de15254fa..b0bdbc0d4578b 100644 --- a/airflow/providers/amazon/aws/triggers/batch.py +++ b/airflow/providers/amazon/aws/triggers/batch.py @@ -23,6 +23,7 @@ from botocore.exceptions import WaiterError from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook +from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -188,3 +189,59 @@ async def run(self): "message": f"Job {self.job_id} Succeeded", } ) + + +class BatchCreateComputeEnvironmentTrigger(BaseTrigger): + """ + Trigger for BatchCreateComputeEnvironmentOperator. + The trigger will asynchronously poll the boto3 API and wait for the compute environment to be ready. + + :param job_id: A unique identifier for the cluster. + :param max_retries: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region_name: region name to use in AWS Hook + :param poll_interval: The amount of time in seconds to wait between attempts. + """ + + def __init__( + self, + compute_env_arn: str | None = None, + poll_interval: int = 30, + max_retries: int = 10, + aws_conn_id: str | None = "aws_default", + region_name: str | None = None, + ): + super().__init__() + self.compute_env_arn = compute_env_arn + self.max_retries = max_retries + self.aws_conn_id = aws_conn_id + self.region_name = region_name + self.poll_interval = poll_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes BatchOperatorTrigger arguments and classpath.""" + return ( + self.__class__.__module__ + "." + self.__class__.__qualname__, + { + "compute_env_arn": self.compute_env_arn, + "max_retries": self.max_retries, + "aws_conn_id": self.aws_conn_id, + "region_name": self.region_name, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self): + hook = BatchClientHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + async with hook.async_conn as client: + waiter = hook.get_waiter("compute_env_ready", deferrable=True, client=client) + await async_wait( + waiter=waiter, + waiter_delay=self.poll_interval, + waiter_max_attempts=self.max_retries, + args={"computeEnvironments": [self.compute_env_arn]}, + failure_message="Failure while creating Compute Environment", + status_message="Compute Environment not ready yet", + status_args=["computeEnvironments[].status", "computeEnvironments[].statusReason"], + ) + yield TriggerEvent({"status": "success", "value": self.compute_env_arn}) diff --git a/airflow/providers/amazon/aws/waiters/batch.json b/airflow/providers/amazon/aws/waiters/batch.json index fa9752ea14c41..3fbdd433771c8 100644 --- a/airflow/providers/amazon/aws/waiters/batch.json +++ b/airflow/providers/amazon/aws/waiters/batch.json @@ -20,6 +20,32 @@ "state": "failed" } ] + }, + + "compute_env_ready": { + "delay": 30, + "operation": "DescribeComputeEnvironments", + "maxAttempts": 100, + "acceptors": [ + { + "argument": "computeEnvironments[].status", + "expected": "VALID", + "matcher": "pathAll", + "state": "success" + }, + { + "argument": "computeEnvironments[].status", + "expected": "INVALID", + "matcher": "pathAny", + "state": "failed" + }, + { + "argument": "computeEnvironments[].status", + "expected": "DELETED", + "matcher": "pathAny", + "state": "failed" + } + ] } } } diff --git a/tests/providers/amazon/aws/operators/test_batch.py b/tests/providers/amazon/aws/operators/test_batch.py index a65e00d8db04b..3aace0bb3e5bb 100644 --- a/tests/providers/amazon/aws/operators/test_batch.py +++ b/tests/providers/amazon/aws/operators/test_batch.py @@ -27,7 +27,10 @@ from airflow.providers.amazon.aws.operators.batch import BatchCreateComputeEnvironmentOperator, BatchOperator # Use dummy AWS credentials -from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger +from airflow.providers.amazon.aws.triggers.batch import ( + BatchCreateComputeEnvironmentTrigger, + BatchOperatorTrigger, +) AWS_REGION = "eu-west-1" AWS_ACCESS_KEY_ID = "airflow_dummy_key" @@ -326,3 +329,26 @@ def test_execute(self, mock_conn): computeResources=compute_resources, tags=tags, ) + + @mock.patch.object(BatchClientHook, "client") + def test_defer(self, client_mock): + client_mock.create_compute_environment.return_value = {"computeEnvironmentArn": "my_arn"} + + operator = BatchCreateComputeEnvironmentOperator( + task_id="task", + compute_environment_name="my_env_name", + environment_type="my_env_type", + state="my_state", + compute_resources={}, + max_retries=123456, + poll_interval=456789, + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as deferred: + operator.execute(None) + + assert isinstance(deferred.value.trigger, BatchCreateComputeEnvironmentTrigger) + assert deferred.value.trigger.compute_env_arn == "my_arn" + assert deferred.value.trigger.poll_interval == 456789 + assert deferred.value.trigger.max_retries == 123456 diff --git a/tests/providers/amazon/aws/triggers/test_batch.py b/tests/providers/amazon/aws/triggers/test_batch.py index 5cf125f8280a5..e33736076237f 100644 --- a/tests/providers/amazon/aws/triggers/test_batch.py +++ b/tests/providers/amazon/aws/triggers/test_batch.py @@ -22,7 +22,13 @@ import pytest from botocore.exceptions import WaiterError -from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger, BatchSensorTrigger +from airflow import AirflowException +from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook +from airflow.providers.amazon.aws.triggers.batch import ( + BatchCreateComputeEnvironmentTrigger, + BatchOperatorTrigger, + BatchSensorTrigger, +) from airflow.triggers.base import TriggerEvent BATCH_JOB_ID = "job_id" @@ -181,3 +187,38 @@ async def test_batch_sensor_trigger_failure( assert actual_response == TriggerEvent( {"status": "failure", "message": f"Job Failed: Waiter {name} failed: {reason}"} ) + + +class TestBatchCreateComputeEnvironmentTrigger: + @pytest.mark.asyncio + @mock.patch.object(BatchClientHook, "async_conn") + @mock.patch.object(BatchClientHook, "get_waiter") + async def test_success(self, get_waiter_mock, conn_mock): + get_waiter_mock().wait = AsyncMock( + side_effect=[ + WaiterError( + "situation normal", "first try", {"computeEnvironments": [{"status": "my_status"}]} + ), + {}, + ] + ) + trigger = BatchCreateComputeEnvironmentTrigger("my_arn", poll_interval=0, max_retries=3) + + generator = trigger.run() + response: TriggerEvent = await generator.asend(None) + + assert response.payload["status"] == "success" + assert response.payload["value"] == "my_arn" + + @pytest.mark.asyncio + @mock.patch.object(BatchClientHook, "async_conn") + @mock.patch.object(BatchClientHook, "get_waiter") + async def test_failure(self, get_waiter_mock, conn_mock): + get_waiter_mock().wait = AsyncMock( + side_effect=[WaiterError("terminal failure", "terminal failure reason", {})] + ) + trigger = BatchCreateComputeEnvironmentTrigger("my_arn", poll_interval=0, max_retries=3) + + with pytest.raises(AirflowException): + generator = trigger.run() + await generator.asend(None)