diff --git a/cosmos/profiles/snowflake/base.py b/cosmos/profiles/snowflake/base.py index 599a9c8e54..8a31486d02 100644 --- a/cosmos/profiles/snowflake/base.py +++ b/cosmos/profiles/snowflake/base.py @@ -1,5 +1,7 @@ from __future__ import annotations +import base64 +import binascii from typing import Any from cosmos.profiles.base import BaseProfileMapping @@ -26,3 +28,25 @@ def transform_account(self, account: str) -> str: account = f"{account}.{region}" return str(account) + + def _decode_private_key_content(self, private_key_content: str) -> str: + """ + Decodes the private key content from either base64-encoded or plain-text PEM format. + + Starting from `apache-airflow-providers-snowflake` version 6.3.0, the provider expects the + `private_key_content` to be base64-encoded rather than raw PEM text. This method ensures + compatibility by attempting to decode the content from base64 first. If decoding fails, + the original content is assumed to be plain-text PEM (as used in older versions). + + This allows backward compatibility while supporting the new expected format. + + Args: + private_key_content: The private key content, either base64 encoded or plain-text PEM + + Returns: + The decoded private key in plain-text PEM format + """ + try: + return base64.b64decode(private_key_content).decode("utf-8") + except (UnicodeDecodeError, binascii.Error): + return private_key_content diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py index 63a6c68d34..1b754d08a0 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py @@ -68,6 +68,10 @@ def conn(self) -> Connection: if conn_dejson.get("extra__snowflake__account"): conn_dejson = {key.replace("extra__snowflake__", ""): value for key, value in conn_dejson.items()} + private_key_content = conn_dejson.get("private_key_content") + if private_key_content: + conn_dejson["private_key_content"] = self._decode_private_key_content(private_key_content) + conn.extra = json.dumps(conn_dejson) return conn diff --git a/cosmos/profiles/snowflake/user_privatekey.py b/cosmos/profiles/snowflake/user_privatekey.py index 40a016af76..d759ec468c 100644 --- a/cosmos/profiles/snowflake/user_privatekey.py +++ b/cosmos/profiles/snowflake/user_privatekey.py @@ -58,6 +58,10 @@ def conn(self) -> Connection: if conn_dejson.get("extra__snowflake__account"): conn_dejson = {key.replace("extra__snowflake__", ""): value for key, value in conn_dejson.items()} + private_key_content = conn_dejson.get("private_key_content") + if private_key_content: + conn_dejson["private_key_content"] = self._decode_private_key_content(private_key_content) + conn.extra = json.dumps(conn_dejson) return conn diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py index ff00ee6980..8ff6f121ae 100644 --- a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey_env_variable.py @@ -1,5 +1,6 @@ """Tests for the Snowflake user/private key environmentvariable profile.""" +import base64 import json from unittest.mock import patch @@ -11,6 +12,34 @@ SnowflakeEncryptedPrivateKeyPemProfileMapping, ) +PLAIN_TEXT_KEY = "mocked-private-key-content" +BASE64_ENCODED_KEY = base64.b64encode(PLAIN_TEXT_KEY.encode("utf-8")).decode("utf-8") + + +@pytest.fixture() +def mock_snowflake_conn_base64(): # type: ignore + """ + Sets a connection with a base64 encoded private key as an environment variable. + """ + conn = Connection( + conn_id="my_snowflake_pk_connection_base64", + conn_type="snowflake", + login="my_user", + schema="my_schema", + extra=json.dumps( + { + "account": "my_account", + "database": "my_database", + "warehouse": "my_warehouse", + "private_key_content": BASE64_ENCODED_KEY, + "private_key_passphrase": "my_private_key_passphrase", + } + ), + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + @pytest.fixture() def mock_snowflake_conn(): # type: ignore @@ -38,6 +67,24 @@ def mock_snowflake_conn(): # type: ignore yield conn +@pytest.mark.parametrize( + "input_key, expected_key", + [ + (BASE64_ENCODED_KEY, PLAIN_TEXT_KEY), + (PLAIN_TEXT_KEY, PLAIN_TEXT_KEY), + ], + ids=["base64_encoded", "plain_text"], +) +def test_decode_private_key_content(input_key: str, expected_key: str, mock_snowflake_conn_base64: Connection) -> None: + """ + Tests that the private key content is decoded correctly. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn_base64.conn_id, + ) + assert profile_mapping._decode_private_key_content(input_key) == expected_key + + def test_connection_claiming() -> None: """ Tests that the Snowflake profile mapping claims the correct connection type. @@ -183,6 +230,20 @@ def test_profile_env_vars( } +def test_profile_env_vars_with_base64( + mock_snowflake_conn_base64: Connection, +) -> None: + """ + Tests that the environment variable get set correctly for a base64-encoded key. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn_base64.conn_id, + ) + assert profile_mapping.env_vars == { + "COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY": PLAIN_TEXT_KEY, + } + + def test_old_snowflake_format() -> None: """ Tests that the old format still works. diff --git a/tests/profiles/snowflake/test_snowflake_user_privatekey.py b/tests/profiles/snowflake/test_snowflake_user_privatekey.py index edbf2b464f..bd7785fc4c 100644 --- a/tests/profiles/snowflake/test_snowflake_user_privatekey.py +++ b/tests/profiles/snowflake/test_snowflake_user_privatekey.py @@ -1,5 +1,6 @@ """Tests for the Snowflake user/private key profile.""" +import base64 import json from unittest.mock import patch @@ -11,6 +12,33 @@ SnowflakePrivateKeyPemProfileMapping, ) +PLAIN_TEXT_KEY = "mocked-private-key-content" +BASE64_ENCODED_KEY = base64.b64encode(PLAIN_TEXT_KEY.encode("utf-8")).decode("utf-8") + + +@pytest.fixture() +def mock_snowflake_conn_base64(): # type: ignore + """ + Sets a connection with a base64 encoded private key as an environment variable. + """ + conn = Connection( + conn_id="my_snowflake_pk_connection_base64", + conn_type="snowflake", + login="my_user", + schema="my_schema", + extra=json.dumps( + { + "account": "my_account", + "database": "my_database", + "warehouse": "my_warehouse", + "private_key_content": BASE64_ENCODED_KEY, + } + ), + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + @pytest.fixture() def mock_snowflake_conn(): # type: ignore @@ -36,6 +64,24 @@ def mock_snowflake_conn(): # type: ignore yield conn +@pytest.mark.parametrize( + "input_key, expected_key", + [ + (BASE64_ENCODED_KEY, PLAIN_TEXT_KEY), + (PLAIN_TEXT_KEY, PLAIN_TEXT_KEY), + ], + ids=["base64_encoded", "plain_text"], +) +def test_decode_private_key_content(input_key: str, expected_key: str, mock_snowflake_conn_base64: Connection) -> None: + """ + Tests that the private key content is decoded correctly. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn_base64.conn_id, + ) + assert profile_mapping._decode_private_key_content(input_key) == expected_key + + def test_connection_claiming() -> None: """ Tests that the Snowflake profile mapping claims the correct connection type. @@ -180,6 +226,20 @@ def test_profile_env_vars( } +def test_profile_env_vars_with_base64( + mock_snowflake_conn_base64: Connection, +) -> None: + """ + Tests that the environment variable get set correctly for a base64-encoded key. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn_base64.conn_id, + ) + assert profile_mapping.env_vars == { + "COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY": PLAIN_TEXT_KEY, + } + + def test_old_snowflake_format() -> None: """ Tests that the old format still works.