diff --git a/cosmos/profiles/base.py b/cosmos/profiles/base.py old mode 100644 new mode 100755 index 9d17c40195..308d71a814 --- a/cosmos/profiles/base.py +++ b/cosmos/profiles/base.py @@ -241,6 +241,18 @@ def get_profile_file_contents( return str(yaml.dump(profile_contents, indent=4)) + def _get_airflow_conn_field(self, airflow_field: str) -> Any: + # make sure there's no "extra." prefix + if airflow_field.startswith("extra."): + airflow_field = airflow_field.replace("extra.", "", 1) + value = self.conn.extra_dejson.get(airflow_field) + elif airflow_field.startswith("extra__"): + value = self.conn.extra_dejson.get(airflow_field) + else: + value = getattr(self.conn, airflow_field, None) + + return value + def get_dbt_value(self, name: str) -> Any: """ Gets values for the dbt profile based on the required_by_dbt and required_in_profile_args lists. @@ -260,16 +272,9 @@ def get_dbt_value(self, name: str) -> Any: airflow_fields = [airflow_fields] for airflow_field in airflow_fields: - # make sure there's no "extra." prefix - if airflow_field.startswith("extra."): - airflow_field = airflow_field.replace("extra.", "", 1) - value = self.conn.extra_dejson.get(airflow_field) - else: - value = getattr(self.conn, airflow_field, None) - + value = self._get_airflow_conn_field(airflow_field) if not value: continue - # if there's a transform method, use it if hasattr(self, f"transform_{name}"): return getattr(self, f"transform_{name}")(value) diff --git a/tests/profiles/bigquery/test_bq_service_account_keyfile_dict.py b/tests/profiles/bigquery/test_bq_service_account_keyfile_dict.py old mode 100644 new mode 100755 index 00cf070c3d..4e56f5ba13 --- a/tests/profiles/bigquery/test_bq_service_account_keyfile_dict.py +++ b/tests/profiles/bigquery/test_bq_service_account_keyfile_dict.py @@ -1,4 +1,5 @@ import json +from collections import namedtuple from unittest.mock import patch import pytest @@ -8,6 +9,7 @@ from cosmos.profiles import get_automatic_profile_mapping from cosmos.profiles.bigquery.service_account_keyfile_dict import GoogleCloudServiceAccountDictProfileMapping +ConnExtraParams = namedtuple("ConnExtraParams", ["keyfile_dict", "keyfile_json_extra_key"]) sample_keyfile_dict = { "type": "service_account", "private_key_id": "my_private_key_id", @@ -15,7 +17,21 @@ } -@pytest.fixture(params=[sample_keyfile_dict, json.dumps(sample_keyfile_dict)]) +def get_fixture_params(): + """ + Make a matrix of params for the fixture that mock connection, as there are multiple fields in + the airflow param mapping for the "keyfile_json" in GoogleCloudServiceAccountDictProfileMapping + """ + fixture_params = [] + for d in (sample_keyfile_dict, json.dumps(sample_keyfile_dict)): + for key in GoogleCloudServiceAccountDictProfileMapping.airflow_param_mapping.get("keyfile_json"): + if key.startswith("extra."): + key = key.replace("extra.", "") + fixture_params.append(ConnExtraParams(keyfile_dict=d, keyfile_json_extra_key=key)) + return fixture_params + + +@pytest.fixture(params=get_fixture_params()) def mock_bigquery_conn_with_dict(request): # type: ignore """ Mocks and returns an Airflow BigQuery connection. @@ -23,14 +39,13 @@ def mock_bigquery_conn_with_dict(request): # type: ignore extra = { "project": "my_project", "dataset": "my_dataset", - "keyfile_dict": request.param, + request.param.keyfile_json_extra_key: request.param.keyfile_dict, } conn = Connection( conn_id="my_bigquery_connection", conn_type="google_cloud_platform", extra=json.dumps(extra), ) - with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): yield conn