diff --git a/airflow/providers/dbt/cloud/hooks/dbt.py b/airflow/providers/dbt/cloud/hooks/dbt.py index d88c0053d318f..31214d73424d7 100644 --- a/airflow/providers/dbt/cloud/hooks/dbt.py +++ b/airflow/providers/dbt/cloud/hooks/dbt.py @@ -144,14 +144,17 @@ class DbtCloudHook(HttpHook): def get_ui_field_behaviour() -> Dict[str, Any]: """Builds custom field behavior for the dbt Cloud connection form in the Airflow UI.""" return { - "hidden_fields": ["host", "port", "schema", "extra"], - "relabeling": {"login": "Account ID", "password": "API Token"}, + "hidden_fields": ["host", "port", "extra"], + "relabeling": {"login": "Account ID", "password": "API Token", "schema": "Tenant"}, + "placeholders": {"schema": "Defaults to 'cloud'."}, } def __init__(self, dbt_cloud_conn_id: str = default_conn_name, *args, **kwargs) -> None: super().__init__(auth_type=TokenAuth) self.dbt_cloud_conn_id = dbt_cloud_conn_id - self.base_url = "https://cloud.getdbt.com/api/v2/accounts/" + tenant = self.connection.schema if self.connection.schema else 'cloud' + + self.base_url = f"https://{tenant}.getdbt.com/api/v2/accounts/" @cached_property def connection(self) -> Connection: diff --git a/tests/providers/dbt/cloud/hooks/test_dbt_cloud.py b/tests/providers/dbt/cloud/hooks/test_dbt_cloud.py index 0c4ad654563f4..012138e5aac0c 100644 --- a/tests/providers/dbt/cloud/hooks/test_dbt_cloud.py +++ b/tests/providers/dbt/cloud/hooks/test_dbt_cloud.py @@ -34,14 +34,17 @@ ACCOUNT_ID_CONN = "account_id_conn" NO_ACCOUNT_ID_CONN = "no_account_id_conn" +SINGLE_TENANT_CONN = "single_tenant_conn" DEFAULT_ACCOUNT_ID = 11111 ACCOUNT_ID = 22222 +SINGLE_TENANT_SCHEMA = "single.tenant" TOKEN = "token" PROJECT_ID = 33333 JOB_ID = 4444 RUN_ID = 5555 BASE_URL = "https://cloud.getdbt.com/api/v2/accounts/" +SINGLE_TENANT_URL = "https://single.tenant.getdbt.com/api/v2/accounts/" class TestDbtCloudJobRunStatus: @@ -119,15 +122,30 @@ def setup_class(self): password=TOKEN, ) + # Connection with `schema` parameter set + schema_conn = Connection( + conn_id=SINGLE_TENANT_CONN, + conn_type=DbtCloudHook.conn_type, + login=DEFAULT_ACCOUNT_ID, + password=TOKEN, + schema=SINGLE_TENANT_SCHEMA, + ) + db.merge_conn(account_id_conn) db.merge_conn(no_account_id_conn) + db.merge_conn(schema_conn) - def test_init_hook(self): - hook = DbtCloudHook() - assert hook.dbt_cloud_conn_id == "dbt_cloud_default" - assert hook.base_url == BASE_URL + @pytest.mark.parametrize( + argnames="conn_id, url", + argvalues=[(ACCOUNT_ID_CONN, BASE_URL), (SINGLE_TENANT_CONN, SINGLE_TENANT_URL)], + ids=["multi-tenant", "single-tenant"], + ) + def test_init_hook(self, conn_id, url): + hook = DbtCloudHook(conn_id) assert hook.auth_type == TokenAuth assert hook.method == "POST" + assert hook.dbt_cloud_conn_id == conn_id + assert hook.base_url == url @pytest.mark.parametrize( argnames="conn_id, account_id", diff --git a/tests/providers/dbt/cloud/sensors/test_dbt_cloud.py b/tests/providers/dbt/cloud/sensors/test_dbt_cloud.py index 6cd50690288ba..c5d834fd1e7ef 100644 --- a/tests/providers/dbt/cloud/sensors/test_dbt_cloud.py +++ b/tests/providers/dbt/cloud/sensors/test_dbt_cloud.py @@ -19,11 +19,14 @@ import pytest +from airflow.models.connection import Connection from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook, DbtCloudJobRunException, DbtCloudJobRunStatus from airflow.providers.dbt.cloud.sensors.dbt import DbtCloudJobRunSensor +from airflow.utils import db ACCOUNT_ID = 11111 RUN_ID = 5555 +TOKEN = "token" class TestDbtCloudJobRunSensor: @@ -37,6 +40,11 @@ def setup_class(self): poke_interval=15, ) + # Connection + conn = Connection(conn_id="dbt", conn_type=DbtCloudHook.conn_type, login=ACCOUNT_ID, password=TOKEN) + + db.merge_conn(conn) + def test_init(self): assert self.sensor.dbt_cloud_conn_id == "dbt" assert self.sensor.run_id == RUN_ID