diff --git a/cosmos/profiles/__init__.py b/cosmos/profiles/__init__.py index 392b4f78b8..00b9348819 100644 --- a/cosmos/profiles/__init__.py +++ b/cosmos/profiles/__init__.py @@ -10,6 +10,7 @@ from .bigquery.service_account_file import GoogleCloudServiceAccountFileProfileMapping from .bigquery.service_account_keyfile_dict import GoogleCloudServiceAccountDictProfileMapping from .clickhouse.user_pass import ClickhouseUserPasswordProfileMapping +from .databricks.oauth import DatabricksOauthProfileMapping from .databricks.token import DatabricksTokenProfileMapping from .exasol.user_pass import ExasolUserPasswordProfileMapping from .postgres.user_pass import PostgresUserPasswordProfileMapping @@ -32,6 +33,7 @@ GoogleCloudServiceAccountDictProfileMapping, GoogleCloudOauthProfileMapping, DatabricksTokenProfileMapping, + DatabricksOauthProfileMapping, PostgresUserPasswordProfileMapping, RedshiftUserPasswordProfileMapping, SnowflakeUserPasswordProfileMapping, @@ -73,6 +75,7 @@ def get_automatic_profile_mapping( "GoogleCloudServiceAccountDictProfileMapping", "GoogleCloudOauthProfileMapping", "DatabricksTokenProfileMapping", + "DatabricksOauthProfileMapping", "DbtProfileConfigVars", "PostgresUserPasswordProfileMapping", "RedshiftUserPasswordProfileMapping", diff --git a/cosmos/profiles/databricks/__init__.py b/cosmos/profiles/databricks/__init__.py index 2e3a9d1143..7ce683c7af 100644 --- a/cosmos/profiles/databricks/__init__.py +++ b/cosmos/profiles/databricks/__init__.py @@ -1,5 +1,6 @@ """Databricks Airflow connection -> dbt profile mappings""" +from .oauth import DatabricksOauthProfileMapping from .token import DatabricksTokenProfileMapping -__all__ = ["DatabricksTokenProfileMapping"] +__all__ = ["DatabricksTokenProfileMapping", "DatabricksOauthProfileMapping"] diff --git a/cosmos/profiles/databricks/oauth.py b/cosmos/profiles/databricks/oauth.py new file mode 100644 index 0000000000..fd6875c497 --- /dev/null +++ b/cosmos/profiles/databricks/oauth.py @@ -0,0 +1,48 @@ +"""Maps Airflow Databricks connections with the client auth to dbt profiles.""" + +from __future__ import annotations + +from typing import Any + +from ..base import BaseProfileMapping + + +class DatabricksOauthProfileMapping(BaseProfileMapping): + """ + Maps Airflow Databricks connections with the client auth to dbt profiles. + + https://docs.getdbt.com/reference/warehouse-setups/databricks-setup + https://airflow.apache.org/docs/apache-airflow-providers-databricks/stable/connections/databricks.html + """ + + airflow_connection_type: str = "databricks" + dbt_profile_type: str = "databricks" + + required_fields = [ + "host", + "schema", + "client_secret", + "client_id", + "http_path", + ] + + secret_fields = ["client_secret", "client_id"] + + airflow_param_mapping = { + "host": "host", + "schema": "schema", + "client_id": ["login", "extra.client_id"], + "client_secret": ["password", "extra.client_secret"], + "http_path": "extra.http_path", + } + + @property + def profile(self) -> dict[str, Any | None]: + """Generates profile. The client-id and client-secret is stored in an environment variable.""" + return { + **self.mapped_params, + **self.profile_args, + "auth_type": "oauth", + "client_secret": self.get_env_var_format("client_secret"), + "client_id": self.get_env_var_format("client_id"), + } diff --git a/tests/profiles/databricks/test_dbr_oauth.py b/tests/profiles/databricks/test_dbr_oauth.py new file mode 100644 index 0000000000..96228c4bb8 --- /dev/null +++ b/tests/profiles/databricks/test_dbr_oauth.py @@ -0,0 +1,71 @@ +"""Tests for the databricks OAuth profile.""" + +from unittest.mock import patch + +import pytest +from airflow.models.connection import Connection + +from cosmos.profiles.databricks import DatabricksOauthProfileMapping + + +@pytest.fixture() +def mock_databricks_conn(): # type: ignore + """ + Mocks and returns an Airflow Databricks connection. + """ + conn = Connection( + conn_id="my_databricks_connection", + conn_type="databricks", + host="https://my_host", + login="my_client_id", + password="my_client_secret", + extra='{"http_path": "my_http_path"}', + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + + +def test_connection_claiming() -> None: + """ + Tests that the Databricks profile mapping claims the correct connection type. + """ + # should only claim when: + # - conn_type == databricks + # and the following exist: + # - schema + # - host + # - http_path + # - client_id + # - client_secret + potential_values = { + "conn_type": "databricks", + "host": "my_host", + "login": "my_client_id", + "password": "my_client_secret", + "extra": '{"http_path": "my_http_path"}', + } + + # if we're missing any of the values, it shouldn't claim + for key in potential_values: + values = potential_values.copy() + del values[key] + conn = Connection(**values) # type: ignore + + print("testing with", values) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = DatabricksOauthProfileMapping(conn, {"schema": "my_schema"}) + assert not profile_mapping.can_claim_connection() + + # also test when there's no schema + conn = Connection(**potential_values) # type: ignore + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = DatabricksOauthProfileMapping(conn, {}) + assert not profile_mapping.can_claim_connection() + + # if we have them all, it should claim + conn = Connection(**potential_values) # type: ignore + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = DatabricksOauthProfileMapping(conn, {"schema": "my_schema"}) + assert profile_mapping.can_claim_connection()