Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions cosmos/profiles/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
if TYPE_CHECKING:
from airflow.models import Connection

DBT_PROFILE_TYPE_FIELD = "type"
DBT_PROFILE_METHOD_FIELD = "method"

logger = get_logger(__name__)


Expand All @@ -41,6 +44,26 @@ class BaseProfileMapping(ABC):
def __init__(self, conn_id: str, profile_args: dict[str, Any] | None = None):
self.conn_id = conn_id
self.profile_args = profile_args or {}
self._validate_profile_args()

def _validate_profile_args(self) -> None:
"""
Check if profile_args contains keys that should not be overridden from the
class variables when creating the profile.
"""
for profile_field in [DBT_PROFILE_TYPE_FIELD, DBT_PROFILE_METHOD_FIELD]:
if profile_field in self.profile_args and self.profile_args.get(profile_field) != getattr(
self, f"dbt_profile_{profile_field}"
):
raise CosmosValueError(
"`profile_args` for {0} has {1}='{2}' that will override the dbt profile required value of '{3}'. "
"To fix this, remove {1} from `profile_args`.".format(
self.__class__.__name__,
profile_field,
self.profile_args.get(profile_field),
getattr(self, f"dbt_profile_{profile_field}"),
)
)

@property
def conn(self) -> Connection:
Expand Down Expand Up @@ -100,11 +123,11 @@ def mock_profile(self) -> dict[str, Any]:
where live connection values don't matter.
"""
mock_profile = {
"type": self.dbt_profile_type,
DBT_PROFILE_TYPE_FIELD: self.dbt_profile_type,
}

if self.dbt_profile_method:
mock_profile["method"] = self.dbt_profile_method
mock_profile[DBT_PROFILE_METHOD_FIELD] = self.dbt_profile_method

for field in self.required_fields:
# if someone has passed in a value for this field, use it
Expand Down Expand Up @@ -199,11 +222,11 @@ def get_dbt_value(self, name: str) -> Any:
def mapped_params(self) -> dict[str, Any]:
"Turns the self.airflow_param_mapping into a dictionary of dbt fields and their values."
mapped_params = {
"type": self.dbt_profile_type,
DBT_PROFILE_TYPE_FIELD: self.dbt_profile_type,
}

if self.dbt_profile_method:
mapped_params["method"] = self.dbt_profile_method
mapped_params[DBT_PROFILE_METHOD_FIELD] = self.dbt_profile_method

for dbt_field in self.airflow_param_mapping:
mapped_params[dbt_field] = self.get_dbt_value(dbt_field)
Expand Down
31 changes: 31 additions & 0 deletions tests/profiles/test_base_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest
from cosmos.profiles.base import BaseProfileMapping
from cosmos.exceptions import CosmosValueError


class TestProfileMapping(BaseProfileMapping):
dbt_profile_method: str = "fake-method"
dbt_profile_type: str = "fake-type"

def profile(self):
raise NotImplementedError


@pytest.mark.parametrize("profile_arg", ["type", "method"])
def test_validate_profile_args(profile_arg: str):
"""
An error should be raised if the profile_args contains a key that should not be overridden from the class variables.
"""
profile_args = {profile_arg: "fake-value"}
dbt_profile_value = getattr(TestProfileMapping, f"dbt_profile_{profile_arg}")

expected_cosmos_error = (
f"`profile_args` for TestProfileMapping has {profile_arg}='fake-value' that will override the dbt profile required value of "
f"'{dbt_profile_value}'. To fix this, remove {profile_arg} from `profile_args`."
)

with pytest.raises(CosmosValueError, match=expected_cosmos_error):
TestProfileMapping(
conn_id="fake_conn_id",
profile_args=profile_args,
)