diff --git a/cosmos/profiles/athena/access_key.py b/cosmos/profiles/athena/access_key.py index a8f71c2b7a..02de2be247 100644 --- a/cosmos/profiles/athena/access_key.py +++ b/cosmos/profiles/athena/access_key.py @@ -3,20 +3,33 @@ from typing import Any +from cosmos.exceptions import CosmosValueError + from ..base import BaseProfileMapping class AthenaAccessKeyProfileMapping(BaseProfileMapping): """ - Maps Airflow AWS connections to a dbt Athena profile using an access key id and secret access key. + Uses the Airflow AWS Connection provided to get_credentials() to generate the profile for dbt. - https://docs.getdbt.com/docs/core/connect-data-platform/athena-setup https://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/connections/aws.html + + + This behaves similarly to other provider operators such as the AWS Athena Operator. + Where you pass the aws_conn_id and the operator will generate the credentials for you. + + https://registry.astronomer.io/providers/amazon/versions/latest/modules/athenaoperator + + Information about the dbt Athena profile that is generated can be found here: + + https://github.com/dbt-athena/dbt-athena?tab=readme-ov-file#configuring-your-profile + https://docs.getdbt.com/docs/core/connect-data-platform/athena-setup """ airflow_connection_type: str = "aws" dbt_profile_type: str = "athena" is_community: bool = True + temporary_credentials = None required_fields = [ "aws_access_key_id", @@ -26,11 +39,7 @@ class AthenaAccessKeyProfileMapping(BaseProfileMapping): "s3_staging_dir", "schema", ] - secret_fields = ["aws_secret_access_key", "aws_session_token"] airflow_param_mapping = { - "aws_access_key_id": "login", - "aws_secret_access_key": "password", - "aws_session_token": "extra.aws_session_token", "aws_profile_name": "extra.aws_profile_name", "database": "extra.database", "debug_query_state": "extra.debug_query_state", @@ -49,11 +58,43 @@ class AthenaAccessKeyProfileMapping(BaseProfileMapping): @property def profile(self) -> dict[str, Any | None]: "Gets profile. The password is stored in an environment variable." + + self.temporary_credentials = self._get_temporary_credentials() # type: ignore + profile = { **self.mapped_params, **self.profile_args, - # aws_secret_access_key and aws_session_token should always get set as env var + "aws_access_key_id": self.temporary_credentials.access_key, "aws_secret_access_key": self.get_env_var_format("aws_secret_access_key"), "aws_session_token": self.get_env_var_format("aws_session_token"), } + return self.filter_null(profile) + + @property + def env_vars(self) -> dict[str, str]: + "Overwrites the env_vars for athena, Returns a dictionary of environment variables that should be set based on the self.temporary_credentials." + + if self.temporary_credentials is None: + raise CosmosValueError(f"Could not find the athena credentials.") + + env_vars = {} + + env_secret_key_name = self.get_env_var_name("aws_secret_access_key") + env_session_token_name = self.get_env_var_name("aws_session_token") + + env_vars[env_secret_key_name] = str(self.temporary_credentials.secret_key) + env_vars[env_session_token_name] = str(self.temporary_credentials.token) + + return env_vars + + def _get_temporary_credentials(self): # type: ignore + """ + Helper function to retrieve temporary short lived credentials + Returns an object including access_key, secret_key and token + """ + from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook + + hook = AwsGenericHook(self.conn_id) # type: ignore + credentials = hook.get_credentials() + return credentials diff --git a/pyproject.toml b/pyproject.toml index c08de4adea..9d367c075f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ dbt-all = [ ] dbt-athena = [ "dbt-athena-community", + "apache-airflow-providers-amazon>=8.0.0", ] dbt-bigquery = [ "dbt-bigquery", @@ -110,7 +111,6 @@ tests = [ "mypy", "sqlalchemy-stubs", # Change when sqlalchemy is upgraded https://docs.sqlalchemy.org/en/14/orm/extensions/mypy.html ] - docker = [ "apache-airflow-providers-docker>=3.5.0", ] @@ -121,7 +121,6 @@ pydantic = [ "pydantic>=1.10.0,<2.0.0", ] - [project.entry-points.cosmos] provider_info = "cosmos:get_provider_info" diff --git a/tests/profiles/athena/test_athena_access_key.py b/tests/profiles/athena/test_athena_access_key.py index 22c8efa2c0..c224a9d4b8 100644 --- a/tests/profiles/athena/test_athena_access_key.py +++ b/tests/profiles/athena/test_athena_access_key.py @@ -1,20 +1,49 @@ "Tests for the Athena profile." import json -from unittest.mock import patch - +from collections import namedtuple +import sys +from unittest.mock import MagicMock, patch import pytest from airflow.models.connection import Connection from cosmos.profiles import get_automatic_profile_mapping from cosmos.profiles.athena.access_key import AthenaAccessKeyProfileMapping +Credentials = namedtuple("Credentials", ["access_key", "secret_key", "token"]) + +mock_assumed_credentials = Credentials( + secret_key="my_aws_assumed_secret_key", + access_key="my_aws_assumed_access_key", + token="my_aws_assumed_token", +) + +mock_missing_credentials = Credentials(access_key=None, secret_key=None, token=None) + + +@pytest.fixture(autouse=True) +def mock_aws_module(): + mock_aws_hook = MagicMock() + + class MockAwsGenericHook: + def __init__(self, conn_id: str) -> None: + pass + + def get_credentials(self) -> Credentials: + return mock_assumed_credentials + + mock_aws_hook.AwsGenericHook = MockAwsGenericHook + + with patch.dict(sys.modules, {"airflow.providers.amazon.aws.hooks.base_aws": mock_aws_hook}): + yield mock_aws_hook + @pytest.fixture() def mock_athena_conn(): # type: ignore """ Sets the connection as an environment variable. """ + conn = Connection( conn_id="my_athena_connection", conn_type="aws", @@ -24,7 +53,7 @@ def mock_athena_conn(): # type: ignore { "aws_session_token": "token123", "database": "my_database", - "region_name": "my_region", + "region_name": "us-east-1", "s3_staging_dir": "s3://my_bucket/dbt/", "schema": "my_schema", } @@ -48,6 +77,7 @@ def test_athena_connection_claiming() -> None: # - region_name # - s3_staging_dir # - schema + potential_values = { "conn_type": "aws", "login": "my_aws_access_key_id", @@ -55,7 +85,7 @@ def test_athena_connection_claiming() -> None: "extra": json.dumps( { "database": "my_database", - "region_name": "my_region", + "region_name": "us-east-1", "s3_staging_dir": "s3://my_bucket/dbt/", "schema": "my_schema", } @@ -68,12 +98,14 @@ def test_athena_connection_claiming() -> None: del values[key] conn = Connection(**values) # type: ignore - print("testing with", values) - - with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - # should raise an InvalidMappingException - profile_mapping = AthenaAccessKeyProfileMapping(conn, {}) - assert not profile_mapping.can_claim_connection() + with patch( + "cosmos.profiles.athena.access_key.AthenaAccessKeyProfileMapping._get_temporary_credentials", + return_value=mock_missing_credentials, + ): + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + # should raise an InvalidMappingException + profile_mapping = AthenaAccessKeyProfileMapping(conn, {}) + assert not profile_mapping.can_claim_connection() # if we have them all, it should claim conn = Connection(**potential_values) # type: ignore @@ -88,6 +120,7 @@ def test_athena_profile_mapping_selected( """ Tests that the correct profile mapping is selected for Athena. """ + profile_mapping = get_automatic_profile_mapping( mock_athena_conn.conn_id, ) @@ -100,13 +133,14 @@ def test_athena_profile_args( """ Tests that the profile values get set correctly for Athena. """ + profile_mapping = get_automatic_profile_mapping( mock_athena_conn.conn_id, ) assert profile_mapping.profile == { "type": "athena", - "aws_access_key_id": mock_athena_conn.login, + "aws_access_key_id": mock_assumed_credentials.access_key, "aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}", "aws_session_token": "{{ env_var('COSMOS_CONN_AWS_AWS_SESSION_TOKEN') }}", "database": mock_athena_conn.extra_dejson.get("database"), @@ -122,9 +156,14 @@ def test_athena_profile_args_overrides( """ Tests that you can override the profile values for Athena. """ + profile_mapping = get_automatic_profile_mapping( mock_athena_conn.conn_id, - profile_args={"schema": "my_custom_schema", "database": "my_custom_db", "aws_session_token": "override_token"}, + profile_args={ + "schema": "my_custom_schema", + "database": "my_custom_db", + "aws_session_token": "override_token", + }, ) assert profile_mapping.profile_args == { @@ -135,7 +174,7 @@ def test_athena_profile_args_overrides( assert profile_mapping.profile == { "type": "athena", - "aws_access_key_id": mock_athena_conn.login, + "aws_access_key_id": mock_assumed_credentials.access_key, "aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}", "aws_session_token": "{{ env_var('COSMOS_CONN_AWS_AWS_SESSION_TOKEN') }}", "database": "my_custom_db", @@ -151,10 +190,12 @@ def test_athena_profile_env_vars( """ Tests that the environment variables get set correctly for Athena. """ + profile_mapping = get_automatic_profile_mapping( mock_athena_conn.conn_id, ) + assert profile_mapping.env_vars == { - "COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY": mock_athena_conn.password, - "COSMOS_CONN_AWS_AWS_SESSION_TOKEN": mock_athena_conn.extra_dejson.get("aws_session_token"), + "COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY": mock_assumed_credentials.secret_key, + "COSMOS_CONN_AWS_AWS_SESSION_TOKEN": mock_assumed_credentials.token, }