diff --git a/cosmos/constants.py b/cosmos/constants.py index 847820ff20..b356d5542e 100644 --- a/cosmos/constants.py +++ b/cosmos/constants.py @@ -56,6 +56,7 @@ class ExecutionMode(Enum): LOCAL = "local" DOCKER = "docker" KUBERNETES = "kubernetes" + AWS_EKS = "aws_eks" VIRTUALENV = "virtualenv" AZURE_CONTAINER_INSTANCE = "azure_container_instance" diff --git a/cosmos/converter.py b/cosmos/converter.py index 08a44b6766..5e415486ee 100644 --- a/cosmos/converter.py +++ b/cosmos/converter.py @@ -115,6 +115,7 @@ def validate_initial_user_config( """ if profile_config is None and execution_config.execution_mode not in ( ExecutionMode.KUBERNETES, + ExecutionMode.AWS_EKS, ExecutionMode.DOCKER, ): raise CosmosValueError(f"The profile_config is mandatory when using {execution_config.execution_mode}") diff --git a/cosmos/operators/aws_eks.py b/cosmos/operators/aws_eks.py new file mode 100644 index 0000000000..1800283783 --- /dev/null +++ b/cosmos/operators/aws_eks.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +from typing import Any, Sequence + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.eks import EksHook +from airflow.utils.context import Context + +from cosmos.operators.kubernetes import ( + DbtBuildKubernetesOperator, + DbtKubernetesBaseOperator, + DbtLSKubernetesOperator, + DbtRunKubernetesOperator, + DbtRunOperationKubernetesOperator, + DbtSeedKubernetesOperator, + DbtSnapshotKubernetesOperator, + DbtTestKubernetesOperator, +) + +DEFAULT_CONN_ID = "aws_default" +DEFAULT_NAMESPACE = "default" + + +class DbtAwsEksBaseOperator(DbtKubernetesBaseOperator): + template_fields: Sequence[str] = tuple( + { + "cluster_name", + "in_cluster", + "namespace", + "pod_name", + "aws_conn_id", + "region", + } + | set(DbtKubernetesBaseOperator.template_fields) + ) + + def __init__( + self, + cluster_name: str, + pod_name: str | None = None, + namespace: str | None = DEFAULT_NAMESPACE, + aws_conn_id: str = DEFAULT_CONN_ID, + region: str | None = None, + **kwargs: Any, + ) -> None: + self.cluster_name = cluster_name + self.pod_name = pod_name + self.namespace = namespace + self.aws_conn_id = aws_conn_id + self.region = region + super().__init__( + name=self.pod_name, + namespace=self.namespace, + **kwargs, + ) + # There is no need to manage the kube_config file, as it will be generated automatically. + # All Kubernetes parameters (except config_file) are also valid for the EksPodOperator. + if self.config_file: + raise AirflowException("The config_file is not an allowed parameter for the EksPodOperator.") + + def execute(self, context: Context) -> Any | None: # type: ignore + eks_hook = EksHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region, + ) + with eks_hook.generate_config_file( + eks_cluster_name=self.cluster_name, pod_namespace=self.namespace + ) as self.config_file: + return super().execute(context) + + +class DbtBuildAwsEksOperator(DbtAwsEksBaseOperator, DbtBuildKubernetesOperator): + """ + Executes a dbt core build command. + """ + + template_fields: Sequence[str] = ( + DbtAwsEksBaseOperator.template_fields + DbtBuildKubernetesOperator.template_fields # type: ignore[operator] + ) + + +class DbtLSAwsEksOperator(DbtAwsEksBaseOperator, DbtLSKubernetesOperator): + """ + Executes a dbt core ls command. + """ + + +class DbtSeedAwsEksOperator(DbtAwsEksBaseOperator, DbtSeedKubernetesOperator): + """ + Executes a dbt core seed command. + """ + + template_fields: Sequence[str] = ( + DbtAwsEksBaseOperator.template_fields + DbtSeedKubernetesOperator.template_fields # type: ignore[operator] + ) + + +class DbtSnapshotAwsEksOperator(DbtAwsEksBaseOperator, DbtSnapshotKubernetesOperator): + """ + Executes a dbt core snapshot command. + """ + + +class DbtRunAwsEksOperator(DbtAwsEksBaseOperator, DbtRunKubernetesOperator): + """ + Executes a dbt core run command. + """ + + template_fields: Sequence[str] = ( + DbtAwsEksBaseOperator.template_fields + DbtRunKubernetesOperator.template_fields # type: ignore[operator] + ) + + +class DbtTestAwsEksOperator(DbtAwsEksBaseOperator, DbtTestKubernetesOperator): + """ + Executes a dbt core test command. + """ + + template_fields: Sequence[str] = ( + DbtAwsEksBaseOperator.template_fields + DbtTestKubernetesOperator.template_fields # type: ignore[operator] + ) + + +class DbtRunOperationAwsEksOperator(DbtAwsEksBaseOperator, DbtRunOperationKubernetesOperator): + """ + Executes a dbt core run-operation command. + """ + + template_fields: Sequence[str] = ( + DbtAwsEksBaseOperator.template_fields + DbtRunOperationKubernetesOperator.template_fields # type: ignore[operator] + ) diff --git a/docs/getting_started/execution-modes.rst b/docs/getting_started/execution-modes.rst index 1765144d99..1b1a35cb92 100644 --- a/docs/getting_started/execution-modes.rst +++ b/docs/getting_started/execution-modes.rst @@ -9,7 +9,8 @@ Cosmos can run ``dbt`` commands using five different approaches, called ``execut 2. **virtualenv**: Run ``dbt`` commands from Python virtual environments managed by Cosmos 3. **docker**: Run ``dbt`` commands from Docker containers managed by Cosmos (requires a pre-existing Docker image) 4. **kubernetes**: Run ``dbt`` commands from Kubernetes Pods managed by Cosmos (requires a pre-existing Docker image) -5. **azure_container_instance**: Run ``dbt`` commands from Azure Container Instances managed by Cosmos (requires a pre-existing Docker image) +5. **aws_eks**: Run ``dbt`` commands from AWS EKS Pods managed by Cosmos (requires a pre-existing Docker image) +6. **azure_container_instance**: Run ``dbt`` commands from Azure Container Instances managed by Cosmos (requires a pre-existing Docker image) The choice of the ``execution mode`` can vary based on each user's needs and concerns. For more details, check each execution mode described below. @@ -38,6 +39,10 @@ The choice of the ``execution mode`` can vary based on each user's needs and con - Slow - High - No + * - AWS_EKS + - Slow + - High + - No * - Azure Container Instance - Slow - High @@ -159,6 +164,38 @@ Example DAG: "secrets": [postgres_password_secret], }, ) +AWS_EKS +---------- + +The ``aws_eks`` approach is very similar to the ``kubernetes`` approach, but it is specifically designed to run on AWS EKS clusters. +It uses the `EKSPodOperator `_ +to run the dbt commands. You need to provide the ``cluster_name`` in your operator_args to connect to the AWS EKS cluster. + + +Example DAG: + +.. code-block:: python + + postgres_password_secret = Secret( + deploy_type="env", + deploy_target="POSTGRES_PASSWORD", + secret="postgres-secrets", + key="password", + ) + + docker_cosmos_dag = DbtDag( + # ... + execution_config=ExecutionConfig( + execution_mode=ExecutionMode.AWS_EKS, + ), + operator_args={ + "image": "dbt-jaffle-shop:1.0.0", + "cluster_name": CLUSTER_NAME, + "get_logs": True, + "is_delete_operator_pod": False, + "secrets": [postgres_password_secret], + }, + ) Azure Container Instance ------------------------ diff --git a/pyproject.toml b/pyproject.toml index 8044162fc6..cb2530a5a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,9 @@ docker = [ kubernetes = [ "apache-airflow-providers-cncf-kubernetes>=5.1.1", ] +aws_eks = [ + "apache-airflow-providers-amazon>=8.0.0,<8.20.0", # https://github.com/apache/airflow/issues/39103 +] azure-container-instance = [ "apache-airflow-providers-microsoft-azure>=8.4.0", ] @@ -120,6 +123,7 @@ dependencies = [ "astronomer-cosmos[tests]", "apache-airflow-providers-postgres", "apache-airflow-providers-cncf-kubernetes>=5.1.1", + "apache-airflow-providers-amazon>=3.0.0,<8.20.0", # https://github.com/apache/airflow/issues/39103 "apache-airflow-providers-docker>=3.5.0", "apache-airflow-providers-microsoft-azure", "types-PyYAML", @@ -137,7 +141,7 @@ airflow = ["2.4", "2.5", "2.6", "2.7", "2.8", "2.9"] [tool.hatch.envs.tests.overrides] matrix.airflow.dependencies = [ - { value = "typing_extensions<4.6", if = ["2.6"] }, + { value = "typing_extensions<4.6", if = ["2.6"] } ] [tool.hatch.envs.tests.scripts] diff --git a/tests/operators/test_aws_eks.py b/tests/operators/test_aws_eks.py new file mode 100644 index 0000000000..35717a0617 --- /dev/null +++ b/tests/operators/test_aws_eks.py @@ -0,0 +1,97 @@ +from unittest.mock import MagicMock, patch + +import pytest +from airflow.exceptions import AirflowException + +from cosmos.operators.aws_eks import ( + DbtBuildAwsEksOperator, + DbtLSAwsEksOperator, + DbtRunAwsEksOperator, + DbtSeedAwsEksOperator, + DbtTestAwsEksOperator, +) + + +@pytest.fixture() +def mock_kubernetes_execute(): + with patch("cosmos.operators.kubernetes.KubernetesPodOperator.execute") as mock_execute: + yield mock_execute + + +base_kwargs = { + "conn_id": "my_airflow_connection", + "cluster_name": "my-cluster", + "task_id": "my-task", + "image": "my_image", + "project_dir": "my/dir", + "vars": { + "start_time": "{{ data_interval_start.strftime('%Y%m%d%H%M%S') }}", + "end_time": "{{ data_interval_end.strftime('%Y%m%d%H%M%S') }}", + }, + "no_version_check": True, +} + + +def test_dbt_kubernetes_build_command(): + """ + Since we know that the KubernetesOperator is tested, we can just test that the + command is built correctly and added to the "arguments" parameter. + """ + + result_map = { + "ls": DbtLSAwsEksOperator(**base_kwargs), + "run": DbtRunAwsEksOperator(**base_kwargs), + "test": DbtTestAwsEksOperator(**base_kwargs), + "build": DbtBuildAwsEksOperator(**base_kwargs), + "seed": DbtSeedAwsEksOperator(**base_kwargs), + } + + for command_name, command_operator in result_map.items(): + command_operator.build_kube_args(context=MagicMock(), cmd_flags=MagicMock()) + assert command_operator.arguments == [ + "dbt", + command_name, + "--vars", + "end_time: '{{ data_interval_end.strftime(''%Y%m%d%H%M%S'') }}'\n" + "start_time: '{{ data_interval_start.strftime(''%Y%m%d%H%M%S'') }}'\n", + "--no-version-check", + "--project-dir", + "my/dir", + ] + + +@patch("cosmos.operators.kubernetes.DbtKubernetesBaseOperator.build_kube_args") +@patch("cosmos.operators.aws_eks.EksHook.generate_config_file") +def test_dbt_kubernetes_operator_execute(mock_generate_config_file, mock_build_kube_args, mock_kubernetes_execute): + """Tests that the execute method call results in both the build_kube_args method and the kubernetes execute method being called.""" + operator = DbtLSAwsEksOperator( + conn_id="my_airflow_connection", + cluster_name="my-cluster", + task_id="my-task", + image="my_image", + project_dir="my/dir", + ) + operator.execute(context={}) + # Assert that the build_kube_args method was called in the execution + mock_build_kube_args.assert_called_once() + + # Assert that the generate_config_file method was called in the execution to create the kubeconfig for eks + mock_generate_config_file.assert_called_once_with(eks_cluster_name="my-cluster", pod_namespace="default") + + # Assert that the kubernetes execute method was called in the execution + mock_kubernetes_execute.assert_called_once() + assert mock_kubernetes_execute.call_args.args[-1] == {} + + +def test_provided_config_file_fails(): + """Tests that the constructor fails if it is called with a config_file.""" + with pytest.raises(AirflowException) as err_context: + DbtLSAwsEksOperator( + conn_id="my_airflow_connection", + cluster_name="my-cluster", + task_id="my-task", + image="my_image", + project_dir="my/dir", + config_file="my/config", + ) + assert "The config_file is not an allowed parameter for the EksPodOperator." in str(err_context.value)