diff --git a/airflow/providers/amazon/aws/operators/eks.py b/airflow/providers/amazon/aws/operators/eks.py index 75fb48fdfeef4..eb6ebb5b0fdde 100644 --- a/airflow/providers/amazon/aws/operators/eks.py +++ b/airflow/providers/amazon/aws/operators/eks.py @@ -31,6 +31,7 @@ from airflow.providers.amazon.aws.triggers.eks import ( EksCreateFargateProfileTrigger, EksDeleteFargateProfileTrigger, + EksNodegroupTrigger, ) from airflow.providers.amazon.aws.utils.waiter_with_logging import wait @@ -183,8 +184,8 @@ class EksCreateClusterOperator(BaseOperator): :param fargate_selectors: The selectors to match for pods to use this AWS Fargate profile. (templated) :param create_fargate_profile_kwargs: Optional parameters to pass to the CreateFargateProfile API (templated) - :param waiter_delay: Time (in seconds) to wait between two consecutive calls to check cluster status - :param waiter_max_attempts: The maximum number of attempts to check the status of the cluster. + :param waiter_delay: Time (in seconds) to wait between two consecutive calls to check cluster state + :param waiter_max_attempts: The maximum number of attempts to check cluster state """ @@ -333,8 +334,11 @@ class EksCreateNodegroupOperator(BaseOperator): maintained on each worker node). :param region: Which AWS region the connection should use. (templated) If this is None or empty then the default boto3 behaviour is used. - :param waiter_delay: Time (in seconds) to wait between two consecutive calls to check nodegroup status - :param waiter_max_attempts: The maximum number of attempts to check the status of the nodegroup. + :param waiter_delay: Time (in seconds) to wait between two consecutive calls to check nodegroup state + :param waiter_max_attempts: The maximum number of attempts to check nodegroup state + :param deferrable: If True, the operator will wait asynchronously for the nodegroup to be created. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) """ @@ -361,6 +365,7 @@ def __init__( region: str | None = None, waiter_delay: int = 30, waiter_max_attempts: int = 80, + deferrable: bool = False, **kwargs, ) -> None: self.nodegroup_subnets = nodegroup_subnets @@ -369,15 +374,13 @@ def __init__( self.nodegroup_role_arn = nodegroup_role_arn self.nodegroup_name = nodegroup_name self.create_nodegroup_kwargs = create_nodegroup_kwargs or {} - self.wait_for_completion = wait_for_completion + self.wait_for_completion = False if deferrable else wait_for_completion self.aws_conn_id = aws_conn_id self.region = region self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts - - super().__init__( - **kwargs, - ) + self.deferrable = deferrable + super().__init__(**kwargs) def execute(self, context: Context): self.log.info(self.task_id) @@ -393,6 +396,7 @@ def execute(self, context: Context): self.nodegroup_subnets, ) self.nodegroup_subnets = nodegroup_subnets_list + _create_compute( compute=self.compute, cluster_name=self.cluster_name, @@ -407,6 +411,28 @@ def execute(self, context: Context): subnets=self.nodegroup_subnets, ) + if self.deferrable: + self.defer( + trigger=EksNodegroupTrigger( + waiter_name="nodegroup_active", + cluster_name=self.cluster_name, + nodegroup_name=self.nodegroup_name, + aws_conn_id=self.aws_conn_id, + region=self.region, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + ), + method_name="execute_complete", + # timeout is set to ensure that if a trigger dies, the timeout does not restart + # 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent) + timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60), + ) + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error creating nodegroup: {event}") + return + class EksCreateFargateProfileOperator(BaseOperator): """ @@ -638,6 +664,11 @@ class EksDeleteNodegroupOperator(BaseOperator): maintained on each worker node). :param region: Which AWS region the connection should use. (templated) If this is None or empty then the default boto3 behaviour is used. + :param waiter_delay: Time (in seconds) to wait between two consecutive calls to check nodegroup state + :param waiter_max_attempts: The maximum number of attempts to check nodegroup state + :param deferrable: If True, the operator will wait asynchronously for the nodegroup to be deleted. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) """ @@ -656,6 +687,9 @@ def __init__( wait_for_completion: bool = False, aws_conn_id: str = DEFAULT_CONN_ID, region: str | None = None, + waiter_delay: int = 30, + waiter_max_attempts: int = 40, + deferrable: bool = False, **kwargs, ) -> None: self.cluster_name = cluster_name @@ -663,6 +697,9 @@ def __init__( self.wait_for_completion = wait_for_completion self.aws_conn_id = aws_conn_id self.region = region + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable super().__init__(**kwargs) def execute(self, context: Context): @@ -672,12 +709,33 @@ def execute(self, context: Context): ) eks_hook.delete_nodegroup(clusterName=self.cluster_name, nodegroupName=self.nodegroup_name) - if self.wait_for_completion: + if self.deferrable: + self.defer( + trigger=EksNodegroupTrigger( + waiter_name="nodegroup_deleted", + cluster_name=self.cluster_name, + nodegroup_name=self.nodegroup_name, + aws_conn_id=self.aws_conn_id, + region=self.region, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + ), + method_name="execute_complete", + # timeout is set to ensure that if a trigger dies, the timeout does not restart + # 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent) + timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60), + ) + elif self.wait_for_completion: self.log.info("Waiting for nodegroup to delete. This will take some time.") eks_hook.conn.get_waiter("nodegroup_deleted").wait( clusterName=self.cluster_name, nodegroupName=self.nodegroup_name ) + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error deleting nodegroup: {event}") + return + class EksDeleteFargateProfileOperator(BaseOperator): """ diff --git a/airflow/providers/amazon/aws/triggers/eks.py b/airflow/providers/amazon/aws/triggers/eks.py index 8ccd88167cc63..be5d50ab4cafb 100644 --- a/airflow/providers/amazon/aws/triggers/eks.py +++ b/airflow/providers/amazon/aws/triggers/eks.py @@ -23,6 +23,7 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.eks import EksHook +from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -164,3 +165,75 @@ async def run(self): ) else: yield TriggerEvent({"status": "success", "message": "Fargate Profile Deleted"}) + + +class EksNodegroupTrigger(BaseTrigger): + """ + Trigger for EksCreateNodegroupOperator and EksDeleteNodegroupOperator. + + The trigger will asynchronously poll the boto3 API and wait for the + nodegroup to be in the state specified by the waiter. + + :param waiter_name: Name of the waiter to use, for instance 'nodegroup_active' or 'nodegroup_deleted' + :param cluster_name: The name of the EKS cluster associated with the node group. + :param nodegroup_name: The name of the nodegroup to check. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region: Which AWS region the connection should use. (templated) + If this is None or empty then the default boto3 behaviour is used. + """ + + def __init__( + self, + waiter_name: str, + cluster_name: str, + nodegroup_name: str, + waiter_delay: int, + waiter_max_attempts: int, + aws_conn_id: str, + region: str | None, + ): + self.waiter_name = waiter_name + self.cluster_name = cluster_name + self.nodegroup_name = nodegroup_name + self.aws_conn_id = aws_conn_id + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.region = region + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + self.__class__.__module__ + "." + self.__class__.__qualname__, + { + "waiter_name": self.waiter_name, + "cluster_name": self.cluster_name, + "nodegroup_name": self.nodegroup_name, + "waiter_delay": str(self.waiter_delay), + "waiter_max_attempts": str(self.waiter_max_attempts), + "aws_conn_id": self.aws_conn_id, + "region": self.region, + }, + ) + + async def run(self): + self.hook = EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region) + async with self.hook.async_conn as client: + waiter = client.get_waiter(self.waiter_name) + await async_wait( + waiter=waiter, + waiter_max_attempts=int(self.waiter_max_attempts), + waiter_delay=int(self.waiter_delay), + args={"clusterName": self.cluster_name, "nodegroupName": self.nodegroup_name}, + failure_message="Error checking nodegroup", + status_message="Nodegroup status is", + status_args=["nodegroup.status"], + ) + + yield TriggerEvent( + { + "status": "success", + "cluster_name": self.cluster_name, + "nodegroup_name": self.nodegroup_name, + } + ) diff --git a/docs/apache-airflow-providers-amazon/operators/eks.rst b/docs/apache-airflow-providers-amazon/operators/eks.rst index e11b721be37b5..cff682db2aa8d 100644 --- a/docs/apache-airflow-providers-amazon/operators/eks.rst +++ b/docs/apache-airflow-providers-amazon/operators/eks.rst @@ -121,6 +121,7 @@ Create an Amazon EKS managed node group To create an Amazon EKS managed node group you can use :class:`~airflow.providers.amazon.aws.operators.eks.EksCreateNodegroupOperator`. +You can also run this operator in deferrable mode by setting ``deferrable`` param to ``True``. Note: An AWS IAM role with the following permissions is required: ``ec2.amazon.aws.com`` must be in the Trusted Relationships @@ -140,6 +141,7 @@ Delete an Amazon EKS managed node group To delete an existing Amazon EKS managed node group you can use :class:`~airflow.providers.amazon.aws.operators.eks.EksDeleteNodegroupOperator`. +You can also run this operator in deferrable mode by setting ``deferrable`` param to ``True``. .. exampleinclude:: /../../tests/system/providers/amazon/aws/example_eks_with_nodegroups.py :language: python diff --git a/tests/providers/amazon/aws/operators/test_eks.py b/tests/providers/amazon/aws/operators/test_eks.py index 311aad972d8a6..5534f635d8f30 100644 --- a/tests/providers/amazon/aws/operators/test_eks.py +++ b/tests/providers/amazon/aws/operators/test_eks.py @@ -37,6 +37,7 @@ from airflow.providers.amazon.aws.triggers.eks import ( EksCreateFargateProfileTrigger, EksDeleteFargateProfileTrigger, + EksNodegroupTrigger, ) from airflow.typing_compat import TypedDict from tests.providers.amazon.aws.utils.eks_test_constants import ( @@ -476,6 +477,36 @@ def test_execute_with_wait_when_nodegroup_does_not_already_exist( mock_waiter.assert_called_with(mock.ANY, clusterName=CLUSTER_NAME, nodegroupName=NODEGROUP_NAME) assert_expected_waiter_type(mock_waiter, "NodegroupActive") + @mock.patch.object(EksHook, "create_nodegroup") + def test_create_nodegroup_deferrable(self, mock_create_nodegroup): + mock_create_nodegroup.return_value = True + op_kwargs = {**self.create_nodegroup_params} + operator = EksCreateNodegroupOperator( + task_id=TASK_ID, + **op_kwargs, + deferrable=True, + ) + with pytest.raises(TaskDeferred) as exc: + operator.execute({}) + assert isinstance(exc.value.trigger, EksNodegroupTrigger), "Trigger is not a EksNodegroupTrigger" + + def test_create_nodegroup_deferrable_versus_wait_for_completion(self): + op_kwargs = {**self.create_nodegroup_params} + operator = EksCreateNodegroupOperator( + task_id=TASK_ID, + **op_kwargs, + deferrable=True, + wait_for_completion=True, + ) + assert operator.wait_for_completion is False + operator = EksCreateNodegroupOperator( + task_id=TASK_ID, + **op_kwargs, + deferrable=False, + wait_for_completion=True, + ) + assert operator.wait_for_completion is True + class TestEksDeleteClusterOperator: def setup_method(self) -> None: diff --git a/tests/providers/amazon/aws/triggers/test_eks.py b/tests/providers/amazon/aws/triggers/test_eks.py index dbc71e7296b78..0b94e957db6cb 100644 --- a/tests/providers/amazon/aws/triggers/test_eks.py +++ b/tests/providers/amazon/aws/triggers/test_eks.py @@ -27,14 +27,17 @@ from airflow.providers.amazon.aws.triggers.eks import ( EksCreateFargateProfileTrigger, EksDeleteFargateProfileTrigger, + EksNodegroupTrigger, ) from airflow.triggers.base import TriggerEvent TEST_CLUSTER_IDENTIFIER = "test-cluster" TEST_FARGATE_PROFILE_NAME = "test-fargate-profile" +TEST_NODEGROUP_NAME = "test-nodegroup" TEST_WAITER_DELAY = 10 TEST_WAITER_MAX_ATTEMPTS = 10 TEST_AWS_CONN_ID = "test-aws-id" +TEST_REGION = "test-region" class TestEksCreateFargateProfileTrigger: @@ -297,3 +300,156 @@ async def test_eks_delete_fargate_profile_trigger_run_attempts_failed(self, mock await generator.asend(None) assert f"Delete Fargate Profile failed: {error_failed}" in str(exc.value) assert a_mock.get_waiter().wait.call_count == 3 + + +class TestEksNodegroupTrigger: + def test_eks_nodegroup_trigger_serialize(self): + eks_nodegroup_trigger = EksNodegroupTrigger( + waiter_name="test_waiter_name", + cluster_name=TEST_CLUSTER_IDENTIFIER, + nodegroup_name=TEST_NODEGROUP_NAME, + aws_conn_id=TEST_AWS_CONN_ID, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, + region=TEST_REGION, + ) + + class_path, args = eks_nodegroup_trigger.serialize() + assert class_path == "airflow.providers.amazon.aws.triggers.eks.EksNodegroupTrigger" + assert args["waiter_name"] == "test_waiter_name" + assert args["cluster_name"] == TEST_CLUSTER_IDENTIFIER + assert args["nodegroup_name"] == TEST_NODEGROUP_NAME + assert args["aws_conn_id"] == TEST_AWS_CONN_ID + assert args["waiter_delay"] == str(TEST_WAITER_DELAY) + assert args["waiter_max_attempts"] == str(TEST_WAITER_MAX_ATTEMPTS) + assert args["region"] == TEST_REGION + + @pytest.mark.asyncio + @mock.patch.object(EksHook, "async_conn") + async def test_eks_nodegroup_trigger_run(self, mock_async_conn): + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + + a_mock.get_waiter().wait = AsyncMock() + + eks_nodegroup_trigger = EksNodegroupTrigger( + waiter_name="test_waiter_name", + cluster_name=TEST_CLUSTER_IDENTIFIER, + nodegroup_name=TEST_NODEGROUP_NAME, + aws_conn_id=TEST_AWS_CONN_ID, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, + region=TEST_REGION, + ) + + generator = eks_nodegroup_trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent( + { + "status": "success", + "cluster_name": TEST_CLUSTER_IDENTIFIER, + "nodegroup_name": TEST_NODEGROUP_NAME, + } + ) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(EksHook, "async_conn") + async def test_eks_nodegroup_trigger_run_multiple_attempts(self, mock_async_conn, mock_sleep): + mock_sleep.return_value = True + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"nodegroup": {"status": "CREATING"}}, + ) + a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, error, True]) + + eks_nodegroup_trigger = EksNodegroupTrigger( + waiter_name="test_waiter_name", + cluster_name=TEST_CLUSTER_IDENTIFIER, + nodegroup_name=TEST_NODEGROUP_NAME, + aws_conn_id=TEST_AWS_CONN_ID, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, + region=TEST_REGION, + ) + + generator = eks_nodegroup_trigger.run() + response = await generator.asend(None) + assert a_mock.get_waiter().wait.call_count == 4 + assert response == TriggerEvent( + { + "status": "success", + "cluster_name": TEST_CLUSTER_IDENTIFIER, + "nodegroup_name": TEST_NODEGROUP_NAME, + } + ) + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(EksHook, "async_conn") + async def test_eks_nodegroup_trigger_run_attempts_exceeded(self, mock_async_conn, mock_sleep): + mock_sleep.return_value = True + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error = WaiterError( + name="test_name", + reason="test_reason", + last_response={"nodegroup": {"status": "CREATING"}}, + ) + a_mock.get_waiter().wait = AsyncMock(side_effect=[error, error, error, True]) + + eks_nodegroup_trigger = EksNodegroupTrigger( + waiter_name="test_waiter_name", + cluster_name=TEST_CLUSTER_IDENTIFIER, + nodegroup_name=TEST_NODEGROUP_NAME, + aws_conn_id=TEST_AWS_CONN_ID, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=2, + region=TEST_REGION, + ) + + with pytest.raises(AirflowException) as exc: + generator = eks_nodegroup_trigger.run() + await generator.asend(None) + assert "Waiter error: max attempts reached" in str(exc.value) + assert a_mock.get_waiter().wait.call_count == 2 + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch.object(EksHook, "async_conn") + async def test_eks_nodegroup_trigger_run_attempts_failed(self, mock_async_conn, mock_sleep): + mock_sleep.return_value = True + a_mock = mock.MagicMock() + mock_async_conn.__aenter__.return_value = a_mock + error_creating = WaiterError( + name="test_name", + reason="test_reason", + last_response={"nodegroup": {"status": "CREATING"}}, + ) + error_failed = WaiterError( + name="test_name", + reason="Waiter encountered a terminal failure state:", + last_response={"nodegroup": {"status": "DELETE_FAILED"}}, + ) + a_mock.get_waiter().wait = AsyncMock(side_effect=[error_creating, error_creating, error_failed]) + mock_sleep.return_value = True + + eks_nodegroup_trigger = EksNodegroupTrigger( + waiter_name="test_waiter_name", + cluster_name=TEST_CLUSTER_IDENTIFIER, + nodegroup_name=TEST_NODEGROUP_NAME, + aws_conn_id=TEST_AWS_CONN_ID, + waiter_delay=TEST_WAITER_DELAY, + waiter_max_attempts=TEST_WAITER_MAX_ATTEMPTS, + region=TEST_REGION, + ) + with pytest.raises(AirflowException) as exc: + generator = eks_nodegroup_trigger.run() + await generator.asend(None) + + assert "Error checking nodegroup" in str(exc.value) + assert a_mock.get_waiter().wait.call_count == 3