From f770b763284a862617fa3b546e8672046c5e9849 Mon Sep 17 00:00:00 2001 From: eskarimov Date: Fri, 26 Nov 2021 08:58:37 +0100 Subject: [PATCH 1/4] Refactor DatabricksHook --- .../providers/databricks/hooks/databricks.py | 110 +++++++++--------- .../databricks/hooks/test_databricks.py | 12 +- 2 files changed, 68 insertions(+), 54 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 9e4ecaff372b8..7d2eda97a3a5e 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -48,6 +48,8 @@ USER_AGENT_HEADER = {'user-agent': f'airflow-{__version__}'} +RUN_LIFE_CYCLE_STATES = ['PENDING', 'RUNNING', 'TERMINATING', 'TERMINATED', 'SKIPPED', 'INTERNAL_ERROR'] + # https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/latest/aad/service-prin-aad-token#--get-an-azure-active-directory-access-token # https://docs.microsoft.com/en-us/graph/deployments#app-registration-and-token-service-root-endpoints AZURE_DEFAULT_AD_ENDPOINT = "https://login.microsoftonline.com" @@ -64,7 +66,9 @@ class RunState: """Utility class for the run state concept of Databricks runs.""" - def __init__(self, life_cycle_state: str, result_state: str, state_message: str) -> None: + def __init__( + self, life_cycle_state: str, result_state: str = '', state_message: str = '', *args, **kwargs + ) -> None: self.life_cycle_state = life_cycle_state self.result_state = result_state self.state_message = state_message @@ -131,7 +135,11 @@ def __init__( ) -> None: super().__init__() self.databricks_conn_id = databricks_conn_id - self.databricks_conn = None + self.databricks_conn = self.get_connection(databricks_conn_id) + if 'host' in self.databricks_conn.extra_dejson: + self.host = self._parse_host(self.databricks_conn.extra_dejson['host']) + else: + self.host = self._parse_host(self.databricks_conn.host) self.timeout_seconds = timeout_seconds if retry_limit < 1: raise ValueError('Retry limit must be greater than equal to 1') @@ -173,13 +181,11 @@ def _get_aad_token(self, resource: str) -> str: :param resource: resource to issue token to :return: AAD token, or raise an exception """ - if resource in self.aad_tokens: - d = self.aad_tokens[resource] - now = int(time.time()) - if d['expires_on'] > (now + TOKEN_REFRESH_LEAD_TIME): # it expires in more than 2 minutes - return d['token'] - self.log.info("Existing AAD token is expired, or going to expire soon. Refreshing...") + aad_token = self.aad_tokens.get(resource) + if aad_token and self._is_aad_token_valid(aad_token): + return aad_token['token'] + self.log.info('Existing AAD token is expired, or going to expire soon. Refreshing...') attempt_num = 1 while True: try: @@ -235,21 +241,18 @@ def _get_aad_token(self, resource: str) -> str: attempt_num += 1 sleep(self.retry_delay) - def _fill_aad_tokens(self, headers: dict) -> str: + def _fill_aad_headers(self, headers: dict) -> dict: """ - Fills headers if necessary (SPN is outside of the workspace) and generates AAD token + Fills AAD headers if necessary (SPN is outside of the workspace) :param headers: dictionary with headers to fill-in - :return: AAD token + :return: dictionary with filled AAD headers """ - # SP is outside of the workspace - if 'azure_resource_id' in self.databricks_conn.extra_dejson: - mgmt_token = self._get_aad_token(AZURE_MANAGEMENT_ENDPOINT) - headers['X-Databricks-Azure-Workspace-Resource-Id'] = self.databricks_conn.extra_dejson[ - 'azure_resource_id' - ] - headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token - - return self._get_aad_token(DEFAULT_DATABRICKS_SCOPE) + mgmt_token = self._get_aad_token(AZURE_MANAGEMENT_ENDPOINT) + headers['X-Databricks-Azure-Workspace-Resource-Id'] = self.databricks_conn.extra_dejson[ + 'azure_resource_id' + ] + headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token + return headers def _do_api_call(self, endpoint_info, json): """ @@ -265,14 +268,11 @@ def _do_api_call(self, endpoint_info, json): :rtype: dict """ method, endpoint = endpoint_info - - self.databricks_conn = self.get_connection(self.databricks_conn_id) + url = f'https://{self.host}/{endpoint}' headers = USER_AGENT_HEADER.copy() - if 'host' in self.databricks_conn.extra_dejson: - host = self._parse_host(self.databricks_conn.extra_dejson['host']) - else: - host = self.databricks_conn.host + if 'azure_resource_id' in self.databricks_conn.extra_dejson: + headers = self._fill_aad_headers(headers) if 'token' in self.databricks_conn.extra_dejson: self.log.info( @@ -285,9 +285,8 @@ def _do_api_call(self, endpoint_info, json): elif 'azure_tenant_id' in self.databricks_conn.extra_dejson: if self.databricks_conn.login == "" or self.databricks_conn.password == "": raise AirflowException("Azure SPN credentials aren't provided") - - self.log.info('Using AAD Token for SPN. ') - auth = _TokenAuth(self._fill_aad_tokens(headers)) + self.log.info('Using AAD Token for SPN.') + auth = _TokenAuth(self._get_aad_token(DEFAULT_DATABRICKS_SCOPE)) elif self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False): self.log.info('Using AAD Token for managed identity.') # check for Azure Metadata Service @@ -306,13 +305,11 @@ def _do_api_call(self, endpoint_info, json): except (requests_exceptions.RequestException, ValueError) as e: raise AirflowException(f"Can't reach Azure Metadata Service: {e}") - auth = _TokenAuth(self._fill_aad_tokens(headers)) + auth = _TokenAuth(self._get_aad_token(DEFAULT_DATABRICKS_SCOPE)) else: self.log.info('Using basic auth.') auth = (self.databricks_conn.login, self.databricks_conn.password) - url = f'https://{self._parse_host(host)}/{endpoint}' - if method == 'GET': request_func = requests.get elif method == 'POST': @@ -356,31 +353,31 @@ def _do_api_call(self, endpoint_info, json): def _log_request_error(self, attempt_num: int, error: str) -> None: self.log.error('Attempt %s API Request to Databricks failed with reason: %s', attempt_num, error) - def run_now(self, json: dict) -> str: + def run_now(self, json: dict) -> int: """ Utility function to call the ``api/2.0/jobs/run-now`` endpoint. :param json: The data used in the body of the request to the ``run-now`` endpoint. :type json: dict - :return: the run_id as a string + :return: the run_id as an int :rtype: str """ response = self._do_api_call(RUN_NOW_ENDPOINT, json) return response['run_id'] - def submit_run(self, json: dict) -> str: + def submit_run(self, json: dict) -> int: """ Utility function to call the ``api/2.0/jobs/runs/submit`` endpoint. :param json: The data used in the body of the request to the ``submit`` endpoint. :type json: dict - :return: the run_id as a string + :return: the run_id as an int :rtype: str """ response = self._do_api_call(SUBMIT_RUN_ENDPOINT, json) return response['run_id'] - def get_run_page_url(self, run_id: str) -> str: + def get_run_page_url(self, run_id: int) -> str: """ Retrieves run_page_url. @@ -391,19 +388,19 @@ def get_run_page_url(self, run_id: str) -> str: response = self._do_api_call(GET_RUN_ENDPOINT, json) return response['run_page_url'] - def get_job_id(self, run_id: str) -> str: + def get_job_id(self, run_id: int) -> int: """ Retrieves job_id from run_id. :param run_id: id of the run - :type run_id: str + :type run_id: int :return: Job id for given Databricks run """ json = {'run_id': run_id} response = self._do_api_call(GET_RUN_ENDPOINT, json) return response['job_id'] - def get_run_state(self, run_id: str) -> RunState: + def get_run_state(self, run_id: int) -> RunState: """ Retrieves run state of the run. @@ -421,13 +418,9 @@ def get_run_state(self, run_id: str) -> RunState: json = {'run_id': run_id} response = self._do_api_call(GET_RUN_ENDPOINT, json) state = response['state'] - life_cycle_state = state['life_cycle_state'] - # result_state may not be in the state if not terminal - result_state = state.get('result_state', None) - state_message = state['state_message'] - return RunState(life_cycle_state, result_state, state_message) + return RunState(**state) - def get_run_state_str(self, run_id: str) -> str: + def get_run_state_str(self, run_id: int) -> str: """ Return the string representation of RunState. @@ -440,7 +433,7 @@ def get_run_state_str(self, run_id: str) -> str: ) return run_state_str - def get_run_state_lifecycle(self, run_id: str) -> str: + def get_run_state_lifecycle(self, run_id: int) -> str: """ Returns the lifecycle state of the run @@ -449,7 +442,7 @@ def get_run_state_lifecycle(self, run_id: str) -> str: """ return self.get_run_state(run_id).life_cycle_state - def get_run_state_result(self, run_id: str) -> str: + def get_run_state_result(self, run_id: int) -> str: """ Returns the resulting state of the run @@ -458,7 +451,7 @@ def get_run_state_result(self, run_id: str) -> str: """ return self.get_run_state(run_id).result_state - def get_run_state_message(self, run_id: str) -> str: + def get_run_state_message(self, run_id: int) -> str: """ Returns the state message for the run @@ -467,7 +460,7 @@ def get_run_state_message(self, run_id: str) -> str: """ return self.get_run_state(run_id).state_message - def cancel_run(self, run_id: str) -> None: + def cancel_run(self, run_id: int) -> None: """ Cancels the run. @@ -522,6 +515,20 @@ def uninstall(self, json: dict) -> None: """ self._do_api_call(UNINSTALL_LIBS_ENDPOINT, json) + @staticmethod + def _is_aad_token_valid(aad_token: dict) -> bool: + """ + Utility function to check AAD token hasn't expired yet + :param aad_token: dict with properties of AAD token + :type aad_token: dict + :return: true if token is valid, false otherwise + :rtype: bool + """ + now = int(time.time()) + if aad_token['expires_on'] > (now + TOKEN_REFRESH_LEAD_TIME): + return True + return False + def _retryable_error(exception) -> bool: return ( @@ -531,9 +538,6 @@ def _retryable_error(exception) -> bool: ) -RUN_LIFE_CYCLE_STATES = ['PENDING', 'RUNNING', 'TERMINATING', 'TERMINATED', 'SKIPPED', 'INTERNAL_ERROR'] - - class _TokenAuth(AuthBase): """ Helper class for requests Auth field. AuthBase requires you to implement the __call__ diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index a5f0fb467fe64..ef7b31334da7f 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -19,6 +19,7 @@ import itertools import json +import time import unittest from unittest import mock @@ -34,6 +35,7 @@ AZURE_TOKEN_SERVICE_URL, DEFAULT_DATABRICKS_SCOPE, SUBMIT_RUN_ENDPOINT, + TOKEN_REFRESH_LEAD_TIME, DatabricksHook, RunState, ) @@ -63,7 +65,7 @@ } NOTEBOOK_PARAMS = {"dry-run": "true", "oldest-time-to-consider": "1457570074236"} JAR_PARAMS = ["param1", "param2"] -RESULT_STATE = None # type: None +RESULT_STATE = '' LIBRARIES = [ {"jar": "dbfs:/mnt/libraries/library.jar"}, {"maven": {"coordinates": "org.jsoup:jsoup:1.7.2", "exclusions": ["slf4j:slf4j"]}}, @@ -520,6 +522,14 @@ def test_uninstall_libs_on_cluster(self, mock_requests): timeout=self.hook.timeout_seconds, ) + def test_is_aad_token_valid_returns_true(self): + aad_token = {'token': 'my_token', 'expires_on': int(time.time()) + TOKEN_REFRESH_LEAD_TIME + 10} + self.assertTrue(self.hook._is_aad_token_valid(aad_token)) + + def test_is_aad_token_valid_returns_false(self): + aad_token = {'token': 'my_token', 'expires_on': int(time.time())} + self.assertFalse(self.hook._is_aad_token_valid(aad_token)) + class TestDatabricksHookToken(unittest.TestCase): """ From 98d91c07684b4cb00f4f8b55759160b555e8f001 Mon Sep 17 00:00:00 2001 From: eskarimov Date: Sun, 28 Nov 2021 12:34:00 +0100 Subject: [PATCH 2/4] Refactor `_fill_add_headers` --- .../providers/databricks/hooks/databricks.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 7d2eda97a3a5e..bf77d94a05cc5 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -241,17 +241,18 @@ def _get_aad_token(self, resource: str) -> str: attempt_num += 1 sleep(self.retry_delay) - def _fill_aad_headers(self, headers: dict) -> dict: + def _get_aad_headers(self) -> dict: """ Fills AAD headers if necessary (SPN is outside of the workspace) - :param headers: dictionary with headers to fill-in :return: dictionary with filled AAD headers """ - mgmt_token = self._get_aad_token(AZURE_MANAGEMENT_ENDPOINT) - headers['X-Databricks-Azure-Workspace-Resource-Id'] = self.databricks_conn.extra_dejson[ - 'azure_resource_id' - ] - headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token + headers = {} + if 'azure_resource_id' in self.databricks_conn.extra_dejson: + mgmt_token = self._get_aad_token(AZURE_MANAGEMENT_ENDPOINT) + headers['X-Databricks-Azure-Workspace-Resource-Id'] = self.databricks_conn.extra_dejson[ + 'azure_resource_id' + ] + headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token return headers def _do_api_call(self, endpoint_info, json): @@ -270,9 +271,8 @@ def _do_api_call(self, endpoint_info, json): method, endpoint = endpoint_info url = f'https://{self.host}/{endpoint}' - headers = USER_AGENT_HEADER.copy() - if 'azure_resource_id' in self.databricks_conn.extra_dejson: - headers = self._fill_aad_headers(headers) + aad_headers = self._get_aad_headers() + headers = {**USER_AGENT_HEADER.copy(), **aad_headers} if 'token' in self.databricks_conn.extra_dejson: self.log.info( From 62347792789fd2099883c0cff95f8676423da876 Mon Sep 17 00:00:00 2001 From: eskarimov Date: Tue, 30 Nov 2021 20:46:18 +0100 Subject: [PATCH 3/4] Extract check for Azure Metadata Service into a separate function and cover with tests --- .../providers/databricks/hooks/databricks.py | 37 +++++++++------- .../databricks/hooks/test_databricks.py | 44 +++++++++++++++++++ 2 files changed, 65 insertions(+), 16 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index bf77d94a05cc5..6b0fb97a3d55b 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -255,6 +255,26 @@ def _get_aad_headers(self) -> dict: headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token return headers + @staticmethod + def _check_azure_metadata_service() -> None: + """ + Check for Azure Metadata Service + https://docs.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service + """ + try: + jsn = requests.get( + AZURE_METADATA_SERVICE_TOKEN_URL, + params={"api-version": "2021-02-01"}, + headers={"Metadata": "true"}, + timeout=2, + ).json() + if 'compute' not in jsn or 'azEnvironment' not in jsn['compute']: + raise AirflowException( + f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}" + ) + except (requests_exceptions.RequestException, ValueError) as e: + raise AirflowException(f"Can't reach Azure Metadata Service: {e}") + def _do_api_call(self, endpoint_info, json): """ Utility function to perform an API call with retries @@ -289,22 +309,7 @@ def _do_api_call(self, endpoint_info, json): auth = _TokenAuth(self._get_aad_token(DEFAULT_DATABRICKS_SCOPE)) elif self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False): self.log.info('Using AAD Token for managed identity.') - # check for Azure Metadata Service - # https://docs.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service - try: - jsn = requests.get( - AZURE_METADATA_SERVICE_TOKEN_URL, - params={"api-version": "2021-02-01"}, - headers={"Metadata": "true"}, - timeout=2, - ).json() - if 'compute' not in jsn or 'azEnvironment' not in jsn['compute']: - raise AirflowException( - f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}" - ) - except (requests_exceptions.RequestException, ValueError) as e: - raise AirflowException(f"Can't reach Azure Metadata Service: {e}") - + self._check_azure_metadata_service() auth = _TokenAuth(self._get_aad_token(DEFAULT_DATABRICKS_SCOPE)) else: self.log.info('Using basic auth.') diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index ef7b31334da7f..ea688e87dc955 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -32,6 +32,7 @@ from airflow.providers.databricks.hooks.databricks import ( AZURE_DEFAULT_AD_ENDPOINT, AZURE_MANAGEMENT_ENDPOINT, + AZURE_METADATA_SERVICE_TOKEN_URL, AZURE_TOKEN_SERVICE_URL, DEFAULT_DATABRICKS_SCOPE, SUBMIT_RUN_ENDPOINT, @@ -772,3 +773,46 @@ def test_submit_run(self, mock_requests): assert kwargs['auth'].token == TOKEN assert kwargs['headers']['X-Databricks-Azure-Workspace-Resource-Id'] == '/Some/resource' assert kwargs['headers']['X-Databricks-Azure-SP-Management-Token'] == TOKEN + + +class TestDatabricksHookAadTokenManagedIdentity(unittest.TestCase): + """ + Tests for DatabricksHook when auth is done with AAD leveraging Managed Identity authentication + """ + + @provide_session + def setUp(self, session=None): + conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() + conn.host = HOST + conn.extra = json.dumps( + { + 'use_azure_managed_identity': True, + } + ) + session.commit() + self.hook = DatabricksHook() + + @mock.patch('airflow.providers.databricks.hooks.databricks.requests') + def test_submit_run(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.get.side_effect = [ + create_successful_response_mock({'compute': {'azEnvironment': 'AZUREPUBLICCLOUD'}}), + create_successful_response_mock(create_aad_token_for_resource(DEFAULT_DATABRICKS_SCOPE)), + ] + mock_requests.post.side_effect = [ + create_successful_response_mock({'run_id': '1'}), + ] + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + data = {'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER} + run_id = self.hook.submit_run(data) + + ad_call_args = mock_requests.method_calls[0] + assert ad_call_args[1][0] == AZURE_METADATA_SERVICE_TOKEN_URL + assert ad_call_args[2]['params']['api-version'] > '2018-02-01' + assert ad_call_args[2]['headers']['Metadata'] == 'true' + + assert run_id == '1' + args = mock_requests.post.call_args + kwargs = args[1] + assert kwargs['auth'].token == TOKEN From 9d08a127014d9f245163dbc5b9eb5f8138f48acd Mon Sep 17 00:00:00 2001 From: eskarimov Date: Tue, 30 Nov 2021 20:49:03 +0100 Subject: [PATCH 4/4] Place `_is_aad_token_valid()` together with other internal functions --- .../providers/databricks/hooks/databricks.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 6b0fb97a3d55b..ac8d9511e0676 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -255,6 +255,20 @@ def _get_aad_headers(self) -> dict: headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token return headers + @staticmethod + def _is_aad_token_valid(aad_token: dict) -> bool: + """ + Utility function to check AAD token hasn't expired yet + :param aad_token: dict with properties of AAD token + :type aad_token: dict + :return: true if token is valid, false otherwise + :rtype: bool + """ + now = int(time.time()) + if aad_token['expires_on'] > (now + TOKEN_REFRESH_LEAD_TIME): + return True + return False + @staticmethod def _check_azure_metadata_service() -> None: """ @@ -520,20 +534,6 @@ def uninstall(self, json: dict) -> None: """ self._do_api_call(UNINSTALL_LIBS_ENDPOINT, json) - @staticmethod - def _is_aad_token_valid(aad_token: dict) -> bool: - """ - Utility function to check AAD token hasn't expired yet - :param aad_token: dict with properties of AAD token - :type aad_token: dict - :return: true if token is valid, false otherwise - :rtype: bool - """ - now = int(time.time()) - if aad_token['expires_on'] > (now + TOKEN_REFRESH_LEAD_TIME): - return True - return False - def _retryable_error(exception) -> bool: return (