Skip to content
24 changes: 24 additions & 0 deletions cosmos/profiles/snowflake/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import base64
import binascii
from typing import Any

from cosmos.profiles.base import BaseProfileMapping
Expand All @@ -26,3 +28,25 @@ def transform_account(self, account: str) -> str:
account = f"{account}.{region}"

return str(account)

def _decode_private_key_content(self, private_key_content: str) -> str:
"""
Decodes the private key content from either base64-encoded or plain-text PEM format.

Starting from `apache-airflow-providers-snowflake` version 6.3.0, the provider expects the
`private_key_content` to be base64-encoded rather than raw PEM text. This method ensures
compatibility by attempting to decode the content from base64 first. If decoding fails,
the original content is assumed to be plain-text PEM (as used in older versions).

This allows backward compatibility while supporting the new expected format.

Args:
private_key_content: The private key content, either base64 encoded or plain-text PEM

Returns:
The decoded private key in plain-text PEM format
"""
try:
return base64.b64decode(private_key_content).decode("utf-8")
except (UnicodeDecodeError, binascii.Error):
return private_key_content
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def conn(self) -> Connection:
if conn_dejson.get("extra__snowflake__account"):
conn_dejson = {key.replace("extra__snowflake__", ""): value for key, value in conn_dejson.items()}

private_key_content = conn_dejson.get("private_key_content")
if private_key_content:
conn_dejson["private_key_content"] = self._decode_private_key_content(private_key_content)
Comment thread
pankajastro marked this conversation as resolved.

conn.extra = json.dumps(conn_dejson)

return conn
Expand Down
4 changes: 4 additions & 0 deletions cosmos/profiles/snowflake/user_privatekey.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def conn(self) -> Connection:
if conn_dejson.get("extra__snowflake__account"):
conn_dejson = {key.replace("extra__snowflake__", ""): value for key, value in conn_dejson.items()}

private_key_content = conn_dejson.get("private_key_content")
if private_key_content:
conn_dejson["private_key_content"] = self._decode_private_key_content(private_key_content)

conn.extra = json.dumps(conn_dejson)

return conn
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for the Snowflake user/private key environmentvariable profile."""

import base64
import json
from unittest.mock import patch

Expand All @@ -11,6 +12,34 @@
SnowflakeEncryptedPrivateKeyPemProfileMapping,
)

PLAIN_TEXT_KEY = "mocked-private-key-content"
BASE64_ENCODED_KEY = base64.b64encode(PLAIN_TEXT_KEY.encode("utf-8")).decode("utf-8")


@pytest.fixture()
def mock_snowflake_conn_base64(): # type: ignore
"""
Sets a connection with a base64 encoded private key as an environment variable.
"""
conn = Connection(
conn_id="my_snowflake_pk_connection_base64",
conn_type="snowflake",
login="my_user",
schema="my_schema",
extra=json.dumps(
{
"account": "my_account",
"database": "my_database",
"warehouse": "my_warehouse",
"private_key_content": BASE64_ENCODED_KEY,
"private_key_passphrase": "my_private_key_passphrase",
}
),
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


@pytest.fixture()
def mock_snowflake_conn(): # type: ignore
Expand Down Expand Up @@ -38,6 +67,24 @@ def mock_snowflake_conn(): # type: ignore
yield conn


@pytest.mark.parametrize(
"input_key, expected_key",
[
(BASE64_ENCODED_KEY, PLAIN_TEXT_KEY),
(PLAIN_TEXT_KEY, PLAIN_TEXT_KEY),
],
ids=["base64_encoded", "plain_text"],
)
def test_decode_private_key_content(input_key: str, expected_key: str, mock_snowflake_conn_base64: Connection) -> None:
"""
Tests that the private key content is decoded correctly.
"""
profile_mapping = get_automatic_profile_mapping(
mock_snowflake_conn_base64.conn_id,
)
assert profile_mapping._decode_private_key_content(input_key) == expected_key


def test_connection_claiming() -> None:
"""
Tests that the Snowflake profile mapping claims the correct connection type.
Expand Down Expand Up @@ -183,6 +230,20 @@ def test_profile_env_vars(
}


def test_profile_env_vars_with_base64(
mock_snowflake_conn_base64: Connection,
) -> None:
"""
Tests that the environment variable get set correctly for a base64-encoded key.
"""
profile_mapping = get_automatic_profile_mapping(
mock_snowflake_conn_base64.conn_id,
)
assert profile_mapping.env_vars == {
"COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY": PLAIN_TEXT_KEY,
}


def test_old_snowflake_format() -> None:
"""
Tests that the old format still works.
Expand Down
60 changes: 60 additions & 0 deletions tests/profiles/snowflake/test_snowflake_user_privatekey.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for the Snowflake user/private key profile."""

import base64
import json
from unittest.mock import patch

Expand All @@ -11,6 +12,33 @@
SnowflakePrivateKeyPemProfileMapping,
)

PLAIN_TEXT_KEY = "mocked-private-key-content"
BASE64_ENCODED_KEY = base64.b64encode(PLAIN_TEXT_KEY.encode("utf-8")).decode("utf-8")


@pytest.fixture()
def mock_snowflake_conn_base64(): # type: ignore
"""
Sets a connection with a base64 encoded private key as an environment variable.
"""
conn = Connection(
conn_id="my_snowflake_pk_connection_base64",
conn_type="snowflake",
login="my_user",
schema="my_schema",
extra=json.dumps(
{
"account": "my_account",
"database": "my_database",
"warehouse": "my_warehouse",
"private_key_content": BASE64_ENCODED_KEY,
}
),
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


@pytest.fixture()
def mock_snowflake_conn(): # type: ignore
Expand All @@ -36,6 +64,24 @@ def mock_snowflake_conn(): # type: ignore
yield conn


@pytest.mark.parametrize(
"input_key, expected_key",
[
(BASE64_ENCODED_KEY, PLAIN_TEXT_KEY),
(PLAIN_TEXT_KEY, PLAIN_TEXT_KEY),
],
ids=["base64_encoded", "plain_text"],
)
def test_decode_private_key_content(input_key: str, expected_key: str, mock_snowflake_conn_base64: Connection) -> None:
"""
Tests that the private key content is decoded correctly.
"""
profile_mapping = get_automatic_profile_mapping(
mock_snowflake_conn_base64.conn_id,
)
assert profile_mapping._decode_private_key_content(input_key) == expected_key


def test_connection_claiming() -> None:
"""
Tests that the Snowflake profile mapping claims the correct connection type.
Expand Down Expand Up @@ -180,6 +226,20 @@ def test_profile_env_vars(
}


def test_profile_env_vars_with_base64(
mock_snowflake_conn_base64: Connection,
) -> None:
"""
Tests that the environment variable get set correctly for a base64-encoded key.
"""
profile_mapping = get_automatic_profile_mapping(
mock_snowflake_conn_base64.conn_id,
)
assert profile_mapping.env_vars == {
"COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY": PLAIN_TEXT_KEY,
}


def test_old_snowflake_format() -> None:
"""
Tests that the old format still works.
Expand Down