diff --git a/cosmos/profiles/__init__.py b/cosmos/profiles/__init__.py index 1f39a91a0f..446207f353 100644 --- a/cosmos/profiles/__init__.py +++ b/cosmos/profiles/__init__.py @@ -6,7 +6,7 @@ from .athena import AthenaAccessKeyProfileMapping -from .base import BaseProfileMapping +from .base import BaseProfileMapping, DbtProfileConfigVars from .bigquery.service_account_file import GoogleCloudServiceAccountFileProfileMapping from .bigquery.service_account_keyfile_dict import GoogleCloudServiceAccountDictProfileMapping from .bigquery.oauth import GoogleCloudOauthProfileMapping @@ -70,6 +70,7 @@ def get_automatic_profile_mapping( "GoogleCloudServiceAccountDictProfileMapping", "GoogleCloudOauthProfileMapping", "DatabricksTokenProfileMapping", + "DbtProfileConfigVars", "PostgresUserPasswordProfileMapping", "RedshiftUserPasswordProfileMapping", "SnowflakeUserPasswordProfileMapping", diff --git a/cosmos/profiles/base.py b/cosmos/profiles/base.py index 2b2a5c7e2a..c583c8edb0 100644 --- a/cosmos/profiles/base.py +++ b/cosmos/profiles/base.py @@ -5,12 +5,12 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any - -from typing import TYPE_CHECKING -import yaml +from typing import Any, Optional, Literal, Dict, TYPE_CHECKING +import warnings from airflow.hooks.base import BaseHook +from pydantic import dataclasses +import yaml from cosmos.exceptions import CosmosValueError from cosmos.log import get_logger @@ -24,6 +24,31 @@ logger = get_logger(__name__) +@dataclasses.dataclass +class DbtProfileConfigVars: + send_anonymous_usage_stats: Optional[bool] = False + partial_parse: Optional[bool] = None + use_experimental_parser: Optional[bool] = None + static_parser: Optional[bool] = None + printer_width: Optional[bool] = None + write_json: Optional[bool] = None + warn_error: Optional[bool] = None + warn_error_options: Optional[Dict[Literal["include", "exclude"], Any]] = None + log_format: Optional[Literal["text", "json", "default"]] = None + debug: Optional[bool] = None + version_check: Optional[bool] = None + + def as_dict(self) -> dict[str, Any] | None: + result = { + field.name: getattr(self, field.name) + for field in self.__dataclass_fields__.values() + if getattr(self, field.name) is not None + } + if result != {}: + return result + return None + + class BaseProfileMapping(ABC): """ A base class that other profile mappings should inherit from to ensure consistency. @@ -41,11 +66,19 @@ class BaseProfileMapping(ABC): _conn: Connection | None = None - def __init__(self, conn_id: str, profile_args: dict[str, Any] | None = None, disable_event_tracking: bool = False): + def __init__( + self, + conn_id: str, + profile_args: dict[str, Any] | None = None, + disable_event_tracking: bool | None = None, + dbt_config_vars: DbtProfileConfigVars | None = None, + ): self.conn_id = conn_id self.profile_args = profile_args or {} self._validate_profile_args() self.disable_event_tracking = disable_event_tracking + self.dbt_config_vars = dbt_config_vars + self._validate_disable_event_tracking() def _validate_profile_args(self) -> None: """ @@ -66,6 +99,25 @@ class variables when creating the profile. ) ) + def _validate_disable_event_tracking(self) -> None: + """ + Check if disable_event_tracking is set and warn that it is deprecated. + """ + if self.disable_event_tracking: + warnings.warn( + "Disabling dbt event tracking is deprecated since Cosmos 1.3 and will be removed in Cosmos 2.0. " + "Use dbt_config_vars=DbtProfileConfigVars(send_anonymous_usage_stats=False) instead.", + DeprecationWarning, + ) + if ( + isinstance(self.dbt_config_vars, DbtProfileConfigVars) + and self.dbt_config_vars.send_anonymous_usage_stats is not None + ): + raise CosmosValueError( + "Cannot set both disable_event_tracking and " + "dbt_config_vars=DbtProfileConfigVars(send_anonymous_usage_stats ..." + ) + @property def conn(self) -> Connection: "Returns the Airflow connection." @@ -180,6 +232,9 @@ def get_profile_file_contents( } } + if self.dbt_config_vars: + profile_contents["config"] = self.dbt_config_vars.as_dict() + if self.disable_event_tracking: profile_contents["config"] = {"send_anonymous_usage_stats": False} diff --git a/dev/dags/cosmos_manifest_example.py b/dev/dags/cosmos_manifest_example.py index c94ea41a2e..7b7f9d4aaa 100644 --- a/dev/dags/cosmos_manifest_example.py +++ b/dev/dags/cosmos_manifest_example.py @@ -7,7 +7,7 @@ from pathlib import Path from cosmos import DbtDag, ProjectConfig, ProfileConfig, RenderConfig, LoadMode, ExecutionConfig -from cosmos.profiles import PostgresUserPasswordProfileMapping +from cosmos.profiles import PostgresUserPasswordProfileMapping, DbtProfileConfigVars DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt" DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) @@ -18,6 +18,7 @@ profile_mapping=PostgresUserPasswordProfileMapping( conn_id="airflow_db", profile_args={"schema": "public"}, + dbt_config_vars=DbtProfileConfigVars(send_anonymous_usage_stats=True), ), ) diff --git a/docs/templates/index.rst.jinja2 b/docs/templates/index.rst.jinja2 index d5c3069111..802b075ed9 100644 --- a/docs/templates/index.rst.jinja2 +++ b/docs/templates/index.rst.jinja2 @@ -85,6 +85,9 @@ you specify in ``ProfileConfig``. Disabling dbt event tracking -------------------------------- + +.. note: + Deprecated in v.1.4 and will be removed in v2.0.0. Use dbt_config_vars=DbtProfileConfigVars(send_anonymous_usage_stats=False) instead. .. versionadded:: 1.3 By default `dbt will track events `_ by sending anonymous usage data @@ -112,6 +115,43 @@ the example below: dag = DbtDag(profile_config=profile_config, ...) +Dbt profile config variables +-------------------------------- +.. versionadded:: 1.4.0 + +The parts of ``profiles.yml``, which aren't specific to a particular data platform `dbt docs `_ + +.. code-block:: python + + from cosmos.profiles import SnowflakeUserPasswordProfileMapping, DbtProfileConfigVars + + profile_config = ProfileConfig( + profile_name="my_profile_name", + target_name="my_target_name", + profile_mapping=SnowflakeUserPasswordProfileMapping( + conn_id="my_snowflake_conn_id", + profile_args={ + "database": "my_snowflake_database", + "schema": "my_snowflake_schema", + }, + dbt_config_vars=DbtProfileConfigVars( + send_anonymous_usage_stats=False, + partial_parse=True, + use_experimental_parse=True, + static_parser=True, + printer_width=120, + write_json=True, + warn_error=True, + warn_error_options={"include": "all"}, + log_format='text', + debug=True, + version_check=True, + ), + ), + ) + + dag = DbtDag(profile_config=profile_config, ...) + diff --git a/tests/profiles/test_base_profile.py b/tests/profiles/test_base_profile.py index 98c4004e71..b80912bcd8 100644 --- a/tests/profiles/test_base_profile.py +++ b/tests/profiles/test_base_profile.py @@ -1,9 +1,11 @@ from __future__ import annotations +from typing import Any import pytest import yaml +from pydantic.error_wrappers import ValidationError -from cosmos.profiles.base import BaseProfileMapping +from cosmos.profiles.base import BaseProfileMapping, DbtProfileConfigVars from cosmos.exceptions import CosmosValueError @@ -37,7 +39,7 @@ def test_validate_profile_args(profile_arg: str): @pytest.mark.parametrize("disable_event_tracking", [True, False]) -def test_disable_event_tracking(disable_event_tracking: str): +def test_disable_event_tracking(disable_event_tracking: bool): """ Tests the config block in the profile is set correctly if disable_event_tracking is set. """ @@ -50,3 +52,112 @@ def test_disable_event_tracking(disable_event_tracking: str): assert ("config" in profile_contents) == disable_event_tracking if disable_event_tracking: assert profile_contents["config"]["send_anonymous_usage_stats"] is False + + +def test_disable_event_tracking_and_send_anonymous_usage_stats(): + expected_cosmos_error = ( + "Cannot set both disable_event_tracking and " + "dbt_config_vars=DbtProfileConfigVars(send_anonymous_usage_stats ..." + ) + + with pytest.raises(CosmosValueError) as err_info: + TestProfileMapping( + conn_id="fake_conn_id", + dbt_config_vars=DbtProfileConfigVars(send_anonymous_usage_stats=False), + disable_event_tracking=True, + ) + assert err_info.value.args[0] == expected_cosmos_error + + +def test_dbt_profile_config_vars_none(): + """ + Tests the DbtProfileConfigVars return None. + """ + dbt_config_vars = DbtProfileConfigVars( + send_anonymous_usage_stats=None, + partial_parse=None, + use_experimental_parser=None, + static_parser=None, + printer_width=None, + write_json=None, + warn_error=None, + warn_error_options=None, + log_format=None, + debug=None, + version_check=None, + ) + assert dbt_config_vars.as_dict() is None + + +@pytest.mark.parametrize("config", [True, False]) +def test_dbt_config_vars_config(config: bool): + """ + Tests the config block in the profile is set correctly. + """ + + dbt_config_vars = None + if config: + dbt_config_vars = DbtProfileConfigVars(debug=False) + + test_profile = TestProfileMapping( + conn_id="fake_conn_id", + dbt_config_vars=dbt_config_vars, + ) + profile_contents = yaml.safe_load(test_profile.get_profile_file_contents(profile_name="fake-profile-name")) + + assert ("config" in profile_contents) == config + + +@pytest.mark.parametrize("dbt_config_var,dbt_config_value", [("debug", True), ("debug", False)]) +def test_validate_dbt_config_vars(dbt_config_var: str, dbt_config_value: Any): + """ + Tests the config block in the profile is set correctly. + """ + dbt_config_vars = {dbt_config_var: dbt_config_value} + test_profile = TestProfileMapping( + conn_id="fake_conn_id", + dbt_config_vars=DbtProfileConfigVars(**dbt_config_vars), + ) + profile_contents = yaml.safe_load(test_profile.get_profile_file_contents(profile_name="fake-profile-name")) + + assert "config" in profile_contents + assert profile_contents["config"][dbt_config_var] == dbt_config_value + + +@pytest.mark.parametrize( + "dbt_config_var,dbt_config_value", + [("send_anonymous_usage_stats", 2), ("send_anonymous_usage_stats", "aaa")], +) +def test_profile_config_validate_dbt_config_vars_check_unexpected_types(dbt_config_var: str, dbt_config_value: Any): + dbt_config_vars = {dbt_config_var: dbt_config_value} + + with pytest.raises(ValidationError): + TestProfileMapping( + conn_id="fake_conn_id", + dbt_config_vars=DbtProfileConfigVars(**dbt_config_vars), + ) + + +@pytest.mark.parametrize("dbt_config_var,dbt_config_value", [("send_anonymous_usage_stats", True)]) +def test_profile_config_validate_dbt_config_vars_check_expected_types(dbt_config_var: str, dbt_config_value: Any): + dbt_config_vars = {dbt_config_var: dbt_config_value} + + profile_config = TestProfileMapping( + conn_id="fake_conn_id", + dbt_config_vars=DbtProfileConfigVars(**dbt_config_vars), + ) + assert profile_config.dbt_config_vars.as_dict() == dbt_config_vars + + +@pytest.mark.parametrize( + "dbt_config_var,dbt_config_value", + [("log_format", "text2")], +) +def test_profile_config_validate_dbt_config_vars_check_values(dbt_config_var: str, dbt_config_value: Any): + dbt_config_vars = {dbt_config_var: dbt_config_value} + + with pytest.raises(ValidationError): + TestProfileMapping( + conn_id="fake_conn_id", + dbt_config_vars=DbtProfileConfigVars(**dbt_config_vars), + )