Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 48 additions & 7 deletions cosmos/profiles/athena/access_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,33 @@

from typing import Any

from cosmos.exceptions import CosmosValueError

from ..base import BaseProfileMapping


class AthenaAccessKeyProfileMapping(BaseProfileMapping):
"""
Maps Airflow AWS connections to a dbt Athena profile using an access key id and secret access key.
Uses the Airflow AWS Connection provided to get_credentials() to generate the profile for dbt.

https://docs.getdbt.com/docs/core/connect-data-platform/athena-setup
https://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/connections/aws.html


This behaves similarly to other provider operators such as the AWS Athena Operator.
Where you pass the aws_conn_id and the operator will generate the credentials for you.

https://registry.astronomer.io/providers/amazon/versions/latest/modules/athenaoperator

Information about the dbt Athena profile that is generated can be found here:

https://github.com/dbt-athena/dbt-athena?tab=readme-ov-file#configuring-your-profile
https://docs.getdbt.com/docs/core/connect-data-platform/athena-setup
"""

airflow_connection_type: str = "aws"
dbt_profile_type: str = "athena"
is_community: bool = True
temporary_credentials = None

required_fields = [
"aws_access_key_id",
Expand All @@ -26,11 +39,7 @@ class AthenaAccessKeyProfileMapping(BaseProfileMapping):
"s3_staging_dir",
"schema",
]
secret_fields = ["aws_secret_access_key", "aws_session_token"]
airflow_param_mapping = {
"aws_access_key_id": "login",
"aws_secret_access_key": "password",
"aws_session_token": "extra.aws_session_token",
"aws_profile_name": "extra.aws_profile_name",
"database": "extra.database",
"debug_query_state": "extra.debug_query_state",
Expand All @@ -49,11 +58,43 @@ class AthenaAccessKeyProfileMapping(BaseProfileMapping):
@property
def profile(self) -> dict[str, Any | None]:
"Gets profile. The password is stored in an environment variable."

self.temporary_credentials = self._get_temporary_credentials() # type: ignore

profile = {
**self.mapped_params,
**self.profile_args,
# aws_secret_access_key and aws_session_token should always get set as env var
"aws_access_key_id": self.temporary_credentials.access_key,
"aws_secret_access_key": self.get_env_var_format("aws_secret_access_key"),
"aws_session_token": self.get_env_var_format("aws_session_token"),
}

return self.filter_null(profile)

@property
def env_vars(self) -> dict[str, str]:
"Overwrites the env_vars for athena, Returns a dictionary of environment variables that should be set based on the self.temporary_credentials."

if self.temporary_credentials is None:
raise CosmosValueError(f"Could not find the athena credentials.")

env_vars = {}

env_secret_key_name = self.get_env_var_name("aws_secret_access_key")
env_session_token_name = self.get_env_var_name("aws_session_token")

env_vars[env_secret_key_name] = str(self.temporary_credentials.secret_key)
env_vars[env_session_token_name] = str(self.temporary_credentials.token)

return env_vars

def _get_temporary_credentials(self): # type: ignore
"""
Helper function to retrieve temporary short lived credentials
Returns an object including access_key, secret_key and token
"""
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
Comment thread
octiva marked this conversation as resolved.

hook = AwsGenericHook(self.conn_id) # type: ignore
credentials = hook.get_credentials()
return credentials
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ dbt-all = [
]
dbt-athena = [
"dbt-athena-community",
"apache-airflow-providers-amazon>=8.0.0",
]
dbt-bigquery = [
"dbt-bigquery",
Expand Down Expand Up @@ -110,7 +111,6 @@ tests = [
"mypy",
"sqlalchemy-stubs", # Change when sqlalchemy is upgraded https://docs.sqlalchemy.org/en/14/orm/extensions/mypy.html
]

docker = [
"apache-airflow-providers-docker>=3.5.0",
]
Expand All @@ -121,7 +121,6 @@ pydantic = [
"pydantic>=1.10.0,<2.0.0",
]


[project.entry-points.cosmos]
provider_info = "cosmos:get_provider_info"

Expand Down
71 changes: 56 additions & 15 deletions tests/profiles/athena/test_athena_access_key.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,49 @@
"Tests for the Athena profile."

import json
from unittest.mock import patch

from collections import namedtuple
import sys
from unittest.mock import MagicMock, patch
import pytest
from airflow.models.connection import Connection

from cosmos.profiles import get_automatic_profile_mapping
from cosmos.profiles.athena.access_key import AthenaAccessKeyProfileMapping

Credentials = namedtuple("Credentials", ["access_key", "secret_key", "token"])

mock_assumed_credentials = Credentials(
secret_key="my_aws_assumed_secret_key",
access_key="my_aws_assumed_access_key",
token="my_aws_assumed_token",
)

mock_missing_credentials = Credentials(access_key=None, secret_key=None, token=None)


@pytest.fixture(autouse=True)
def mock_aws_module():
mock_aws_hook = MagicMock()

class MockAwsGenericHook:
def __init__(self, conn_id: str) -> None:
pass

def get_credentials(self) -> Credentials:
return mock_assumed_credentials

mock_aws_hook.AwsGenericHook = MockAwsGenericHook

with patch.dict(sys.modules, {"airflow.providers.amazon.aws.hooks.base_aws": mock_aws_hook}):
yield mock_aws_hook


@pytest.fixture()
def mock_athena_conn(): # type: ignore
"""
Sets the connection as an environment variable.
"""
Comment thread
jbandoro marked this conversation as resolved.

conn = Connection(
conn_id="my_athena_connection",
conn_type="aws",
Expand All @@ -24,7 +53,7 @@ def mock_athena_conn(): # type: ignore
{
"aws_session_token": "token123",
"database": "my_database",
"region_name": "my_region",
"region_name": "us-east-1",
"s3_staging_dir": "s3://my_bucket/dbt/",
"schema": "my_schema",
}
Expand All @@ -48,14 +77,15 @@ def test_athena_connection_claiming() -> None:
# - region_name
# - s3_staging_dir
# - schema

potential_values = {
"conn_type": "aws",
"login": "my_aws_access_key_id",
"password": "my_aws_secret_key",
"extra": json.dumps(
{
"database": "my_database",
"region_name": "my_region",
"region_name": "us-east-1",
"s3_staging_dir": "s3://my_bucket/dbt/",
"schema": "my_schema",
}
Expand All @@ -68,12 +98,14 @@ def test_athena_connection_claiming() -> None:
del values[key]
conn = Connection(**values) # type: ignore

print("testing with", values)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
# should raise an InvalidMappingException
profile_mapping = AthenaAccessKeyProfileMapping(conn, {})
assert not profile_mapping.can_claim_connection()
with patch(
"cosmos.profiles.athena.access_key.AthenaAccessKeyProfileMapping._get_temporary_credentials",
return_value=mock_missing_credentials,
):
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
# should raise an InvalidMappingException
profile_mapping = AthenaAccessKeyProfileMapping(conn, {})
assert not profile_mapping.can_claim_connection()

# if we have them all, it should claim
conn = Connection(**potential_values) # type: ignore
Expand All @@ -88,6 +120,7 @@ def test_athena_profile_mapping_selected(
"""
Tests that the correct profile mapping is selected for Athena.
"""

profile_mapping = get_automatic_profile_mapping(
mock_athena_conn.conn_id,
)
Expand All @@ -100,13 +133,14 @@ def test_athena_profile_args(
"""
Tests that the profile values get set correctly for Athena.
"""

profile_mapping = get_automatic_profile_mapping(
mock_athena_conn.conn_id,
)

assert profile_mapping.profile == {
"type": "athena",
"aws_access_key_id": mock_athena_conn.login,
"aws_access_key_id": mock_assumed_credentials.access_key,
"aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}",
"aws_session_token": "{{ env_var('COSMOS_CONN_AWS_AWS_SESSION_TOKEN') }}",
"database": mock_athena_conn.extra_dejson.get("database"),
Expand All @@ -122,9 +156,14 @@ def test_athena_profile_args_overrides(
"""
Tests that you can override the profile values for Athena.
"""

profile_mapping = get_automatic_profile_mapping(
mock_athena_conn.conn_id,
profile_args={"schema": "my_custom_schema", "database": "my_custom_db", "aws_session_token": "override_token"},
profile_args={
"schema": "my_custom_schema",
"database": "my_custom_db",
"aws_session_token": "override_token",
},
Comment thread
octiva marked this conversation as resolved.
)

assert profile_mapping.profile_args == {
Expand All @@ -135,7 +174,7 @@ def test_athena_profile_args_overrides(

assert profile_mapping.profile == {
"type": "athena",
"aws_access_key_id": mock_athena_conn.login,
"aws_access_key_id": mock_assumed_credentials.access_key,
"aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}",
"aws_session_token": "{{ env_var('COSMOS_CONN_AWS_AWS_SESSION_TOKEN') }}",
"database": "my_custom_db",
Expand All @@ -151,10 +190,12 @@ def test_athena_profile_env_vars(
"""
Tests that the environment variables get set correctly for Athena.
"""

profile_mapping = get_automatic_profile_mapping(
mock_athena_conn.conn_id,
)

assert profile_mapping.env_vars == {
"COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY": mock_athena_conn.password,
"COSMOS_CONN_AWS_AWS_SESSION_TOKEN": mock_athena_conn.extra_dejson.get("aws_session_token"),
"COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY": mock_assumed_credentials.secret_key,
"COSMOS_CONN_AWS_AWS_SESSION_TOKEN": mock_assumed_credentials.token,
}