From 6d1f06bbe01f5d150f0957c4ec8cfe7fa9809385 Mon Sep 17 00:00:00 2001 From: Bruno Costa Martins Date: Mon, 9 Jun 2025 17:51:25 -0300 Subject: [PATCH 1/7] Support base64 encoded private_key_content for Airflow Snowflake provider >= 6.3 Starting from `apache-airflow-providers-snowflake` version 6.3, the provider expects the `private_key_content` field to be base64 encoded instead of plain text PEM. This change adds decoding of base64 encoded keys while preserving backwards compatibility with older versions that use plain text PEM. --- cosmos/profiles/snowflake/user_privatekey.py | 29 ++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/cosmos/profiles/snowflake/user_privatekey.py b/cosmos/profiles/snowflake/user_privatekey.py index 40a016af76..c88cfdae82 100644 --- a/cosmos/profiles/snowflake/user_privatekey.py +++ b/cosmos/profiles/snowflake/user_privatekey.py @@ -2,7 +2,10 @@ from __future__ import annotations +import base64 +import binascii import json +import logging from typing import TYPE_CHECKING, Any from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping @@ -10,6 +13,8 @@ if TYPE_CHECKING: from airflow.models import Connection +logger = logging.getLogger(__name__) + class SnowflakePrivateKeyPemProfileMapping(SnowflakeBaseProfileMapping): """ @@ -43,6 +48,29 @@ class SnowflakePrivateKeyPemProfileMapping(SnowflakeBaseProfileMapping): "private_key": "extra.private_key_content", } + def _decode_private_key_content(self, private_key_content: str) -> str: + """ + Decodes the private key content from either base64-encoded or plain-text PEM format. + + Starting from `apache-airflow-providers-snowflake` version 6.3.0, the provider expects the + `private_key_content` to be base64-encoded rather than raw PEM text. This method ensures + compatibility by attempting to decode the content from base64 first. If decoding fails, + the original content is assumed to be plain-text PEM (as used in older versions). + + This allows backward compatibility while supporting the new expected format. + + Args: + private_key_content: The private key content, either base64 encoded or plain-text PEM + + Returns: + The decoded private key in plain-text PEM format + """ + try: + decoded_key = base64.b64decode(private_key_content).decode("utf-8") + return decoded_key + except (UnicodeDecodeError, binascii.Error): + return private_key_content + @property def conn(self) -> Connection: """ @@ -58,6 +86,7 @@ def conn(self) -> Connection: if conn_dejson.get("extra__snowflake__account"): conn_dejson = {key.replace("extra__snowflake__", ""): value for key, value in conn_dejson.items()} + conn_dejson["private_key_content"] = self._decode_private_key_content(conn_dejson["private_key_content"]) conn.extra = json.dumps(conn_dejson) return conn From 7415acc7ad6a38e5d850ff1c00c3bd69c2d439ad Mon Sep 17 00:00:00 2001 From: Bruno Costa Martins Date: Mon, 9 Jun 2025 18:24:24 -0300 Subject: [PATCH 2/7] Adjust code to pass hatch cov tests --- cosmos/profiles/snowflake/user_privatekey.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cosmos/profiles/snowflake/user_privatekey.py b/cosmos/profiles/snowflake/user_privatekey.py index c88cfdae82..e576b6e9c4 100644 --- a/cosmos/profiles/snowflake/user_privatekey.py +++ b/cosmos/profiles/snowflake/user_privatekey.py @@ -86,7 +86,10 @@ def conn(self) -> Connection: if conn_dejson.get("extra__snowflake__account"): conn_dejson = {key.replace("extra__snowflake__", ""): value for key, value in conn_dejson.items()} - conn_dejson["private_key_content"] = self._decode_private_key_content(conn_dejson["private_key_content"]) + private_key_content = conn_dejson.get("private_key_content") + if private_key_content: + conn_dejson["private_key_content"] = self._decode_private_key_content(private_key_content) + conn.extra = json.dumps(conn_dejson) return conn From a15ded97860db1002016721c3ae02b161c2a2d2b Mon Sep 17 00:00:00 2001 From: Bruno Costa Martins Date: Mon, 9 Jun 2025 18:25:28 -0300 Subject: [PATCH 3/7] Remove unused logging --- cosmos/profiles/snowflake/user_privatekey.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/cosmos/profiles/snowflake/user_privatekey.py b/cosmos/profiles/snowflake/user_privatekey.py index e576b6e9c4..c5e4f22435 100644 --- a/cosmos/profiles/snowflake/user_privatekey.py +++ b/cosmos/profiles/snowflake/user_privatekey.py @@ -5,7 +5,6 @@ import base64 import binascii import json -import logging from typing import TYPE_CHECKING, Any from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping @@ -13,8 +12,6 @@ if TYPE_CHECKING: from airflow.models import Connection -logger = logging.getLogger(__name__) - class SnowflakePrivateKeyPemProfileMapping(SnowflakeBaseProfileMapping): """ From 3d6a31bce2d7f0ef2b6ff0e70e4d11407a6fb72c Mon Sep 17 00:00:00 2001 From: Bruno Costa Martins Date: Mon, 9 Jun 2025 21:08:34 -0300 Subject: [PATCH 4/7] Remove redundancy --- cosmos/profiles/snowflake/user_privatekey.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cosmos/profiles/snowflake/user_privatekey.py b/cosmos/profiles/snowflake/user_privatekey.py index c5e4f22435..3afa798cb5 100644 --- a/cosmos/profiles/snowflake/user_privatekey.py +++ b/cosmos/profiles/snowflake/user_privatekey.py @@ -63,8 +63,7 @@ def _decode_private_key_content(self, private_key_content: str) -> str: The decoded private key in plain-text PEM format """ try: - decoded_key = base64.b64decode(private_key_content).decode("utf-8") - return decoded_key + return base64.b64decode(private_key_content).decode("utf-8") except (UnicodeDecodeError, binascii.Error): return private_key_content From fc15641c415f636d389d740aff77921b1752850c Mon Sep 17 00:00:00 2001 From: Bruno Costa Martins Date: Tue, 10 Jun 2025 13:57:20 -0300 Subject: [PATCH 5/7] Move _decode_private_key_content function to SnowflakeBaseProfileMapping Move function to SnowflakeBaseProfileMapping to avoid duplicating code on SnowflakePrivateKeyPemProfileMapping and SnowflakeEncryptedPrivateKeyPemProfileMapping --- cosmos/profiles/snowflake/base.py | 24 +++++++++++++++++++ .../user_encrypted_privatekey_env_variable.py | 4 ++++ cosmos/profiles/snowflake/user_privatekey.py | 24 ------------------- 3 files changed, 28 insertions(+), 24 deletions(-) diff --git a/cosmos/profiles/snowflake/base.py b/cosmos/profiles/snowflake/base.py index 599a9c8e54..8a31486d02 100644 --- a/cosmos/profiles/snowflake/base.py +++ b/cosmos/profiles/snowflake/base.py @@ -1,5 +1,7 @@ from __future__ import annotations +import base64 +import binascii from typing import Any from cosmos.profiles.base import BaseProfileMapping @@ -26,3 +28,25 @@ def transform_account(self, account: str) -> str: account = f"{account}.{region}" return str(account) + + def _decode_private_key_content(self, private_key_content: str) -> str: + """ + Decodes the private key content from either base64-encoded or plain-text PEM format. + + Starting from `apache-airflow-providers-snowflake` version 6.3.0, the provider expects the + `private_key_content` to be base64-encoded rather than raw PEM text. This method ensures + compatibility by attempting to decode the content from base64 first. If decoding fails, + the original content is assumed to be plain-text PEM (as used in older versions). + + This allows backward compatibility while supporting the new expected format. + + Args: + private_key_content: The private key content, either base64 encoded or plain-text PEM + + Returns: + The decoded private key in plain-text PEM format + """ + try: + return base64.b64decode(private_key_content).decode("utf-8") + except (UnicodeDecodeError, binascii.Error): + return private_key_content diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py index 63a6c68d34..1b754d08a0 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py @@ -68,6 +68,10 @@ def conn(self) -> Connection: if conn_dejson.get("extra__snowflake__account"): conn_dejson = {key.replace("extra__snowflake__", ""): value for key, value in conn_dejson.items()} + private_key_content = conn_dejson.get("private_key_content") + if private_key_content: + conn_dejson["private_key_content"] = self._decode_private_key_content(private_key_content) + conn.extra = json.dumps(conn_dejson) return conn diff --git a/cosmos/profiles/snowflake/user_privatekey.py b/cosmos/profiles/snowflake/user_privatekey.py index 3afa798cb5..d759ec468c 100644 --- a/cosmos/profiles/snowflake/user_privatekey.py +++ b/cosmos/profiles/snowflake/user_privatekey.py @@ -2,8 +2,6 @@ from __future__ import annotations -import base64 -import binascii import json from typing import TYPE_CHECKING, Any @@ -45,28 +43,6 @@ class SnowflakePrivateKeyPemProfileMapping(SnowflakeBaseProfileMapping): "private_key": "extra.private_key_content", } - def _decode_private_key_content(self, private_key_content: str) -> str: - """ - Decodes the private key content from either base64-encoded or plain-text PEM format. - - Starting from `apache-airflow-providers-snowflake` version 6.3.0, the provider expects the - `private_key_content` to be base64-encoded rather than raw PEM text. This method ensures - compatibility by attempting to decode the content from base64 first. If decoding fails, - the original content is assumed to be plain-text PEM (as used in older versions). - - This allows backward compatibility while supporting the new expected format. - - Args: - private_key_content: The private key content, either base64 encoded or plain-text PEM - - Returns: - The decoded private key in plain-text PEM format - """ - try: - return base64.b64decode(private_key_content).decode("utf-8") - except (UnicodeDecodeError, binascii.Error): - return private_key_content - @property def conn(self) -> Connection: """ From 8ad6a460a47021831827685cd3021c4adabc04b1 Mon Sep 17 00:00:00 2001 From: Bruno Costa Martins Date: Tue, 10 Jun 2025 15:13:15 -0300 Subject: [PATCH 6/7] Add test to support base64 decode from SnowflakeEncryptedPrivateKeyPemProfileMapping and SnowflakePrivateKeyPemProfileMapping --- ..._user_encrypted_privatekey_env_variable.py | 61 +++++++++++++++++++ .../test_snowflake_user_privatekey.py | 60 ++++++++++++++++++ 2 files changed, 121 insertions(+) diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py index ff00ee6980..d3be0ba32b 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py @@ -1,5 +1,6 @@ """Tests for the Snowflake user/private key environmentvariable profile.""" +import base64 import json from unittest.mock import patch @@ -11,6 +12,33 @@ SnowflakeEncryptedPrivateKeyPemProfileMapping, ) +PLAIN_TEXT_KEY = "mocked-private-key-content" +BASE64_ENCODED_KEY = base64.b64encode(PLAIN_TEXT_KEY.encode("utf-8")).decode("utf-8") + + +@pytest.fixture() +def mock_snowflake_conn_base64(): # type: ignore + """ + Sets a connection with a base64 encoded private key as an environment variable. + """ + conn = Connection( + conn_id="my_snowflake_pk_connection_base64", + conn_type="snowflake", + login="my_user", + schema="my_schema", + extra=json.dumps( + { + "account": "my_account", + "database": "my_database", + "warehouse": "my_warehouse", + "private_key_content": BASE64_ENCODED_KEY, + } + ), + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + @pytest.fixture() def mock_snowflake_conn(): # type: ignore @@ -30,6 +58,7 @@ def mock_snowflake_conn(): # type: ignore "database": "my_database", "warehouse": "my_warehouse", "private_key_content": "my_private_key", + "private_key_passphrase": "my_private_key_passphrase", } ), ) @@ -38,6 +67,24 @@ def mock_snowflake_conn(): # type: ignore yield conn +@pytest.mark.parametrize( + "input_key, expected_key", + [ + (BASE64_ENCODED_KEY, PLAIN_TEXT_KEY), + (PLAIN_TEXT_KEY, PLAIN_TEXT_KEY), + ], + ids=["base64_encoded", "plain_text"], +) +def test_decode_private_key_content(input_key: str, expected_key: str, mock_snowflake_conn_base64: Connection) -> None: + """ + Tests that the private key content is decoded correctly. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn_base64.conn_id, + ) + assert profile_mapping._decode_private_key_content(input_key) == expected_key + + def test_connection_claiming() -> None: """ Tests that the Snowflake profile mapping claims the correct connection type. @@ -183,6 +230,20 @@ def test_profile_env_vars( } +def test_profile_env_vars_with_base64( + mock_snowflake_conn_base64: Connection, +) -> None: + """ + Tests that the environment variable get set correctly for a base64-encoded key. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn_base64.conn_id, + ) + assert profile_mapping.env_vars == { + "COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY": PLAIN_TEXT_KEY, + } + + def test_old_snowflake_format() -> None: """ Tests that the old format still works. diff --git a/tests/profiles/snowflake/test_snowflake_user_privatekey.py b/tests/profiles/snowflake/test_snowflake_user_privatekey.py index edbf2b464f..bd7785fc4c 100644 --- a/tests/profiles/snowflake/test_snowflake_user_privatekey.py +++ b/tests/profiles/snowflake/test_snowflake_user_privatekey.py @@ -1,5 +1,6 @@ """Tests for the Snowflake user/private key profile.""" +import base64 import json from unittest.mock import patch @@ -11,6 +12,33 @@ SnowflakePrivateKeyPemProfileMapping, ) +PLAIN_TEXT_KEY = "mocked-private-key-content" +BASE64_ENCODED_KEY = base64.b64encode(PLAIN_TEXT_KEY.encode("utf-8")).decode("utf-8") + + +@pytest.fixture() +def mock_snowflake_conn_base64(): # type: ignore + """ + Sets a connection with a base64 encoded private key as an environment variable. + """ + conn = Connection( + conn_id="my_snowflake_pk_connection_base64", + conn_type="snowflake", + login="my_user", + schema="my_schema", + extra=json.dumps( + { + "account": "my_account", + "database": "my_database", + "warehouse": "my_warehouse", + "private_key_content": BASE64_ENCODED_KEY, + } + ), + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + @pytest.fixture() def mock_snowflake_conn(): # type: ignore @@ -36,6 +64,24 @@ def mock_snowflake_conn(): # type: ignore yield conn +@pytest.mark.parametrize( + "input_key, expected_key", + [ + (BASE64_ENCODED_KEY, PLAIN_TEXT_KEY), + (PLAIN_TEXT_KEY, PLAIN_TEXT_KEY), + ], + ids=["base64_encoded", "plain_text"], +) +def test_decode_private_key_content(input_key: str, expected_key: str, mock_snowflake_conn_base64: Connection) -> None: + """ + Tests that the private key content is decoded correctly. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn_base64.conn_id, + ) + assert profile_mapping._decode_private_key_content(input_key) == expected_key + + def test_connection_claiming() -> None: """ Tests that the Snowflake profile mapping claims the correct connection type. @@ -180,6 +226,20 @@ def test_profile_env_vars( } +def test_profile_env_vars_with_base64( + mock_snowflake_conn_base64: Connection, +) -> None: + """ + Tests that the environment variable get set correctly for a base64-encoded key. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn_base64.conn_id, + ) + assert profile_mapping.env_vars == { + "COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY": PLAIN_TEXT_KEY, + } + + def test_old_snowflake_format() -> None: """ Tests that the old format still works. From 3c0cbdc4fabf3f7128c39f9e1e0ca11570fa3de9 Mon Sep 17 00:00:00 2001 From: Bruno Costa Martins Date: Tue, 10 Jun 2025 15:18:02 -0300 Subject: [PATCH 7/7] Change private_key_passphrase to correct location --- .../test_snowflake_user_encrypted_privatekey_env_variable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py index d3be0ba32b..8ff6f121ae 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py @@ -32,6 +32,7 @@ def mock_snowflake_conn_base64(): # type: ignore "database": "my_database", "warehouse": "my_warehouse", "private_key_content": BASE64_ENCODED_KEY, + "private_key_passphrase": "my_private_key_passphrase", } ), ) @@ -58,7 +59,6 @@ def mock_snowflake_conn(): # type: ignore "database": "my_database", "warehouse": "my_warehouse", "private_key_content": "my_private_key", - "private_key_passphrase": "my_private_key_passphrase", } ), )