Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions cosmos/profiles/base.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,18 @@ def get_profile_file_contents(

return str(yaml.dump(profile_contents, indent=4))

def _get_airflow_conn_field(self, airflow_field: str) -> Any:
# make sure there's no "extra." prefix
if airflow_field.startswith("extra."):
airflow_field = airflow_field.replace("extra.", "", 1)
value = self.conn.extra_dejson.get(airflow_field)
elif airflow_field.startswith("extra__"):
value = self.conn.extra_dejson.get(airflow_field)
else:
value = getattr(self.conn, airflow_field, None)

return value

def get_dbt_value(self, name: str) -> Any:
"""
Gets values for the dbt profile based on the required_by_dbt and required_in_profile_args lists.
Expand All @@ -260,16 +272,9 @@ def get_dbt_value(self, name: str) -> Any:
airflow_fields = [airflow_fields]

for airflow_field in airflow_fields:
# make sure there's no "extra." prefix
if airflow_field.startswith("extra."):
airflow_field = airflow_field.replace("extra.", "", 1)
value = self.conn.extra_dejson.get(airflow_field)
else:
value = getattr(self.conn, airflow_field, None)

value = self._get_airflow_conn_field(airflow_field)
if not value:
continue

# if there's a transform method, use it
if hasattr(self, f"transform_{name}"):
return getattr(self, f"transform_{name}")(value)
Expand Down
21 changes: 18 additions & 3 deletions tests/profiles/bigquery/test_bq_service_account_keyfile_dict.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from collections import namedtuple
from unittest.mock import patch

import pytest
Expand All @@ -8,29 +9,43 @@
from cosmos.profiles import get_automatic_profile_mapping
from cosmos.profiles.bigquery.service_account_keyfile_dict import GoogleCloudServiceAccountDictProfileMapping

ConnExtraParams = namedtuple("ConnExtraParams", ["keyfile_dict", "keyfile_json_extra_key"])
sample_keyfile_dict = {
"type": "service_account",
"private_key_id": "my_private_key_id",
"private_key": "my_private_key",
}


@pytest.fixture(params=[sample_keyfile_dict, json.dumps(sample_keyfile_dict)])
def get_fixture_params():
"""
Make a matrix of params for the fixture that mock connection, as there are multiple fields in
the airflow param mapping for the "keyfile_json" in GoogleCloudServiceAccountDictProfileMapping
"""
fixture_params = []
for d in (sample_keyfile_dict, json.dumps(sample_keyfile_dict)):
for key in GoogleCloudServiceAccountDictProfileMapping.airflow_param_mapping.get("keyfile_json"):
if key.startswith("extra."):
key = key.replace("extra.", "")
fixture_params.append(ConnExtraParams(keyfile_dict=d, keyfile_json_extra_key=key))
return fixture_params


@pytest.fixture(params=get_fixture_params())
def mock_bigquery_conn_with_dict(request): # type: ignore
"""
Mocks and returns an Airflow BigQuery connection.
"""
extra = {
"project": "my_project",
"dataset": "my_dataset",
"keyfile_dict": request.param,
request.param.keyfile_json_extra_key: request.param.keyfile_dict,
}
conn = Connection(
conn_id="my_bigquery_connection",
conn_type="google_cloud_platform",
extra=json.dumps(extra),
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn

Expand Down