Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from .redshift.user_pass import RedshiftUserPasswordProfileMapping
from .snowflake.user_pass import SnowflakeUserPasswordProfileMapping
from .snowflake.user_privatekey import SnowflakePrivateKeyPemProfileMapping
from .snowflake.user_encrypted_privatekey import SnowflakeEncryptedPrivateKeyPemProfileMapping
from .snowflake.user_encrypted_privatekey_file import SnowflakeEncryptedPrivateKeyFilePemProfileMapping
from .snowflake.user_encrypted_privatekey_env_variable import SnowflakeEncryptedPrivateKeyPemProfileMapping
from .spark.thrift import SparkThriftProfileMapping
from .trino.certificate import TrinoCertificateProfileMapping
from .trino.jwt import TrinoJWTProfileMapping
Expand All @@ -32,6 +33,7 @@
PostgresUserPasswordProfileMapping,
RedshiftUserPasswordProfileMapping,
SnowflakeUserPasswordProfileMapping,
SnowflakeEncryptedPrivateKeyFilePemProfileMapping,
SnowflakeEncryptedPrivateKeyPemProfileMapping,
SnowflakePrivateKeyPemProfileMapping,
SparkThriftProfileMapping,
Expand Down Expand Up @@ -71,6 +73,7 @@ def get_automatic_profile_mapping(
"RedshiftUserPasswordProfileMapping",
"SnowflakeUserPasswordProfileMapping",
"SnowflakePrivateKeyPemProfileMapping",
"SnowflakeEncryptedPrivateKeyFilePemProfileMapping",
"SparkThriftProfileMapping",
"ExasolUserPasswordProfileMapping",
"TrinoLDAPProfileMapping",
Expand Down
4 changes: 3 additions & 1 deletion cosmos/profiles/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from .user_pass import SnowflakeUserPasswordProfileMapping
from .user_privatekey import SnowflakePrivateKeyPemProfileMapping
from .user_encrypted_privatekey import SnowflakeEncryptedPrivateKeyPemProfileMapping
from .user_encrypted_privatekey_file import SnowflakeEncryptedPrivateKeyFilePemProfileMapping
from .user_encrypted_privatekey_env_variable import SnowflakeEncryptedPrivateKeyPemProfileMapping

__all__ = [
"SnowflakeUserPasswordProfileMapping",
"SnowflakePrivateKeyPemProfileMapping",
"SnowflakeEncryptedPrivateKeyFilePemProfileMapping",
"SnowflakeEncryptedPrivateKeyPemProfileMapping",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key."
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any

from ..base import BaseProfileMapping

if TYPE_CHECKING:
from airflow.models import Connection


class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping):
"""
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key.
https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication
https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html
"""

airflow_connection_type: str = "snowflake"
dbt_profile_type: str = "snowflake"
is_community: bool = True

required_fields = [
"account",
"user",
"database",
"warehouse",
"schema",
"private_key",
"private_key_passphrase",
]
secret_fields = [
"private_key",
"private_key_passphrase",
]
airflow_param_mapping = {
"account": "extra.account",
"user": "login",
"database": "extra.database",
"warehouse": "extra.warehouse",
"schema": "schema",
"role": "extra.role",
"private_key": "extra.private_key_content",
"private_key_passphrase": "password",
}

def can_claim_connection(self) -> bool:
# Make sure this isn't a private key path credential
result = super().can_claim_connection()
if result and self.conn.extra_dejson.get("private_key_file") is not None:
return False
return result

@property
def conn(self) -> Connection:
"""
Snowflake can be odd because the fields used to be stored with keys in the format
'extra__snowflake__account', but now are stored as 'account'.

This standardizes the keys to be 'account', 'database', etc.
"""
conn = super().conn

conn_dejson = conn.extra_dejson

if conn_dejson.get("extra__snowflake__account"):
conn_dejson = {key.replace("extra__snowflake__", ""): value for key, value in conn_dejson.items()}

conn.extra = json.dumps(conn_dejson)

return conn

@property
def profile(self) -> dict[str, Any | None]:
"Gets profile."
profile_vars = {
**self.mapped_params,
**self.profile_args,
"private_key": self.get_env_var_format("private_key"),
"private_key_passphrase": self.get_env_var_format("private_key_passphrase"),
}

# remove any null values
return self.filter_null(profile_vars)

def transform_account(self, account: str) -> str:
"Transform the account to the format <account>.<region> if it's not already."
region = self.conn.extra_dejson.get("region")
if region and region not in account:
account = f"{account}.{region}"

return str(account)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key."
"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key path."
from __future__ import annotations

import json
Expand All @@ -10,9 +10,9 @@
from airflow.models import Connection


class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping):
class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(BaseProfileMapping):
"""
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key.
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key path.
https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication
https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html
"""
Expand Down Expand Up @@ -44,6 +44,13 @@ class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping):
"private_key_path": "extra.private_key_file",
}

def can_claim_connection(self) -> bool:
# Make sure this isn't a private key environmentvariable
result = super().can_claim_connection()
if result and self.conn.extra_dejson.get("private_key_content") is not None:
return False
return result

@property
def conn(self) -> Connection:
"""
Expand Down
5 changes: 4 additions & 1 deletion cosmos/profiles/snowflake/user_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ class SnowflakeUserPasswordProfileMapping(BaseProfileMapping):
def can_claim_connection(self) -> bool:
# Make sure this isn't a private key path credential
result = super().can_claim_connection()
if result and self.conn.extra_dejson.get("private_key_file") is not None:
if result and (
self.conn.extra_dejson.get("private_key_file") is not None
or self.conn.extra_dejson.get("private_key_content") is not None
):
return False
return result

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"Tests for the Snowflake user/private key profile."
"Tests for the Snowflake user/private key environmentvariable profile."

import json
from unittest.mock import patch
Expand Down Expand Up @@ -29,7 +29,7 @@ def mock_snowflake_conn(): # type: ignore
"region": "my_region",
"database": "my_database",
"warehouse": "my_warehouse",
"private_key_file": "path/to/private_key.p8",
"private_key_content": "my_private_key",
}
),
)
Expand All @@ -52,7 +52,7 @@ def test_connection_claiming() -> None:
"account": "my_account",
"database": "my_database",
"warehouse": "my_warehouse",
"private_key_file": "path/to/private_key.p8",
"private_key_content": "my_private_key",
}
),
}
Expand Down Expand Up @@ -130,8 +130,8 @@ def test_profile_args(
assert profile_mapping.profile == {
"type": mock_snowflake_conn.conn_type,
"user": mock_snowflake_conn.login,
"private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}",
"private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}",
"private_key_path": mock_snowflake_conn.extra_dejson.get("private_key_file"),
"schema": mock_snowflake_conn.schema,
"account": f"{mock_account}.{mock_region}",
"database": mock_snowflake_conn.extra_dejson.get("database"),
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_profile_args_overrides(
"type": mock_snowflake_conn.conn_type,
"user": mock_snowflake_conn.login,
"private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}",
"private_key_path": mock_snowflake_conn.extra_dejson.get("private_key_file"),
"private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}",
"schema": mock_snowflake_conn.schema,
"account": f"{mock_account}.{mock_region}",
"database": "my_db_override",
Expand All @@ -178,6 +178,7 @@ def test_profile_env_vars(
mock_snowflake_conn.conn_id,
)
assert profile_mapping.env_vars == {
"COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY": mock_snowflake_conn.extra_dejson.get("private_key_content"),
"COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE": mock_snowflake_conn.password,
}

Expand All @@ -197,7 +198,7 @@ def test_old_snowflake_format() -> None:
"extra__snowflake__account": "my_account",
"extra__snowflake__database": "my_database",
"extra__snowflake__warehouse": "my_warehouse",
"extra__snowflake__private_key_file": "path/to/private_key.p8",
"extra__snowflake__private_key_content": "my_private_key",
}
),
)
Expand All @@ -207,8 +208,8 @@ def test_old_snowflake_format() -> None:
assert profile_mapping.profile == {
"type": conn.conn_type,
"user": conn.login,
"private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}",
"private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}",
"private_key_path": conn.extra_dejson.get("private_key_file"),
"schema": conn.schema,
"account": conn.extra_dejson.get("account"),
"database": conn.extra_dejson.get("database"),
Expand Down
Loading