diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 14447a09d422d..5156cdcbaf8b0 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -22,6 +22,7 @@ operators talk to the ``api/2.0/jobs/runs/submit`` `endpoint `_. """ +import sys import time from time import sleep from typing import Dict @@ -34,6 +35,12 @@ from airflow import __version__ from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook +from airflow.models import Connection + +if sys.version_info >= (3, 8): + from functools import cached_property +else: + from cached_property import cached_property RESTART_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/restart") START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start") @@ -143,11 +150,10 @@ def __init__( self.retry_delay = retry_delay self.aad_tokens: Dict[str, dict] = {} self.aad_timeout_seconds = 10 - self.databricks_conn = self.get_connection(self.databricks_conn_id) - if 'host' in self.databricks_conn.extra_dejson: - self.host = self._parse_host(self.databricks_conn.extra_dejson['host']) - else: - self.host = self._parse_host(self.databricks_conn.host) + + @cached_property + def databricks_conn(self) -> Connection: + return self.get_connection(self.databricks_conn_id) @staticmethod def _parse_host(host: str) -> str: @@ -305,6 +311,11 @@ def _do_api_call(self, endpoint_info, json): """ method, endpoint = endpoint_info + if 'host' in self.databricks_conn.extra_dejson: + self.host = self._parse_host(self.databricks_conn.extra_dejson['host']) + else: + self.host = self._parse_host(self.databricks_conn.host) + url = f'https://{self.host}/{endpoint}' aad_headers = self._get_aad_headers()