From 529b0c5b5f12ecf890927ae08f6842cf31cba8c2 Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 10:56:05 +0000 Subject: [PATCH 01/11] Intial change for Snowflake encrypted private key --- cosmos/profiles/__init__.py | 5 +- cosmos/profiles/snowflake/__init__.py | 4 +- .../user_encrypted_privatekey_env_variable.py | 86 +++++++ ...y.py => user_encrypted_privatekey_file.py} | 6 +- ..._user_encrypted_privatekey_env_variable.py | 216 ++++++++++++++++++ ...owflake_user_encrypted_privatekey_file.py} | 2 +- 6 files changed, 313 insertions(+), 6 deletions(-) create mode 100644 cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py rename cosmos/profiles/snowflake/{user_encrypted_privatekey.py => user_encrypted_privatekey_file.py} (95%) create mode 100644 tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py rename tests/profiles/snowflake/{test_snowflake_user_encrypted_privatekey.py => test_snowflake_user_encrypted_privatekey_file.py} (99%) diff --git a/cosmos/profiles/__init__.py b/cosmos/profiles/__init__.py index 47c7309abc..8280cd9508 100644 --- a/cosmos/profiles/__init__.py +++ b/cosmos/profiles/__init__.py @@ -16,7 +16,8 @@ from .redshift.user_pass import RedshiftUserPasswordProfileMapping from .snowflake.user_pass import SnowflakeUserPasswordProfileMapping from .snowflake.user_privatekey import SnowflakePrivateKeyPemProfileMapping -from .snowflake.user_encrypted_privatekey import SnowflakeEncryptedPrivateKeyPemProfileMapping +from .snowflake.user_encrypted_privatekey_file import SnowflakeEncryptedPrivateKeyFilePemProfileMapping +from .snowflake.user_encrypted_privatekey_env_variable import SnowflakeEncryptedPrivateKeyPemProfileMapping from .spark.thrift import SparkThriftProfileMapping from .trino.certificate import TrinoCertificateProfileMapping from .trino.jwt import TrinoJWTProfileMapping @@ -32,6 +33,7 @@ PostgresUserPasswordProfileMapping, RedshiftUserPasswordProfileMapping, SnowflakeUserPasswordProfileMapping, + SnowflakeEncryptedPrivateKeyFilePemProfileMapping, SnowflakeEncryptedPrivateKeyPemProfileMapping, SnowflakePrivateKeyPemProfileMapping, SparkThriftProfileMapping, @@ -71,6 +73,7 @@ def get_automatic_profile_mapping( "RedshiftUserPasswordProfileMapping", "SnowflakeUserPasswordProfileMapping", "SnowflakePrivateKeyPemProfileMapping", + "SnowflakeEncryptedPrivateKeyFilePemProfileMapping", "SparkThriftProfileMapping", "ExasolUserPasswordProfileMapping", "TrinoLDAPProfileMapping", diff --git a/cosmos/profiles/snowflake/__init__.py b/cosmos/profiles/snowflake/__init__.py index 26c3fb5954..fdf323a766 100644 --- a/cosmos/profiles/snowflake/__init__.py +++ b/cosmos/profiles/snowflake/__init__.py @@ -2,10 +2,12 @@ from .user_pass import SnowflakeUserPasswordProfileMapping from .user_privatekey import SnowflakePrivateKeyPemProfileMapping -from .user_encrypted_privatekey import SnowflakeEncryptedPrivateKeyPemProfileMapping +from .user_encrypted_privatekey_file import SnowflakeEncryptedPrivateKeyFilePemProfileMapping +from .user_encrypted_privatekey_env_variable import SnowflakeEncryptedPrivateKeyPemProfileMapping __all__ = [ "SnowflakeUserPasswordProfileMapping", "SnowflakePrivateKeyPemProfileMapping", + "SnowflakeEncryptedPrivateKeyFilePemProfileMapping", "SnowflakeEncryptedPrivateKeyPemProfileMapping", ] diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py new file mode 100644 index 0000000000..3fa7aaaf4d --- /dev/null +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py @@ -0,0 +1,86 @@ +"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key." +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any + +from ..base import BaseProfileMapping + +if TYPE_CHECKING: + from airflow.models import Connection + + +class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping): + """ + Maps Airflow Snowflake connections to dbt profiles if they use a user/private key. + https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication + https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html + """ + + airflow_connection_type: str = "snowflake" + dbt_profile_type: str = "snowflake" + is_community: bool = True + + required_fields = [ + "account", + "user", + "database", + "warehouse", + "schema", + "private_key", + "private_key_passphrase", + ] + secret_fields = [ + "private_key", + "private_key_passphrase", + ] + airflow_param_mapping = { + "account": "extra.account", + "user": "login", + "database": "extra.database", + "warehouse": "extra.warehouse", + "schema": "schema", + "role": "extra.role", + "private_key": "extra.private_key_content", + "private_key_passphrase": "password", + } + + @property + def conn(self) -> Connection: + """ + Snowflake can be odd because the fields used to be stored with keys in the format + 'extra__snowflake__account', but now are stored as 'account'. + + This standardizes the keys to be 'account', 'database', etc. + """ + conn = super().conn + + conn_dejson = conn.extra_dejson + + if conn_dejson.get("extra__snowflake__account"): + conn_dejson = {key.replace("extra__snowflake__", ""): value for key, value in conn_dejson.items()} + + conn.extra = json.dumps(conn_dejson) + + return conn + + @property + def profile(self) -> dict[str, Any | None]: + "Gets profile." + profile_vars = { + **self.mapped_params, + **self.profile_args, + "private_key": self.get_env_var_format("private_key"), + "private_key_passphrase": self.get_env_var_format("private_key_passphrase"), + } + + # remove any null values + return self.filter_null(profile_vars) + + def transform_account(self, account: str) -> str: + "Transform the account to the format . if it's not already." + region = self.conn.extra_dejson.get("region") + if region and region not in account: + account = f"{account}.{region}" + + return str(account) diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py similarity index 95% rename from cosmos/profiles/snowflake/user_encrypted_privatekey.py rename to cosmos/profiles/snowflake/user_encrypted_privatekey_file.py index 0623598be6..96e45080af 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py @@ -1,4 +1,4 @@ -"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key." +"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key path." from __future__ import annotations import json @@ -10,9 +10,9 @@ from airflow.models import Connection -class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping): +class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(BaseProfileMapping): """ - Maps Airflow Snowflake connections to dbt profiles if they use a user/private key. + Maps Airflow Snowflake connections to dbt profiles if they use a user/private key path. https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html """ diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py new file mode 100644 index 0000000000..64c33d337b --- /dev/null +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py @@ -0,0 +1,216 @@ +"Tests for the Snowflake user/private key file profile." + +import json +from unittest.mock import patch + +import pytest +from airflow.models.connection import Connection + +from cosmos.profiles import get_automatic_profile_mapping +from cosmos.profiles.snowflake import ( + SnowflakeEncryptedPrivateKeyFilePemProfileMapping, +) + + +@pytest.fixture() +def mock_snowflake_conn(): # type: ignore + """ + Sets the connection as an environment variable. + """ + conn = Connection( + conn_id="my_snowflake_pk_connection", + conn_type="snowflake", + login="my_user", + schema="my_schema", + password="secret", + extra=json.dumps( + { + "account": "my_account", + "region": "my_region", + "database": "my_database", + "warehouse": "my_warehouse", + "private_key_content": "my_private_key", + } + ), + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + + +def test_connection_claiming() -> None: + """ + Tests that the Snowflake profile mapping claims the correct connection type. + """ + potential_values = { + "conn_type": "snowflake", + "login": "my_user", + "schema": "my_database", + "password": "secret", + "extra": json.dumps( + { + "account": "my_account", + "database": "my_database", + "warehouse": "my_warehouse", + "private_key_content": "my_private_key", + } + ), + } + + # if we're missing any of the values, it shouldn't claim + for key in potential_values: + values = potential_values.copy() + del values[key] + conn = Connection(**values) # type: ignore + + print("testing with", values) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping( + conn, + ) + assert not profile_mapping.can_claim_connection() + + # test when we're missing the account + conn = Connection(**potential_values) # type: ignore + conn.extra = '{"database": "my_database", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' + print("testing with", conn.extra) + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + assert not profile_mapping.can_claim_connection() + + # test when we're missing the database + conn = Connection(**potential_values) # type: ignore + conn.extra = '{"account": "my_account", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' + print("testing with", conn.extra) + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + assert not profile_mapping.can_claim_connection() + + # test when we're missing the warehouse + conn = Connection(**potential_values) # type: ignore + conn.extra = '{"account": "my_account", "database": "my_database", "private_key_content": "my_private_key"}' + print("testing with", conn.extra) + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + assert not profile_mapping.can_claim_connection() + + # if we have them all, it should claim + conn = Connection(**potential_values) # type: ignore + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + assert profile_mapping.can_claim_connection() + + +def test_profile_mapping_selected( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that the correct profile mapping is selected. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn.conn_id, + ) + assert isinstance(profile_mapping, SnowflakeEncryptedPrivateKeyFilePemProfileMapping) + + +def test_profile_args( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that the profile values get set correctly. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn.conn_id, + ) + + mock_account = mock_snowflake_conn.extra_dejson.get("account") + mock_region = mock_snowflake_conn.extra_dejson.get("region") + + assert profile_mapping.profile == { + "type": mock_snowflake_conn.conn_type, + "user": mock_snowflake_conn.login, + "private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}", + "private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}", + "schema": mock_snowflake_conn.schema, + "account": f"{mock_account}.{mock_region}", + "database": mock_snowflake_conn.extra_dejson.get("database"), + "warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"), + } + + +def test_profile_args_overrides( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that you can override the profile values. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn.conn_id, + profile_args={"database": "my_db_override"}, + ) + assert profile_mapping.profile_args == { + "database": "my_db_override", + } + + mock_account = mock_snowflake_conn.extra_dejson.get("account") + mock_region = mock_snowflake_conn.extra_dejson.get("region") + + assert profile_mapping.profile == { + "type": mock_snowflake_conn.conn_type, + "user": mock_snowflake_conn.login, + "private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}", + "private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}", + "schema": mock_snowflake_conn.schema, + "account": f"{mock_account}.{mock_region}", + "database": "my_db_override", + "warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"), + } + + +def test_profile_env_vars( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that the environment variables get set correctly. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn.conn_id, + ) + assert profile_mapping.env_vars == { + "COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE": mock_snowflake_conn.password, + } + + +def test_old_snowflake_format() -> None: + """ + Tests that the old format still works. + """ + conn = Connection( + conn_id="my_snowflake_connection", + conn_type="snowflake", + login="my_user", + schema="my_schema", + password="secret", + extra=json.dumps( + { + "extra__snowflake__account": "my_account", + "extra__snowflake__database": "my_database", + "extra__snowflake__warehouse": "my_warehouse", + "extra__snowflake__private_key_content": "my_private_key", + } + ), + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + assert profile_mapping.profile == { + "type": conn.conn_type, + "user": conn.login, + "private_key": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY') }}", + "private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}", + "schema": conn.schema, + "account": conn.extra_dejson.get("account"), + "database": conn.extra_dejson.get("database"), + "warehouse": conn.extra_dejson.get("warehouse"), + } diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py similarity index 99% rename from tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey.py rename to tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py index b61b85094b..b3ccde9777 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py @@ -1,4 +1,4 @@ -"Tests for the Snowflake user/private key profile." +"Tests for the Snowflake user/private key file profile." import json from unittest.mock import patch From 231029af4550d70c2a0c7f58253f5a25f099bdb0 Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 11:02:21 +0000 Subject: [PATCH 02/11] Updating test description --- .../test_snowflake_user_encrypted_privatekey_env_variable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py index 64c33d337b..314bb489e4 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py @@ -1,4 +1,4 @@ -"Tests for the Snowflake user/private key file profile." +"Tests for the Snowflake user/private key enviroment variable profile." import json from unittest.mock import patch From 79c6adae9dd55207a6e1c7c0a96f25132c19d2ba Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 11:13:32 +0000 Subject: [PATCH 03/11] Work around for user/password mapping --- cosmos/profiles/snowflake/user_pass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cosmos/profiles/snowflake/user_pass.py b/cosmos/profiles/snowflake/user_pass.py index 2e1025a2c6..ce3ea64726 100644 --- a/cosmos/profiles/snowflake/user_pass.py +++ b/cosmos/profiles/snowflake/user_pass.py @@ -44,7 +44,7 @@ class SnowflakeUserPasswordProfileMapping(BaseProfileMapping): def can_claim_connection(self) -> bool: # Make sure this isn't a private key path credential result = super().can_claim_connection() - if result and self.conn.extra_dejson.get("private_key_file") is not None: + if result and self.conn.extra_dejson.get("private_key_file") is not None and self.conn.extra_dejson.get("private_key_content") is not None: return False return result From c00813ab82c7d892c0d10cff0e0bf409b750e880 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 4 Nov 2023 11:14:01 +0000 Subject: [PATCH 04/11] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cosmos/profiles/snowflake/user_pass.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cosmos/profiles/snowflake/user_pass.py b/cosmos/profiles/snowflake/user_pass.py index ce3ea64726..9ced6b691a 100644 --- a/cosmos/profiles/snowflake/user_pass.py +++ b/cosmos/profiles/snowflake/user_pass.py @@ -44,7 +44,11 @@ class SnowflakeUserPasswordProfileMapping(BaseProfileMapping): def can_claim_connection(self) -> bool: # Make sure this isn't a private key path credential result = super().can_claim_connection() - if result and self.conn.extra_dejson.get("private_key_file") is not None and self.conn.extra_dejson.get("private_key_content") is not None: + if ( + result + and self.conn.extra_dejson.get("private_key_file") is not None + and self.conn.extra_dejson.get("private_key_content") is not None + ): return False return result From d66a8a3b3ac20ef76c0391db406fbf507ebca77c Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 11:26:10 +0000 Subject: [PATCH 05/11] Adding conditions on claiming connections --- .../user_encrypted_privatekey_env_variable.py | 10 ++++++++++ .../snowflake/user_encrypted_privatekey_file.py | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py index 3fa7aaaf4d..f4c57154bc 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py @@ -45,6 +45,16 @@ class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping): "private_key_passphrase": "password", } + def can_claim_connection(self) -> bool: + # Make sure this isn't a private key path credential + result = super().can_claim_connection() + if ( + result + and self.conn.extra_dejson.get("private_key_file") is not None + ): + return False + return result + @property def conn(self) -> Connection: """ diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py index 96e45080af..8c57931bb4 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py @@ -44,6 +44,16 @@ class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(BaseProfileMapping): "private_key_path": "extra.private_key_file", } + def can_claim_connection(self) -> bool: + # Make sure this isn't a private key enviroment variable + result = super().can_claim_connection() + if ( + result + and self.conn.extra_dejson.get("private_key_content") is not None + ): + return False + return result + @property def conn(self) -> Connection: """ From 339d500b057279e8c8b528f5ee87b66a2bf30ec3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 4 Nov 2023 11:26:35 +0000 Subject: [PATCH 06/11] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../snowflake/user_encrypted_privatekey_env_variable.py | 5 +---- cosmos/profiles/snowflake/user_encrypted_privatekey_file.py | 5 +---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py index f4c57154bc..fecfa97fe7 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py @@ -48,10 +48,7 @@ class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping): def can_claim_connection(self) -> bool: # Make sure this isn't a private key path credential result = super().can_claim_connection() - if ( - result - and self.conn.extra_dejson.get("private_key_file") is not None - ): + if result and self.conn.extra_dejson.get("private_key_file") is not None: return False return result diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py index 8c57931bb4..c806658e1e 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py @@ -47,10 +47,7 @@ class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(BaseProfileMapping): def can_claim_connection(self) -> bool: # Make sure this isn't a private key enviroment variable result = super().can_claim_connection() - if ( - result - and self.conn.extra_dejson.get("private_key_content") is not None - ): + if result and self.conn.extra_dejson.get("private_key_content") is not None: return False return result From afa4aac44e5e7506ba91835da4208c8104a51c1a Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 11:40:57 +0000 Subject: [PATCH 07/11] Updating import on tests --- ...ake_user_encrypted_privatekey_env_variable.py | 16 ++++++++-------- ...t_snowflake_user_encrypted_privatekey_file.py | 16 ++++++++-------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py index 314bb489e4..4ed746fb78 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py @@ -8,7 +8,7 @@ from cosmos.profiles import get_automatic_profile_mapping from cosmos.profiles.snowflake import ( - SnowflakeEncryptedPrivateKeyFilePemProfileMapping, + SnowflakeEncryptedPrivateKeyPemProfileMapping, ) @@ -66,7 +66,7 @@ def test_connection_claiming() -> None: print("testing with", values) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping( + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping( conn, ) assert not profile_mapping.can_claim_connection() @@ -76,7 +76,7 @@ def test_connection_claiming() -> None: conn.extra = '{"database": "my_database", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' print("testing with", conn.extra) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) assert not profile_mapping.can_claim_connection() # test when we're missing the database @@ -84,7 +84,7 @@ def test_connection_claiming() -> None: conn.extra = '{"account": "my_account", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' print("testing with", conn.extra) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) assert not profile_mapping.can_claim_connection() # test when we're missing the warehouse @@ -92,13 +92,13 @@ def test_connection_claiming() -> None: conn.extra = '{"account": "my_account", "database": "my_database", "private_key_content": "my_private_key"}' print("testing with", conn.extra) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) assert not profile_mapping.can_claim_connection() # if we have them all, it should claim conn = Connection(**potential_values) # type: ignore with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) assert profile_mapping.can_claim_connection() @@ -111,7 +111,7 @@ def test_profile_mapping_selected( profile_mapping = get_automatic_profile_mapping( mock_snowflake_conn.conn_id, ) - assert isinstance(profile_mapping, SnowflakeEncryptedPrivateKeyFilePemProfileMapping) + assert isinstance(profile_mapping, SnowflakeEncryptedPrivateKeyPemProfileMapping) def test_profile_args( @@ -203,7 +203,7 @@ def test_old_snowflake_format() -> None: ) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) assert profile_mapping.profile == { "type": conn.conn_type, "user": conn.login, diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py index b3ccde9777..d8c3aedcf6 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_file.py @@ -8,7 +8,7 @@ from cosmos.profiles import get_automatic_profile_mapping from cosmos.profiles.snowflake import ( - SnowflakeEncryptedPrivateKeyPemProfileMapping, + SnowflakeEncryptedPrivateKeyFilePemProfileMapping, ) @@ -66,7 +66,7 @@ def test_connection_claiming() -> None: print("testing with", values) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping( + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping( conn, ) assert not profile_mapping.can_claim_connection() @@ -76,7 +76,7 @@ def test_connection_claiming() -> None: conn.extra = '{"database": "my_database", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' print("testing with", conn.extra) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) assert not profile_mapping.can_claim_connection() # test when we're missing the database @@ -84,7 +84,7 @@ def test_connection_claiming() -> None: conn.extra = '{"account": "my_account", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' print("testing with", conn.extra) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) assert not profile_mapping.can_claim_connection() # test when we're missing the warehouse @@ -92,13 +92,13 @@ def test_connection_claiming() -> None: conn.extra = '{"account": "my_account", "database": "my_database", "private_key_content": "my_private_key"}' print("testing with", conn.extra) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) assert not profile_mapping.can_claim_connection() # if we have them all, it should claim conn = Connection(**potential_values) # type: ignore with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) assert profile_mapping.can_claim_connection() @@ -111,7 +111,7 @@ def test_profile_mapping_selected( profile_mapping = get_automatic_profile_mapping( mock_snowflake_conn.conn_id, ) - assert isinstance(profile_mapping, SnowflakeEncryptedPrivateKeyPemProfileMapping) + assert isinstance(profile_mapping, SnowflakeEncryptedPrivateKeyFilePemProfileMapping) def test_profile_args( @@ -203,7 +203,7 @@ def test_old_snowflake_format() -> None: ) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) + profile_mapping = SnowflakeEncryptedPrivateKeyFilePemProfileMapping(conn) assert profile_mapping.profile == { "type": conn.conn_type, "user": conn.login, From 210721be533bb71bb2dbf3299ce99de8db6a6a38 Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 11:47:46 +0000 Subject: [PATCH 08/11] Or condition for user pass --- cosmos/profiles/snowflake/user_pass.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cosmos/profiles/snowflake/user_pass.py b/cosmos/profiles/snowflake/user_pass.py index 9ced6b691a..c24571a39f 100644 --- a/cosmos/profiles/snowflake/user_pass.py +++ b/cosmos/profiles/snowflake/user_pass.py @@ -46,8 +46,8 @@ def can_claim_connection(self) -> bool: result = super().can_claim_connection() if ( result - and self.conn.extra_dejson.get("private_key_file") is not None - and self.conn.extra_dejson.get("private_key_content") is not None + and (self.conn.extra_dejson.get("private_key_file") is not None + or self.conn.extra_dejson.get("private_key_content") is not None) ): return False return result From 7d265e869b96198661cb1e5260f108bfbda61dcc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 4 Nov 2023 11:48:08 +0000 Subject: [PATCH 09/11] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cosmos/profiles/snowflake/user_pass.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cosmos/profiles/snowflake/user_pass.py b/cosmos/profiles/snowflake/user_pass.py index c24571a39f..fa634d1a2e 100644 --- a/cosmos/profiles/snowflake/user_pass.py +++ b/cosmos/profiles/snowflake/user_pass.py @@ -44,10 +44,9 @@ class SnowflakeUserPasswordProfileMapping(BaseProfileMapping): def can_claim_connection(self) -> bool: # Make sure this isn't a private key path credential result = super().can_claim_connection() - if ( - result - and (self.conn.extra_dejson.get("private_key_file") is not None - or self.conn.extra_dejson.get("private_key_content") is not None) + if result and ( + self.conn.extra_dejson.get("private_key_file") is not None + or self.conn.extra_dejson.get("private_key_content") is not None ): return False return result From 7cc83777d8354c96f39ef5f40d715b9af5b21977 Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 11:52:24 +0000 Subject: [PATCH 10/11] Adding private key to env test --- .../test_snowflake_user_encrypted_privatekey_env_variable.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py index 4ed746fb78..7ed7f7d3d1 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py @@ -178,6 +178,7 @@ def test_profile_env_vars( mock_snowflake_conn.conn_id, ) assert profile_mapping.env_vars == { + "COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY": mock_snowflake_conn.extra_dejson.get("private_key_content"), "COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE": mock_snowflake_conn.password, } From 685586f0a9cf45cdcb01a7e101c324b3eed713fd Mon Sep 17 00:00:00 2001 From: DanMawdsleyBA Date: Sat, 4 Nov 2023 11:57:42 +0000 Subject: [PATCH 11/11] Fixing typo --- cosmos/profiles/snowflake/user_encrypted_privatekey_file.py | 2 +- .../test_snowflake_user_encrypted_privatekey_env_variable.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py index c806658e1e..6831cbd280 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py @@ -45,7 +45,7 @@ class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(BaseProfileMapping): } def can_claim_connection(self) -> bool: - # Make sure this isn't a private key enviroment variable + # Make sure this isn't a private key environmentvariable result = super().can_claim_connection() if result and self.conn.extra_dejson.get("private_key_content") is not None: return False diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py index 7ed7f7d3d1..2c7515f72f 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py @@ -1,4 +1,4 @@ -"Tests for the Snowflake user/private key enviroment variable profile." +"Tests for the Snowflake user/private key environmentvariable profile." import json from unittest.mock import patch