From 58c34e562fb67b43e3b2ded7eb380fd5b1ea8e7c Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Mon, 5 Aug 2024 14:39:20 +0530 Subject: [PATCH 1/5] Deprecate the provider and proxy to upstream Airflow provider --- README.md | 3 + docs/index.rst | 23 + pyproject.toml | 8 +- src/astro_databricks/__init__.py | 2 +- src/astro_databricks/operators/common.py | 328 +------------ src/astro_databricks/operators/notebook.py | 369 +------------- src/astro_databricks/operators/workflow.py | 429 +--------------- src/astro_databricks/plugins/plugin.py | 537 +++------------------ src/astro_databricks/settings.py | 1 + tests/databricks/test_common.py | 423 +--------------- tests/databricks/test_notebook.py | 437 +---------------- tests/databricks/test_plugin.py | 417 +--------------- tests/databricks/test_workflow.py | 433 +---------------- 13 files changed, 164 insertions(+), 3246 deletions(-) diff --git a/README.md b/README.md index 725d268..9524b91 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,6 @@ +## Deprecation Notice +This provider is now deprecated since version 0.3.0 and will not be maintained. Please use the official [Apache Airflow Databricks Provider >= 6.8.0](https://airflow.apache.org/docs/apache-airflow-providers-databricks/stable/index.html) instead. +

Databricks Workflows in Airflow

diff --git a/docs/index.rst b/docs/index.rst index 96af895..4bac7a9 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,3 +1,26 @@ +.. warning:: + All the operators and their functionality within this repository have been deprecated and will not receive further updates. + Read more about the deprecation in the `Deprecation Notice` section below. + +Deprecation Notice +------------------ + +With the release ``0.3.0`` of the ``astro-provider-databricks`` package, this provider stands deprecated and will +no longer receive updates. We recommend migrating to the official Apache Airflow Databricks Provider for the latest features and support. +For the operators and sensors that are deprecated in this repository, migrating to the official Apache Airflow Databricks Provider +is as simple as changing the import path from + +.. code-block:: + + from astro_databricks import import SomeOperator + +to + +.. code-block:: + + from airflow.providers.databricks.operators.operator_module import SomeOperator + + Astro Databricks Provider ========================= diff --git a/pyproject.toml b/pyproject.toml index 280439e..8b619bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,12 +35,8 @@ classifiers = [ "Programming Language :: Python :: 3.10", ] dependencies = [ - "apache-airflow>=2.3", - "databricks-sql-connector>=2.0.4;python_version>='3.10'", - "databricks-cli>=0.17.7", - "apache-airflow-providers-databricks>=2.2.0", - "mergedeep", - "pydantic>=1.10.0", + "apache-airflow>=2.7", + "apache-airflow-providers-databricks>=6.8.0rc1", ] [project.optional-dependencies] diff --git a/src/astro_databricks/__init__.py b/src/astro_databricks/__init__.py index ab31753..7d435e7 100644 --- a/src/astro_databricks/__init__.py +++ b/src/astro_databricks/__init__.py @@ -3,7 +3,7 @@ from astro_databricks.operators.notebook import DatabricksNotebookOperator from astro_databricks.operators.workflow import DatabricksWorkflowTaskGroup -__version__ = "0.2.2" +__version__ = "0.3.0" __all__ = [ "DatabricksNotebookOperator", "DatabricksWorkflowTaskGroup", diff --git a/src/astro_databricks/operators/common.py b/src/astro_databricks/operators/common.py index 4a7f5a9..e944615 100644 --- a/src/astro_databricks/operators/common.py +++ b/src/astro_databricks/operators/common.py @@ -1,325 +1,25 @@ -"""DatabricksNotebookOperator for submitting notebook jobs to databricks.""" +""" +This module is deprecated and will be removed in future versions. +Please use `airflow.providers.databricks.operators.databricks.DatabricksTaskOperator` instead. +""" from __future__ import annotations -import time from typing import Any -import airflow -from airflow.exceptions import AirflowException -from airflow.models import BaseOperator -from airflow.providers.databricks.hooks.databricks import DatabricksHook -from airflow.utils.context import Context -from databricks_cli.runs.api import RunsApi -from databricks_cli.sdk.api_client import ApiClient - -from astro_databricks.operators.workflow import ( - DatabricksMetaData, - DatabricksWorkflowTaskGroup, -) -from astro_databricks.plugins.plugin import ( - DatabricksJobRepairSingleFailedLink, - DatabricksJobRunLink, +from airflow.providers.databricks.operators.databricks import ( + DatabricksTaskOperator as UpstreamDatabricksTaskOperator, ) -from astro_databricks.settings import DATABRICKS_JOBS_API_VERSION -class DatabricksTaskOperator(BaseOperator): +class DatabricksTaskOperator(UpstreamDatabricksTaskOperator): """ - Launches a All Types Task to databricks using an Airflow operator. - - The DatabricksTaskOperator allows users to launch and monitor task - deployments on Databricks as Aiflow tasks. - It can be used as a part of a DatabricksWorkflowTaskGroup to take advantage of job clusters, - which allows users to run their tasks on cheaper clusters that can be shared between tasks. - - Here is an example of running a notebook as a part of a workflow task group: - - .. code-block: python - - with dag: - task_group = DatabricksWorkflowTaskGroup( - group_id="test_workflow", - databricks_conn_id="databricks_conn", - job_clusters=job_cluster_spec, - ) - with task_group: - task_1 = DatabricksTaskOperator( - task_id="task_1", - databricks_conn_id="databricks_conn", - job_cluster_key="Shared_job_cluster", - task_config={ - "notebook_task": { - "notebook_path": "/Users/daniel@astronomer.io/Test workflow", - "source": "WORKSPACE", - "base_parameters": { - "end_time": "{{ ts }}", - "start_time": "{{ ts }}", - }, - }, - "libraries": [ - {"pypi": {"package": "scikit-learn"}}, - {"pypi": {"package": "pandas"}}, - ], - }, - ) - task_2 = DatabricksTaskOperator( - task_id="task_2", - databricks_conn_id="databricks_conn", - job_cluster_key="Shared_job_cluster", - task_config={ - "spark_jar_task": { - "main_class_name": "jar.main.class.here", - "parameters": [ - "--key", - "value", - ], - "run_as_repl": "true", - }, - "libraries": [ - { - "jar": "your.jar.path/file.jar" - } - ], - }, - ) - task_1 >> task_2 - - :param task_id: the task name displayed in Databricks and Airflow. - :param databricks_conn_id: the connection id to use to connect to Databricks - :param job_cluster_key: the connection id to use to connect to Databricks - :param task_config: Please write appropriate configuration values for various tasks provided by Databricks - such as notebook_task, spark_jar_task, spark_python_task, spark_submit_task, etc. - For more information please visit - https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate + This class is deprecated and will be removed in future versions. + Please use `airflow.providers.databricks.operators.databricks.DatabricksTaskOperator` instead. """ - operator_extra_links = ( - DatabricksJobRunLink(), - DatabricksJobRepairSingleFailedLink(), - ) - template_fields = ("databricks_metadata",) - - def __init__( - self, - databricks_conn_id: str, - task_config: dict | None = None, - job_cluster_key: str | None = None, - new_cluster: dict | None = None, - existing_cluster_id: str | None = None, - **kwargs, - ): - if new_cluster and existing_cluster_id: - raise ValueError( - "Both new_cluster and existing_cluster_id are set. Only one can be set." - ) - - self.task_config = task_config or {} - self.databricks_conn_id = databricks_conn_id - self.databricks_run_id = "" - self.databricks_metadata: dict | None = None - self.job_cluster_key = job_cluster_key - self.new_cluster = new_cluster - self.existing_cluster_id = existing_cluster_id or "" - super().__init__(**kwargs) - - # For Airflow versions <2.3, the `task_group` attribute is unassociated, and hence we need to add that. - if not hasattr(self, "task_group"): - from airflow.utils.task_group import TaskGroupContext - - self.task_group = TaskGroupContext.get_current_task_group(self.dag) - - def _get_task_base_json(self) -> dict[str, Any]: - """Get task base json to be used for task group tasks and single task submissions.""" - return self.task_config - - def find_parent_databricks_workflow_task_group(self): - """ - Find the closest parent_group which is an instance of DatabricksWorkflowTaskGroup. - - In the case of Airflow 2.2.x, inner Task Groups do not inherit properties from Parent Task Groups like more - recent versions of Airflow. This lead to the issue: - https://github.com/astronomer/astro-provider-databricks/pull/47 - """ - parent_group = self.task_group - while parent_group: - if parent_group.__class__.__name__ == "DatabricksWorkflowTaskGroup": - return parent_group - parent_group = parent_group._parent_group - - def convert_to_databricks_workflow_task( - self, relevant_upstreams: list[BaseOperator], context: Context | None = None - ): - """ - Convert the operator to a Databricks workflow task that can be a task in a workflow - """ - if airflow.__version__ in ("2.2.4", "2.2.5"): - self.find_parent_databricks_workflow_task_group() - else: - pass - - if context: - # The following exception currently only happens on Airflow 2.3, with the following error: - # airflow.exceptions.AirflowException: XComArg result from test_workflow.launch at example_databricks_workflow with key="return_value" is not found! - try: - self.render_template_fields(context) - except AirflowException: - self.log.exception("Unable to process template fields") - - base_task_json = self._get_task_base_json() - result = { - "task_key": self._get_databricks_task_id(self.task_id), - "depends_on": [ - {"task_key": self._get_databricks_task_id(t)} - for t in self.upstream_task_ids - if t in relevant_upstreams - ], - **base_task_json, - } - - if self.job_cluster_key: - result["job_cluster_key"] = self.job_cluster_key - - return result - - def _get_databricks_task_id(self, task_id: str): - """Get the databricks task ID using dag_id and task_id. removes illegal characters.""" - return self.dag_id + "__" + task_id.replace(".", "__") - - def monitor_databricks_job(self): - """Monitor the Databricks job until it completes. Raises Airflow exception if the job fails.""" - api_client = self._get_api_client() - runs_api = RunsApi(api_client) - current_task = self._get_current_databricks_task(runs_api) - url = runs_api.get_run( - self.databricks_run_id, version=DATABRICKS_JOBS_API_VERSION - )["run_page_url"] - self.log.info(f"Check the job run in Databricks: {url}") - self._wait_for_pending_task(current_task, runs_api) - self._wait_for_running_task(current_task, runs_api) - self._wait_for_terminating_task(current_task, runs_api) - final_state = runs_api.get_run( - current_task["run_id"], version=DATABRICKS_JOBS_API_VERSION - )["state"] - self._handle_final_state(final_state) - - def _get_current_databricks_task(self, runs_api): - return { - x["task_key"]: x - for x in runs_api.get_run( - self.databricks_run_id, version=DATABRICKS_JOBS_API_VERSION - )["tasks"] - }[self._get_databricks_task_id(self.task_id)] - - def _handle_final_state(self, final_state): - if final_state.get("life_cycle_state", None) != "TERMINATED": - raise AirflowException( - f"Databricks job failed with state {final_state}. Message: {final_state['state_message']}" - ) - if final_state["result_state"] != "SUCCESS": - raise AirflowException( - "Task failed. Final State %s. Reason: %s", - final_state["result_state"], - final_state["state_message"], - ) - - def _get_lifestyle_state(self, current_task, runs_api): - return runs_api.get_run( - current_task["run_id"], version=DATABRICKS_JOBS_API_VERSION - )["state"]["life_cycle_state"] - - def _wait_on_state(self, current_task, runs_api, state): - while self._get_lifestyle_state(current_task, runs_api) == state: - print(f"task {self.task_id.replace('.', '__')} {state.lower()}...") - time.sleep(5) - - def _wait_for_terminating_task(self, current_task, runs_api): - self._wait_on_state(current_task, runs_api, "TERMINATING") - - def _wait_for_running_task(self, current_task, runs_api): - self._wait_on_state(current_task, runs_api, "RUNNING") - - def _wait_for_pending_task(self, current_task, runs_api): - self._wait_on_state(current_task, runs_api, "PENDING") - - def _get_api_client(self): - hook = DatabricksHook(self.databricks_conn_id) - databricks_conn = hook.get_conn() - return ApiClient( - user=databricks_conn.login, - token=databricks_conn.password, - host=databricks_conn.host, + def __init__(self, *args: Any, **kwargs: Any): + self.log.warning( + "DatabricksTaskOperator is deprecated and will be removed in future versions. " + "Please use `airflow.providers.databricks.operators.databricks.DatabricksTaskOperator` instead." ) - - def launch_task_job(self): - """Launch the notebook as a one-time job to Databricks.""" - api_client = self._get_api_client() - base_task_json = self._get_task_base_json() - run_json = { - "run_name": self._get_databricks_task_id(self.task_id), - **base_task_json, - } - if self.new_cluster and self.existing_cluster_id: - raise ValueError( - "Both new_cluster and existing_cluster_id are set. Only one can be set." - ) - if self.existing_cluster_id: - run_json["existing_cluster_id"] = self.existing_cluster_id - elif self.new_cluster: - run_json["new_cluster"] = self.new_cluster - else: - raise ValueError("Must specify either existing_cluster_id or new_cluster") - runs_api = RunsApi(api_client) - run = runs_api.submit_run(run_json) - self.databricks_run_id = run["run_id"] - return run - - def execute(self, context: Context) -> Any: - """ - Execute the DataBricksNotebookOperator. - - Executes the DataBricksNotebookOperator. If the task is inside of a - DatabricksWorkflowTaskGroup, it assumes the notebook is already launched - and proceeds to monitor the running notebook. - - :param context: - :return: - """ - if self.databricks_task_group: - # if we are in a workflow, we assume there is an upstream launch task - if not self.databricks_metadata: - launch_task_id = [ - task for task in self.upstream_task_ids if task.endswith(".launch") - ][0] - self.databricks_metadata = context["ti"].xcom_pull( - task_ids=launch_task_id - ) - databricks_metadata = DatabricksMetaData(**self.databricks_metadata) - self.databricks_run_id = databricks_metadata.databricks_run_id - self.databricks_conn_id = databricks_metadata.databricks_conn_id - else: - self.launch_task_job() - - self.monitor_databricks_job() - - @property - def databricks_task_group(self) -> DatabricksWorkflowTaskGroup | None: - """ - Traverses up parent TaskGroups until the `is_databricks` flag is found. - If found, returns the task group. Otherwise, returns None. - """ - parent_tg = self.task_group - - while parent_tg: - if hasattr(parent_tg, "is_databricks") and getattr( - parent_tg, "is_databricks" - ): - return parent_tg - - # here, we rely on the fact that Airflow sets the task_group property on tasks/task groups - # if that ever changes, we will need to update this - if hasattr(parent_tg, "task_group") and getattr(parent_tg, "task_group"): - parent_tg = parent_tg.task_group - else: - return None - - return None + super().__init__(*args, **kwargs) diff --git a/src/astro_databricks/operators/notebook.py b/src/astro_databricks/operators/notebook.py index 4146938..4d0578b 100644 --- a/src/astro_databricks/operators/notebook.py +++ b/src/astro_databricks/operators/notebook.py @@ -1,366 +1,25 @@ -"""DatabricksNotebookOperator for submitting notebook jobs to databricks.""" +""" +This module is deprecated and will be removed in future versions. +Please use `airflow.providers.databricks.operators.databricks.DatabricksNotebookOperator` instead. +""" from __future__ import annotations -import time from typing import Any -import airflow -from airflow.exceptions import AirflowException -from airflow.models import BaseOperator -from airflow.providers.databricks.hooks.databricks import DatabricksHook -from airflow.utils.context import Context -from airflow.utils.task_group import TaskGroup -from databricks_cli.runs.api import RunsApi -from databricks_cli.sdk.api_client import ApiClient - -from astro_databricks import settings -from astro_databricks.operators.workflow import ( - DatabricksMetaData, - DatabricksWorkflowTaskGroup, -) -from astro_databricks.plugins.plugin import ( - DatabricksJobRepairSingleFailedLink, - DatabricksJobRunLink, +from airflow.providers.databricks.operators.databricks import ( + DatabricksNotebookOperator as UpstreamDatabricksNotebookOperator, ) -class DatabricksNotebookOperator(BaseOperator): +class DatabricksNotebookOperator(UpstreamDatabricksNotebookOperator): """ - Launches a notebook to databricks using an Airflow operator. - - The DatabricksNotebookOperator allows users to launch and monitor notebook - deployments on Databricks as Aiflow tasks. - It can be used as a part of a DatabricksWorkflowTaskGroup to take advantage of job clusters, - which allows users to run their tasks on cheaper clusters that can be shared between tasks. - - Here is an example of running a notebook as a part of a workflow task group: - - .. code-block: python - - with dag: - task_group = DatabricksWorkflowTaskGroup( - group_id="test_workflow", - databricks_conn_id="databricks_conn", - job_clusters=job_cluster_spec, - notebook_params={}, - ) - with task_group: - notebook_1 = DatabricksNotebookOperator( - task_id="notebook_1", - databricks_conn_id="databricks_conn", - notebook_path="/Users/daniel@astronomer.io/Test workflow", - source="WORKSPACE", - job_cluster_key="Shared_job_cluster", - ) - notebook_2 = DatabricksNotebookOperator( - task_id="notebook_2", - databricks_conn_id="databricks_conn", - notebook_path="/Users/daniel@astronomer.io/Test workflow", - source="WORKSPACE", - job_cluster_key="Shared_job_cluster", - notebook_params={ - "foo": "bar", - }, - ) - notebook_1 >> notebook_2 - - - :param notebook_path: the path to the notebook in Databricks - :param source: Optional location type of the notebook. When set to WORKSPACE, the notebook will be retrieved - from the local Databricks workspace. When set to GIT, the notebook will be retrieved from a Git repository - defined in git_source. If the value is empty, the task will use GIT if git_source is defined - and WORKSPACE otherwise. For more information please visit - https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate - :param databricks_conn_id: the connection id to use to connect to Databricks - :param notebook_params: the parameters to pass to the notebook + This class is deprecated and will be removed in future versions. + Please use `airflow.providers.databricks.operators.databricks.DatabricksNotebookOperator` instead. """ - operator_extra_links = ( - DatabricksJobRunLink(), - DatabricksJobRepairSingleFailedLink(), - ) - template_fields = ( - "databricks_metadata", - "notebook_params", - ) - - def __init__( - self, - notebook_path: str, - source: str, - databricks_conn_id: str, - notebook_params: dict | None = None, - notebook_packages: list[dict[str, Any]] = None, - job_cluster_key: str | None = None, - new_cluster: dict | None = None, - existing_cluster_id: str | None = None, - **kwargs, - ): - if new_cluster and existing_cluster_id: - raise ValueError( - "Both new_cluster and existing_cluster_id are set. Only one can be set." - ) - - self.notebook_path = notebook_path - self.source = source - self.notebook_params = notebook_params or {} - self.notebook_packages = notebook_packages or [] - self.databricks_conn_id = databricks_conn_id - self.databricks_run_id = "" - self.databricks_metadata: dict | None = None - self.job_cluster_key = job_cluster_key or "" - self.new_cluster = new_cluster or {} - self.existing_cluster_id = existing_cluster_id or "" - super().__init__(**kwargs) - - # For Airflow versions <2.3, the `task_group` attribute is unassociated, and hence we need to add that. - if not hasattr(self, "task_group"): - from airflow.utils.task_group import TaskGroupContext - - self.task_group = TaskGroupContext.get_current_task_group(self.dag) - - def _get_task_base_json(self) -> dict[str, Any]: - """Get task base json to be used for task group tasks and single task submissions.""" - return { - # Timeout seconds value of 0 for the Databricks Jobs API means the job runs forever. - # That is also the default behavior of Databricks jobs to run a job forever without a default timeout value. - "timeout_seconds": int(self.execution_timeout.total_seconds()) - if self.execution_timeout - else 0, - "email_notifications": {}, - "notebook_task": { - "notebook_path": self.notebook_path, - "source": self.source, - "base_parameters": self.notebook_params, - }, - "libraries": self.notebook_packages, - } - - def merge_notebook_packages(self, databricks_task_group: TaskGroup): - """ - Merge the task group notebook packages into the notebook's packages, without adding any identical duplicates. - Modifies self.notebook_packages in place. - - Example value for self.notebook_packages: - [ - {"pypi": {"package": "requests_toolbelt==1.0.0"}} - ] - - """ - for task_group_package in databricks_task_group.notebook_packages: - exists = False - for existing_package in self.notebook_packages: - if task_group_package == existing_package: - exists = True - break - if not exists: - self.notebook_packages.append(task_group_package) - - def find_parent_databricks_workflow_task_group(self): - """ - Find the closest parent_group which is an instance of DatabricksWorkflowTaskGroup. - - In the case of Airflow 2.2.x, inner Task Groups do not inherit properties from Parent Task Groups like more - recent versions of Airflow. This lead to the issue: - https://github.com/astronomer/astro-provider-databricks/pull/47 - """ - parent_group = self.task_group - while parent_group: - if parent_group.__class__.__name__ == "DatabricksWorkflowTaskGroup": - return parent_group - parent_group = parent_group._parent_group - - def convert_to_databricks_workflow_task( - self, relevant_upstreams: list[BaseOperator], context: Context | None = None - ): - """ - Convert the operator to a Databricks workflow task that can be a task in a workflow - """ - if airflow.__version__ in ("2.2.4", "2.2.5"): - databricks_task_group = self.find_parent_databricks_workflow_task_group() - else: - databricks_task_group = self.databricks_task_group - - if databricks_task_group and hasattr( - databricks_task_group, "notebook_packages" - ): - self.merge_notebook_packages(databricks_task_group) - - if databricks_task_group and hasattr(databricks_task_group, "notebook_params"): - self.notebook_params = { - **self.notebook_params, - **databricks_task_group.notebook_params, - } - if context: - # The following exception currently only happens on Airflow 2.3, with the following error: - # airflow.exceptions.AirflowException: XComArg result from test_workflow.launch at example_databricks_workflow with key="return_value" is not found! - try: - self.render_template_fields(context) - except AirflowException: - self.log.exception("Unable to process template fields") - - base_task_json = self._get_task_base_json() - result = { - "task_key": self._get_databricks_task_id(self.task_id), - "depends_on": [ - {"task_key": self._get_databricks_task_id(t)} - for t in self.upstream_task_ids - if t in relevant_upstreams - ], - **base_task_json, - } - - if self.existing_cluster_id and self.job_cluster_key: - raise ValueError ("Both existing_cluster_id and job_cluster_key are set. Only one cluster can be set per task.") - - if self.existing_cluster_id: - result['existing_cluster_id'] = self.existing_cluster_id - elif self.job_cluster_key: - result['job_cluster_key'] = self.job_cluster_key - - return result - - def _get_databricks_task_id(self, task_id: str): - """Get the databricks task ID using dag_id and task_id. removes illegal characters.""" - return self.dag_id + "__" + task_id.replace(".", "__") - - def monitor_databricks_job(self): - """Monitor the Databricks job until it completes. Raises Airflow exception if the job fails.""" - api_client = self._get_api_client() - runs_api = RunsApi(api_client) - current_task = self._get_current_databricks_task(runs_api) - url = runs_api.get_run( - self.databricks_run_id, version=settings.DATABRICKS_JOBS_API_VERSION - )["run_page_url"] - self.log.info(f"Check the job run in Databricks: {url}") - self._wait_for_pending_task(current_task, runs_api) - self._wait_for_running_task(current_task, runs_api) - self._wait_for_terminating_task(current_task, runs_api) - final_state = runs_api.get_run( - current_task["run_id"], version=settings.DATABRICKS_JOBS_API_VERSION - )["state"] - self._handle_final_state(final_state) - - def _get_current_databricks_task(self, runs_api): - return { - x["task_key"]: x - for x in runs_api.get_run( - self.databricks_run_id, version=settings.DATABRICKS_JOBS_API_VERSION - )["tasks"] - }[self._get_databricks_task_id(self.task_id)] - - def _handle_final_state(self, final_state): - if final_state.get("life_cycle_state", None) != "TERMINATED": - raise AirflowException( - f"Databricks job failed with state {final_state}. Message: {final_state['state_message']}" - ) - if final_state["result_state"] != "SUCCESS": - raise AirflowException( - "Task failed. Final State %s. Reason: %s", - final_state["result_state"], - final_state["state_message"], - ) - - def _get_lifestyle_state(self, current_task, runs_api): - return runs_api.get_run( - current_task["run_id"], version=settings.DATABRICKS_JOBS_API_VERSION - )["state"]["life_cycle_state"] - - def _wait_on_state(self, current_task, runs_api, state): - while self._get_lifestyle_state(current_task, runs_api) == state: - print(f"task {self.task_id.replace('.', '__')} {state.lower()}...") - time.sleep(5) - - def _wait_for_terminating_task(self, current_task, runs_api): - self._wait_on_state(current_task, runs_api, "TERMINATING") - - def _wait_for_running_task(self, current_task, runs_api): - self._wait_on_state(current_task, runs_api, "RUNNING") - - def _wait_for_pending_task(self, current_task, runs_api): - self._wait_on_state(current_task, runs_api, "PENDING") - - def _get_api_client(self): - hook = DatabricksHook(self.databricks_conn_id) - databricks_conn = hook.get_conn() - return ApiClient( - user=databricks_conn.login, - token=databricks_conn.password, - host=databricks_conn.host, - ) - - def launch_notebook_job(self): - """Launch the notebook as a one-time job to Databricks.""" - api_client = self._get_api_client() - base_task_json = self._get_task_base_json() - run_json = { - "run_name": self._get_databricks_task_id(self.task_id), - **base_task_json, - } - if self.new_cluster and self.existing_cluster_id: - raise ValueError( - "Both new_cluster and existing_cluster_id are set. Only one can be set." - ) - if self.existing_cluster_id: - run_json["existing_cluster_id"] = self.existing_cluster_id - elif self.new_cluster: - run_json["new_cluster"] = self.new_cluster - else: - raise ValueError("Must specify either existing_cluster_id or new_cluster") - runs_api = RunsApi(api_client) - run = runs_api.submit_run( - run_json, version=settings.DATABRICKS_JOBS_API_VERSION + def __init__(self, *args: Any, **kwargs: Any): + self.log.warning( + "DatabricksNotebookOperator is deprecated and will be removed in future versions." + "Please use `airflow.providers.databricks.operators.databricks.DatabricksNotebookOperator` instead." ) - self.databricks_run_id = run["run_id"] - return run - - def execute(self, context: Context) -> Any: - """ - Execute the DataBricksNotebookOperator. - - Executes the DataBricksNotebookOperator. If the task is inside of a - DatabricksWorkflowTaskGroup, it assumes the notebook is already launched - and proceeds to monitor the running notebook. - - :param context: - :return: - """ - if self.databricks_task_group: - # if we are in a workflow, we assume there is an upstream launch task - if not self.databricks_metadata: - launch_task_id = [ - task for task in self.upstream_task_ids if task.endswith(".launch") - ][0] - self.databricks_metadata = context["ti"].xcom_pull( - task_ids=launch_task_id - ) - databricks_metadata = DatabricksMetaData(**self.databricks_metadata) - self.databricks_run_id = databricks_metadata.databricks_run_id - self.databricks_conn_id = databricks_metadata.databricks_conn_id - else: - self.launch_notebook_job() - - self.monitor_databricks_job() - - @property - def databricks_task_group(self) -> DatabricksWorkflowTaskGroup | None: - """ - Traverses up parent TaskGroups until the `is_databricks` flag is found. - If found, returns the task group. Otherwise, returns None. - """ - parent_tg = self.task_group - - while parent_tg: - if hasattr(parent_tg, "is_databricks") and getattr( - parent_tg, "is_databricks" - ): - return parent_tg - - # here, we rely on the fact that Airflow sets the task_group property on tasks/task groups - # if that ever changes, we will need to update this - if hasattr(parent_tg, "task_group") and getattr(parent_tg, "task_group"): - parent_tg = parent_tg.task_group - else: - return None - - return None + super().__init__(*args, **kwargs) diff --git a/src/astro_databricks/operators/workflow.py b/src/astro_databricks/operators/workflow.py index 625c4c3..282cfd3 100644 --- a/src/astro_databricks/operators/workflow.py +++ b/src/astro_databricks/operators/workflow.py @@ -1,428 +1,25 @@ -"""DatabricksWorkflowTaskGroup for submitting jobs to Databricks.""" +""" +This module is deprecated and will be removed in future versions. +Please use `airflow.providers.databricks.operators.databricks.DatabricksWorkflowTaskGroup` instead. +""" from __future__ import annotations -from logging import Logger -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - pass -import json -import time from typing import Any -from airflow.exceptions import AirflowException -from airflow.models import BaseOperator -from airflow.providers.databricks.hooks.databricks import DatabricksHook -from airflow.utils.context import Context -from airflow.utils.task_group import TaskGroup - -try: - from attrs import define -except ModuleNotFoundError: - from attr import define - -from databricks_cli.jobs.api import JobsApi -from databricks_cli.runs.api import RunsApi -from databricks_cli.sdk.api_client import ApiClient -from mergedeep import merge - -from astro_databricks.plugins.plugin import ( - DatabricksJobRepairAllFailedLink, - DatabricksJobRunLink, +from airflow.providers.databricks.operators.databricks_workflow import ( + DatabricksWorkflowTaskGroup as UpstreamDatabricksWorkflowTaskGroup, ) -from astro_databricks.settings import DATABRICKS_JOBS_API_VERSION - - -@define -class DatabricksMetaData: - databricks_conn_id: str - databricks_run_id: str - databricks_job_id: str - - -def _get_job_by_name(job_name: str, jobs_api: JobsApi) -> dict | None: - jobs = jobs_api.list_jobs(version=DATABRICKS_JOBS_API_VERSION, name=job_name).get("jobs", []) - return jobs[0] if jobs else None - - -def flatten_node( - node: TaskGroup | BaseOperator, tasks: list[BaseOperator] = [] -) -> list[BaseOperator]: - """ - Flattens a node (either a TaskGroup or Operator) to a list of nodes - """ - if isinstance(node, BaseOperator): - return [node] - - if isinstance(node, TaskGroup): - new_tasks = [] - for id_, child in node.children.items(): - new_tasks += flatten_node(child, tasks) - - return tasks + new_tasks - - return tasks - - -class _CreateDatabricksWorkflowOperator(BaseOperator): - """Creates a databricks workflow from a DatabricksWorkflowTaskGroup. - - :param task_id: The task id of the operator - :param databricks_conn_id: The databricks connection id - :param job_clusters: A list of job clusters to use in the workflow - :param existing_clusters: A list of existing clusters to use in the workflow - :param max_concurrent_runs: The maximum number of concurrent runs - :param tasks_to_convert: A list of tasks to convert to a workflow. This list can also - be populated after initialization by calling add_task. - :param extra_job_params: A dictionary containing properties which will override the - default Databricks Workflow Job definitions. - :param notebook_params: A dictionary of notebook parameters to pass to the workflow.These parameters will be passed to - all notebook tasks in the workflow. - """ - - template_fields = ("notebook_params",) - - operator_extra_links = (DatabricksJobRunLink(), DatabricksJobRepairAllFailedLink()) - databricks_conn_id: str - databricks_run_id: str - databricks_job_id: str - - def __init__( - self, - task_id, - databricks_conn_id, - job_clusters: list[dict[str, object]] = None, - existing_clusters: list[str] = None, - max_concurrent_runs: int = 1, - tasks_to_convert: list[BaseOperator] = None, - extra_job_params: dict[str, Any] = None, - notebook_params: dict | None = None, - **kwargs, - ): - self.existing_clusters = existing_clusters or [] - self.job_clusters = job_clusters or [] - self.job_cluster_dict = {j["job_cluster_key"]: j for j in self.job_clusters} - self.tasks_to_convert = tasks_to_convert or [] - self.relevant_upstreams = [task_id] - self.databricks_conn_id = databricks_conn_id - self.databricks_run_id = None - self.max_concurrent_runs = max_concurrent_runs - self.extra_job_params = extra_job_params or {} - self.notebook_params = notebook_params or {} - super().__init__(task_id=task_id, **kwargs) - - # For Airflow versions <2.3, the `task_group` attribute is unassociated, and hence we need to add that. - if not hasattr(self, "task_group"): - from airflow.utils.task_group import TaskGroupContext - - self.task_group = TaskGroupContext.get_current_task_group(self.dag) - - def add_task(self, task: BaseOperator): - """ - Add a task to the list of tasks to convert to a workflow. - - :param task: - :return: - """ - self.tasks_to_convert.append(task) - - def create_workflow_json(self, context: Context | None = None) -> dict[str, object]: - """Create a workflow json that can be submitted to databricks. - - :return: A workflow json - """ - task_json = [ - task.convert_to_databricks_workflow_task( - relevant_upstreams=self.relevant_upstreams, context=context - ) - for task in self.tasks_to_convert - ] - default_json = { - "name": self.databricks_job_name, - "email_notifications": {"no_alert_for_skipped_runs": False}, - "timeout_seconds": 0, - "tasks": task_json, - "format": "MULTI_TASK", - "job_clusters": self.job_clusters, - "max_concurrent_runs": self.max_concurrent_runs, - } - merged_json = merge(default_json, self.extra_job_params) - return merged_json - - @property - def databricks_job_name(self): - return self.dag_id + "." + self.task_group.group_id - - def execute(self, context: Context) -> Any: - hook = DatabricksHook(self.databricks_conn_id) - databricks_conn = hook.get_conn() - api_client = ApiClient( - token=databricks_conn.password, host=databricks_conn.host - ) - jobs_api = JobsApi(api_client) - job = _get_job_by_name(self.databricks_job_name, jobs_api) - - job_id = job["job_id"] if job else None - current_job_spec = self.create_workflow_json(context) - if not isinstance(self.task_group, DatabricksWorkflowTaskGroup): - raise AirflowException("Task group must be a DatabricksWorkflowTaskGroup") - if job_id: - self.log.info( - "Updating existing job with spec %s", - json.dumps(current_job_spec, indent=4), - ) - - jobs_api.reset_job( - json={"job_id": job_id, "new_settings": current_job_spec}, - version=DATABRICKS_JOBS_API_VERSION, - ) - else: - self.log.info( - "Creating new job with spec %s", json.dumps(current_job_spec, indent=4) - ) - job_id = jobs_api.create_job( - json=current_job_spec, version=DATABRICKS_JOBS_API_VERSION - )["job_id"] - - run_id = jobs_api.run_now( - job_id=job_id, - jar_params=self.task_group.jar_params, - notebook_params=self.notebook_params, - python_params=self.task_group.python_params, - spark_submit_params=self.task_group.spark_submit_params, - version=DATABRICKS_JOBS_API_VERSION, - )["run_id"] - self.databricks_run_id = run_id - - runs_api = RunsApi(api_client) - url = runs_api.get_run(run_id, version=DATABRICKS_JOBS_API_VERSION).get( - "run_page_url" - ) - self.log.info(f"Check the job run in Databricks: {url}") - state = runs_api.get_run(run_id, version=DATABRICKS_JOBS_API_VERSION)["state"][ - "life_cycle_state" - ] - self.log.info(f"Job state: {state}") - - if state not in ("PENDING", "BLOCKED", "RUNNING"): - raise AirflowException( - f"Could not start the workflow job, it had state {state}" - ) - while state in ("PENDING", "BLOCKED"): - self.log.info(f"Job {state}") - time.sleep(5) - state = runs_api.get_run(run_id, version=DATABRICKS_JOBS_API_VERSION)[ - "state" - ]["life_cycle_state"] - return { - "databricks_conn_id": self.databricks_conn_id, - "databricks_job_id": job_id, - "databricks_run_id": run_id, - } - - -class DatabricksWorkflowTaskGroup(TaskGroup): +class DatabricksWorkflowTaskGroup(UpstreamDatabricksWorkflowTaskGroup): """ - A task group that takes a list of tasks and creates a databricks workflow. - - The DatabricksWorkflowTaskGroup takes a list of tasks and creates a databricks workflow - based on the metadata produced by those tasks. For a task to be eligible for this - TaskGroup, it must contain the ``convert_to_databricks_workflow_task`` method. If any tasks - do not contain this method then the Taskgroup will raise an error at parse time. - - Here is an example of what a DAG looks like with a DatabricksWorkflowTaskGroup: - - .. code-block:: python - - job_clusters = [ - { - "job_cluster_key": "Shared_job_cluster", - "new_cluster": { - "cluster_name": "", - "spark_version": "11.3.x-scala2.12", - "aws_attributes": { - "first_on_demand": 1, - "availability": "SPOT_WITH_FALLBACK", - "zone_id": "us-east-2b", - "spot_bid_price_percent": 100, - "ebs_volume_count": 0, - }, - "node_type_id": "i3.xlarge", - "spark_env_vars": {"PYSPARK_PYTHON": "/databricks/python3/bin/python3"}, - "enable_elastic_disk": False, - "data_security_mode": "LEGACY_SINGLE_USER_STANDARD", - "runtime_engine": "STANDARD", - "num_workers": 8, - }, - } - ] - - with dag: - task_group = DatabricksWorkflowTaskGroup( - group_id="test_workflow", - databricks_conn_id="databricks_conn", - job_clusters=job_cluster_spec, - notebook_params={}, - notebook_packages=[ - { - "pypi": { - "package": "simplejson" - } - }, - ] - ) - with task_group: - notebook_1 = DatabricksNotebookOperator( - task_id="notebook_1", - databricks_conn_id="databricks_conn", - notebook_path="/Users//Test workflow", - source="WORKSPACE", - job_cluster_key="Shared_job_cluster", - notebook_packages=[ - { - "pypi": { - "package": "Faker" - } - } - ] - ) - notebook_2 = DatabricksNotebookOperator( - task_id="notebook_2", - databricks_conn_id="databricks_conn", - notebook_path="/Users//Test workflow", - source="WORKSPACE", - job_cluster_key="Shared_job_cluster", - notebook_params={ - "foo": "bar", - }, - ) - notebook_1 >> notebook_2 - - With this example, Airflow will produce a job named .test_workflow that will - run notebook_1 and then notebook_2. The job will be created in the databricks workspace - if it does not already exist. If the job already exists, it will be updated to match - the workflow defined in the DAG. - - To minimize update conflicts, we recommend that you keep parameters in the ``notebook_params`` of the - ``DatabricksWorkflowTaskGroup`` and not in the ``DatabricksNotebookOperator`` whenever possible. - This is because tasks in the - ``DatabricksWorkflowTaskGroup`` are passed in at the job trigger time and do not modify the job definition - - :param group_id: The name of the task group - :param databricks_conn_id: The name of the databricks connection to use - :param job_clusters: A list of job clusters to use for this workflow. - :param notebook_params: A dictionary of notebook parameters to pass to the workflow. These parameters will be passed to - all notebook tasks in the workflow. - :param notebook_packages: A list of dictionary of Python packages to be installed. Packages defined at the - workflow task group level are installed for each of the notebook tasks under it. And packages defined at the - notebook task level are installed specific for the notebook task. - :param jar_params: A list of jar parameters to pass to the workflow. These parameters will be passed to all jar - tasks - in the workflow. - :param python_params: A list of python parameters to pass to the workflow. These parameters will be passed to - all python tasks - in the workflow. - :param spark_submit_params: A list of spark submit parameters to pass to the workflow. These parameters - will be passed to all spark submit tasks - :param extra_job_params: A dictionary containing properties which will override the default Databricks Workflow - Job definitions. - :param max_concurrent_runs: The maximum number of concurrent runs for this workflow. - + This class is deprecated and will be removed in future versions. + Please use `airflow.providers.databricks.operators.databricks.DatabricksWorkflowTaskGroup` instead. """ - @property - def log(self) -> Logger: - """Returns logger.""" - pass - - is_databricks = True - - def __init__( - self, - databricks_conn_id, - existing_clusters=None, - job_clusters=None, - jar_params: dict = None, - notebook_params: dict | None = None, - notebook_packages: list[dict[str, Any]] = None, - python_params: list = None, - spark_submit_params: list = None, - max_concurrent_runs: int = 1, - extra_job_params: dict[str, Any] = None, - **kwargs, - ): + def __init__(self, *args: Any, **kwargs: Any): """ - Create a new DatabricksWorkflowTaskGroup. - - :param group_id: The name of the task group - :param databricks_conn_id: The name of the databricks connection to use - :param job_clusters: A list of job clusters to use for this workflow. - :param notebook_params: A dictionary of notebook parameters to pass to the workflow.These parameters will be passed to - all notebook tasks in the workflow. - :param notebook_packages: A list of dictionary of Python packages to be installed. These packages will be passed - to all notebook tasks in the workflow. - :param jar_params: A list of jar parameters to pass to the workflow. - These parameters will be passed to all jar tasks - in the workflow. - :param python_params: A dictionary of python parameters to pass to the workflow. - These parameters will be passed to all python tasks - in the workflow. - :param spark_submit_params: A list of spark submit parameters to pass to the workflow. - These parameters will be passed to all spark submit tasks - :param max_concurrent_runs: The maximum number of concurrent runs for this workflow. - :param extra_job_params: A dictionary containing properties which will override the default Databricks - Workflow Job definitions. + This class is deprecated and will be removed in future versions. + Please use `airflow.providers.databricks.operators.databricks.DatabricksWorkflowTaskGroup` instead. """ - self.databricks_conn_id = databricks_conn_id - self.existing_clusters = existing_clusters or [] - self.job_clusters = job_clusters or [] - self.notebook_params = notebook_params or {} - self.notebook_packages = notebook_packages or [] - self.python_params = python_params or [] - self.spark_submit_params = spark_submit_params or [] - self.jar_params = jar_params or [] - self.max_concurrent_runs = max_concurrent_runs - self.extra_job_params = extra_job_params or {} - super().__init__(**kwargs) - - def __exit__(self, _type, _value, _tb): - """Exit the context manager and add tasks to a single _CreateDatabricksWorkflowOperator.""" - roots = list(self.get_roots()) - tasks = flatten_node(self) - - # For Airflow versions <2.3, the `dag` attribute is unassociated, and hence we need to add that. - if not hasattr(self, "dag"): - from airflow.models.dag import DagContext - - self.dag = DagContext.get_current_dag() - - create_databricks_workflow_task = _CreateDatabricksWorkflowOperator( - dag=self.dag, - task_group=self, - task_id="launch", - databricks_conn_id=self.databricks_conn_id, - job_clusters=self.job_clusters, - existing_clusters=self.existing_clusters, - extra_job_params=self.extra_job_params, - notebook_params=self.notebook_params, - ) - - for task in tasks: - if not ( - hasattr(task, "convert_to_databricks_workflow_task") - and callable(task.convert_to_databricks_workflow_task) - ): - raise AirflowException( - f"Task {task.task_id} does not support conversion to databricks workflow task." - ) - - task.databricks_metadata = create_databricks_workflow_task.output - create_databricks_workflow_task.relevant_upstreams.append(task.task_id) - create_databricks_workflow_task.add_task(task) - - for root_task in roots: - root_task.set_upstream(create_databricks_workflow_task) - - super().__exit__(_type, _value, _tb) + super().__init__(*args, **kwargs) diff --git a/src/astro_databricks/plugins/plugin.py b/src/astro_databricks/plugins/plugin.py index 6f17513..7194c10 100644 --- a/src/astro_databricks/plugins/plugin.py +++ b/src/astro_databricks/plugins/plugin.py @@ -1,503 +1,74 @@ -"""DatabricksWorkflowTaskGroup for submitting jobs to Databricks.""" +""" +This module is deprecated and will be removed in future versions. +Please use `airflow.providers.databricks.plugins.databricks_workflow.DatabricksWorkflowPlugin` instead. +""" from __future__ import annotations -import logging -from operator import itemgetter from typing import Any -from airflow.configuration import conf -from airflow.models import BaseOperator, BaseOperatorLink -from airflow.models.dag import DAG, clear_task_instances -from airflow.models.dagrun import DagRun -from airflow.models.taskinstance import TaskInstance, TaskInstanceKey -from airflow.models.xcom import XCom -from airflow.plugins_manager import AirflowPlugin -from airflow.providers.databricks.hooks.databricks import DatabricksHook -from airflow.security import permissions -from airflow.version import version as airflow_version -from packaging import version - -try: - # The following utility was included in Airflow version 2.3.3; we handle the needed import in the exception block. - from airflow.utils.airflow_flask_app import get_airflow_app -except ModuleNotFoundError: - # For older versions of airflow < 2.3.3 that don't have the utility. - from flask import current_app - -from airflow.exceptions import TaskInstanceNotFound +from airflow.providers.databricks.plugins.databricks_workflow import ( + DatabricksWorkflowPlugin, + RepairDatabricksTasks as UpstreamRepairDatabricksTasks, + WorkflowJobRepairAllFailedLink, + WorkflowJobRepairSingleTaskLink, + WorkflowJobRunLink, +) from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.task_group import TaskGroup -from airflow.www.auth import has_access -from airflow.www.views import AirflowBaseView -from databricks_cli.sdk import JobsService -from databricks_cli.sdk.api_client import ApiClient -from flask import flash, redirect, request -from flask_appbuilder.api import expose -from sqlalchemy.orm.session import Session - -def _get_flask_app(): - """Get the Airflow flask app instance""" - try: - flask_app = get_airflow_app() - except NameError: - flask_app = current_app - return flask_app - -def _get_databricks_task_id(task: BaseOperator) -> str: - """Get the databricks task ID using dag_id and task_id. removes illegal characters. - - :param task: The task to get the databricks task ID for. - :return: The databricks task ID. +class DatabricksJobRunLink(WorkflowJobRunLink, LoggingMixin): """ - return task.dag_id + "__" + task.task_id.replace(".", "__") - - -def get_databricks_task_ids( - group_id: str, task_map: dict[str, BaseOperator], log: logging.Logger -) -> list[str]: - """ - Returns a list of all Databricks task IDs for a dictionary of Airflow tasks. - - :param group_id: The task group ID. - :param task_map: A dictionary mapping task IDs to BaseOperator instances. - :return: A list of Databricks task IDs for the given task group. + This class is deprecated and will be removed in future versions. + Please use `airflow.providers.databricks.plugins.databricks_workflow.WorkflowJobRunLink` instead. """ - task_ids = [] - log.debug("Getting databricks task ids for group %s", group_id) - for task_id, task in task_map.items(): - if task_id == f"{group_id}.launch": - continue - databricks_task_id = _get_databricks_task_id(task) - log.debug("databricks task id for task %s is %s", task_id, databricks_task_id) - task_ids.append(databricks_task_id) - return task_ids - - -@provide_session -def _get_dagrun(dag: DAG, run_id, session=None) -> DagRun: - """ - Retrieves the DagRun object associated with the specified DAG and run_id. - - :param dag: The DAG object associated with the DagRun to retrieve. - :param run_id: The run_id associated with the DagRun to retrieve. - :param session: The SQLAlchemy session to use for the query. If None, uses the default session. - :return: The DagRun object associated with the specified DAG and run_id. - """ - return ( - session.query(DagRun) - .filter(DagRun.dag_id == dag.dag_id, DagRun.run_id == run_id) - .first() - ) - - -@provide_session -def _clear_task_instances( - dag_id: str, run_id: str, task_ids: list[str], log: logging.Logger, session=None -): - dag = _get_flask_app().dag_bag.get_dag(dag_id) - log.debug("task_ids to clear", str(task_ids)) - dr: DagRun = _get_dagrun(dag, run_id) - tis_to_clear = [ - ti for ti in dr.get_task_instances() if _get_databricks_task_id(ti) in task_ids - ] - clear_task_instances(tis_to_clear, session) - - -def _repair_task( - databricks_conn_id: str, - databricks_run_id: str, - tasks_to_repair: list[str], - log: logging.Logger, -) -> dict: - """ - This function allows the Airflow retry function to create a repair job for Databricks. - It uses the Databricks API to get the latest repair ID before sending the repair query. - - Note that we use the `JobsService` class instead of the `RunsApi` class. This is because the - `RunsApi` class does not allow sending the `include_history` parameter which is necessary for - repair jobs. - Also for the moment we don't allow custom retry_callbacks. We might implement this in - the future if users ask for it, but for the moment we want to keep things simple while the API - stabilizes. - - :param databricks_conn_id: The Databricks connection ID. - :param databricks_run_id: The Databricks run ID. - :param tasks_to_repair: A list of Databricks task IDs to repair. - :return: None - """ - - def _get_api_client(): - hook = DatabricksHook(databricks_conn_id) - databricks_conn = hook.get_conn() - return ApiClient( - user=databricks_conn.login, - token=databricks_conn.password, - host=databricks_conn.host, + def __init__(self, *args: Any, **kwargs: Any): + self.log.warning( + "DatabricksJobRunLink is deprecated and will be removed in future versions. " + "Please use `airflow.providers.databricks.plugins.databricks_workflow.WorkflowJobRunLink` instead." ) + super().__init__(*args, **kwargs) - api_client = _get_api_client() - log.debug("Getting latest repair ID") - jobs_service = JobsService(api_client) - current_job = jobs_service.get_run(run_id=databricks_run_id, include_history=True) - repair_history = current_job.get("repair_history") - repair_history_id = None - if ( - repair_history and len(repair_history) > 1 - ): # We use >1 because the first entry is the original run. - # We use the last item in the array to get the latest repair ID - repair_history_id = repair_history[-1]["id"] - log.debug("Latest repair ID is %s", repair_history_id) - log.debug( - "Sending repair query for tasks %s on run %s", - tasks_to_repair, - databricks_run_id, - ) - return jobs_service.repair( - run_id=databricks_run_id, - version="2.1", - latest_repair_id=repair_history_id, - rerun_tasks=tasks_to_repair, - ) - - -def get_task_group_legacy(operator: BaseOperator) -> TaskGroup: - """ - Returns the task group for a given operator. This is a workaround for Airflow 2.2.4. - Unfortunately in Airflow 2.2.4 the task_group property is not set on the operator, so we - have to get the taskgroup tree from the DAG and search for the operator. This allows us to - return the operators group, so we can find the xcom result of the launch task. - - :param operator: The operator to get the task group for. - :return: The task group for the given operator. - """ - - def find_my_group(group: TaskGroup, task_id: str): - groups_to_recurse = set() - for elem in group.children.values(): - if isinstance(elem, TaskGroup): - groups_to_recurse.add(elem) - else: - if operator.task_id == task_id: - return group - for group in groups_to_recurse: - val = find_my_group(group, task_id) - if val: - return val - - return find_my_group(operator.dag.task_group, operator.task_id) - -def get_launch_task_id(task_group: TaskGroup) -> str: +class DatabricksJobRepairAllFailedLink(WorkflowJobRepairAllFailedLink, LoggingMixin): """ - Retrieve the launch task ID from the current task group or a parent task group, - recursively. - - :param task_group: Task Group to be inspected - :return: launch Task ID + This class is deprecated and will be removed in future versions. + Please use `airflow.providers.databricks.plugins.databricks_workflow.WorkflowJobRepairAllFailedLink` instead. """ - try: - launch_task_id = task_group.get_child_by_label("launch").task_id - print("launch task id %s", launch_task_id) - except KeyError: - launch_task_id = get_launch_task_id(task_group.parent_group) - return launch_task_id + def __init__(self, *args: Any, **kwargs: Any): + self.log.warning( + "DatabricksJobRepairAllFailedLink is deprecated and will be removed in future versions. " + "Please use " + "`airflow.providers.databricks.plugins.databricks_workflow.WorkflowJobRepairAllFailedLink` instead." + ) + super().__init__(*args, **kwargs) -def _get_launch_task_key( - current_task_key: TaskInstanceKey, task_id: str -) -> TaskInstanceKey: +class DatabricksJobRepairSingleFailedLink( + WorkflowJobRepairSingleTaskLink, LoggingMixin +): """ - Returns the task key for the launch task. This allows us to gather databricks Metadata - even if the current task has failed (since tasks only create xcom values if they succeed). - - :param current_task_key: The task key for the current task. - :param task_id: The task ID for the current task. - :return: The task key for the launch task. + This class is deprecated and will be removed in future versions. + Please use `airflow.providers.databricks.plugins.databricks_workflow.WorkflowJobRepairSingleTaskLink` instead. """ - if task_id: - return TaskInstanceKey( - dag_id=current_task_key.dag_id, - task_id=task_id, - run_id=current_task_key.run_id, - try_number=current_task_key.try_number, - ) - else: - return current_task_key - - -@provide_session -def get_task_instance(operator, dttm, session: Session = NEW_SESSION): - dag_id = operator.dag.dag_id - dag_run = DagRun.find(dag_id, execution_date=dttm)[0] - ti = ( - session.query(TaskInstance) - .filter( - TaskInstance.dag_id == dag_id, - TaskInstance.run_id == dag_run.run_id, - TaskInstance.task_id == operator.task_id, - ) - .one_or_none() - ) - if not ti: - raise TaskInstanceNotFound("Task instance not found") - return ti - -def get_task_group(operator): - if not hasattr(operator, "task_group"): - task_group = get_task_group_legacy(operator) - else: - task_group = operator.task_group - return task_group - - -def get_xcom_result( - ti_key: TaskInstanceKey, - key: str, - ti: TaskInstance | None, -) -> Any: - # Pull the xcom result for the given task instance and task instance key - try: - result = XCom.get_value( - ti_key=ti_key, - key=key, - ) - except AttributeError: - # For Airflow versions < 2.3.0 which do not contain the XCOM.get_value method implementation. - if not ti: - raise TaskInstanceNotFound( - "Valid task instance is needed to fetch the XCOM result." - ) - result = XCom.get_one( - task_id=ti_key.task_id, - dag_id=ti_key.dag_id, - execution_date=ti.execution_date, - key=key, + def __init__(self, *args: Any, **kwargs: Any): + self.log.warning( + "DatabricksJobRepairSingleFailedLink is deprecated and will be removed in future versions. " + "Please use " + "`airflow.providers.databricks.plugins.databricks_workflow.WorkflowJobRepairSingleTaskLink` instead." ) - from astro_databricks.operators.workflow import DatabricksMetaData - - return DatabricksMetaData(**result) - - -class DatabricksJobRunLink(BaseOperatorLink, LoggingMixin): - """Constructs a link to monitor a Databricks Job Run.""" - - name = "See Databricks Job Run" - - def get_link( - self, - operator: BaseOperator, - dttm=None, - *, - ti_key: TaskInstanceKey | None = None, - ) -> str: - ti = None - if not ti_key: - ti = get_task_instance(operator, dttm) - ti_key = ti.key - task_group = get_task_group(operator) - - dag = _get_flask_app().dag_bag.get_dag(ti_key.dag_id) - dag.get_task(ti_key.task_id) - self.log.info("Getting link for task %s", ti_key.task_id) - if ".launch" not in ti_key.task_id: - self.log.debug( - "Finding the launch task for job run metadata %s", ti_key.task_id - ) - launch_task_id = get_launch_task_id(task_group) - ti_key = _get_launch_task_key(ti_key, task_id=launch_task_id) - # Should we catch the exception here if there is no return value? - metadata = get_xcom_result(ti_key, "return_value", ti) - - hook = DatabricksHook(metadata.databricks_conn_id) - return f"https://{hook.host}/#job/{metadata.databricks_job_id}/run/{metadata.databricks_run_id}" - + super().__init__(*args, **kwargs) -class DatabricksJobRepairAllFailedLink(BaseOperatorLink, LoggingMixin): - """Constructs a link to send a request to repair all failed databricks tasks.""" - name = "Repair All Failed Tasks" - - def get_link( - self, - operator, - dttm=None, - *, - ti_key: TaskInstanceKey | None = None, - ) -> str: - ti = None - if not ti_key: - ti = get_task_instance(operator, dttm) - ti_key = ti.key - task_group = get_task_group(operator) - self.log.debug( - "Creating link to repair all tasks for databricks job run %s", - task_group.group_id, - ) - # Should we catch the exception here if there is no return value? - metadata = get_xcom_result(ti_key, "return_value", ti) - - tasks_str = self.get_tasks_to_run(ti_key, operator, self.log) - self.log.debug("tasks to rerun: %s", tasks_str) - return ( - f"/repair_databricks_job?dag_id={ti_key.dag_id}&" - f"databricks_conn_id={metadata.databricks_conn_id}&" - f"databricks_run_id={metadata.databricks_run_id}&" - f"run_id={ti_key.run_id}&" - f"tasks_to_repair={tasks_str}" - ) - - @classmethod - def get_task_group_children(cls, task_group): - """ - Given a TaskGroup, return children which are Tasks, inspecting recursively any TaskGroups within. - - :param task_group: An Airflow TaskGroup - :return: Dictionary that contains Task IDs as keys and Tasks as values. - """ - children = {} - for child_id, child in task_group.children.items(): - if isinstance(child, TaskGroup): - child_children = cls.get_task_group_children(child) - children = {**children, **child_children} - else: - children[child_id] = child - return children - - def get_tasks_to_run( - self, ti_key: TaskInstanceKey, operator: BaseOperator, log: logging.Logger - ) -> str: - task_group = get_task_group(operator) - dag = _get_flask_app().dag_bag.get_dag(ti_key.dag_id) - dr = _get_dagrun(dag, ti_key.run_id) - log.debug("Getting failed and skipped tasks for dag run %s", dr.run_id) - task_group_sub_tasks = self.get_task_group_children(task_group).items() - failed_and_skipped_tasks = self._get_failed_and_skipped_tasks(dr) - log.debug("Failed and skipped tasks: %s", failed_and_skipped_tasks) - - tasks_to_run = { - ti: t - for ti, t in task_group_sub_tasks - if ti in failed_and_skipped_tasks - } - log.debug( - "Tasks to repair in databricks job %s : %s", - task_group.group_id, - tasks_to_run, +class RepairDatabricksTasks(UpstreamRepairDatabricksTasks, LoggingMixin): + def __init__(self, *args: Any, **kwargs: Any): + self.log.warning( + "RepairDatabricksTasks is deprecated and will be removed in future versions. " + "Please use `airflow.providers.databricks.plugins.databricks_workflow.RepairDatabricksTasks` instead." ) - tasks_str = ",".join( - get_databricks_task_ids(task_group.group_id, tasks_to_run, log) - ) - - return tasks_str - - def _get_failed_and_skipped_tasks(self, dr: DagRun) -> list[str]: - """ - Returns a list of task IDs for tasks that have failed or have been skipped in the given DagRun. - - :param dr: The DagRun object for which to retrieve failed and skipped tasks. - - :return: A list of task IDs for tasks that have failed or have been skipped. - """ - return [ - t.task_id - for t in dr.get_task_instances( - state=["failed", "skipped", "up_for_retry", "upstream_failed", None], - ) - ] - - -class DatabricksJobRepairSingleFailedLink(BaseOperatorLink, LoggingMixin): - """Constructs a link to send a repair request for a single databricks task.""" - - name = "Repair a single failed task" - - def get_link( - self, - operator, - dttm=None, - *, - ti_key: TaskInstanceKey | None = None, - ) -> str: - ti = None - if not ti_key: - ti = get_task_instance(operator, dttm) - ti_key = ti.key - - task_group = get_task_group(operator) - - self.log.info( - "Creating link to repair a single task for databricks job run %s task %s", - task_group.group_id, - ti_key.task_id, - ) - dag = _get_flask_app().dag_bag.get_dag(ti_key.dag_id) - task = dag.get_task(ti_key.task_id) - # Should we catch the exception here if there is no return value? - if ".launch" not in ti_key.task_id: - launch_task_id = get_launch_task_id(task_group) - ti_key = _get_launch_task_key(ti_key, task_id=launch_task_id) - metadata = get_xcom_result(ti_key, "return_value", ti) - - return ( - f"/repair_databricks_job?dag_id={ti_key.dag_id}&" - f"databricks_conn_id={metadata.databricks_conn_id}&" - f"databricks_run_id={metadata.databricks_run_id}&" - f"tasks_to_repair={_get_databricks_task_id(task)}&" - f"run_id={ti_key.run_id}" - ) - - -class RepairDatabricksTasks(AirflowBaseView, LoggingMixin): - default_view = "repair" - - @expose("/repair_databricks_job", methods=["GET"]) - @has_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - ] - ) - def repair(self): - databricks_conn_id, databricks_run_id, dag_id, tasks_to_repair = itemgetter( - "databricks_conn_id", "databricks_run_id", "dag_id", "tasks_to_repair" - )(request.values) - view = conf.get("webserver", "dag_default_view") - return_url = self._get_return_url(dag_id, view) - run_id = request.values.get("run_id").replace( - " ", "+" - ) # get run id separately since we need to modify it - if tasks_to_repair == "": - # If there are no tasks to repair, we return. - flash("No tasks to repair. Not sending repair request.") - return redirect(return_url) - self.log.info("Tasks to repair: %s", tasks_to_repair) - self.log.info("Repairing databricks job %s", databricks_run_id) - res = _repair_task( - databricks_conn_id=databricks_conn_id, - databricks_run_id=databricks_run_id, - tasks_to_repair=tasks_to_repair.split(","), - log=self.log, - ) - self.log.info( - "Repairing databricks job query for run %s sent", databricks_run_id - ) - self.log.info("Clearing tasks to rerun in airflow") - _clear_task_instances(dag_id, run_id, tasks_to_repair.split(","), self.log) - flash(f"Databricks repair job is starting!: {res}") - return redirect(return_url) - - @staticmethod - def _get_return_url(dag_id, view): - if version.parse(airflow_version) < version.parse("2.3.0"): - return_url = f"/{view}?dag_id={dag_id}" - else: - return_url = f"/dags/{dag_id}/{view}" - return return_url + super().__init__(*args, **kwargs) repair_databricks_view = RepairDatabricksTasks() @@ -509,11 +80,11 @@ def _get_return_url(dag_id, view): } -class AstroDatabricksPlugin(AirflowPlugin): - name = "astro_databricks_plugin" - operator_extra_links = [ - DatabricksJobRepairAllFailedLink(), - DatabricksJobRepairSingleFailedLink(), - DatabricksJobRunLink(), - ] - appbuilder_views = [repair_databricks_package] +class AstroDatabricksPlugin(DatabricksWorkflowPlugin, LoggingMixin): + def __init__(self, *args: Any, **kwargs: Any): + self.log.warning( + "AstroDatabricksPlugin is deprecated and will be removed in future versions. " + "Please use `airflow.providers.databricks.plugins.databricks_workflow.DatabricksWorkflowPlugin` instead." + ) + + super().__init__(*args, **kwargs) diff --git a/src/astro_databricks/settings.py b/src/astro_databricks/settings.py index 0e93c75..f91b460 100644 --- a/src/astro_databricks/settings.py +++ b/src/astro_databricks/settings.py @@ -1,3 +1,4 @@ +"""This module is deprecated and will be removed in future versions.""" import os DATABRICKS_JOBS_API_VERSION = os.getenv("DATABRICKS_JOBS_API_VERSION", "2.1") diff --git a/tests/databricks/test_common.py b/tests/databricks/test_common.py index 3204bae..40e7043 100644 --- a/tests/databricks/test_common.py +++ b/tests/databricks/test_common.py @@ -1,22 +1,11 @@ -from unittest import mock -from unittest.mock import MagicMock - -import pytest -from airflow.exceptions import AirflowException -from astro_databricks.operators.common import DatabricksTaskOperator -from astro_databricks.operators.workflow import ( - DatabricksWorkflowTaskGroup, +from airflow.providers.databricks.operators.databricks import ( + DatabricksTaskOperator as UpstreamDatabricksTaskOperator, ) +from astro_databricks.operators.common import DatabricksTaskOperator -@pytest.fixture -def mock_runs_api(): - return MagicMock() - - -@pytest.fixture -def databricks_task_operator(): - return DatabricksTaskOperator( +def test_init(): + operator = DatabricksTaskOperator( task_id="test_task", databricks_conn_id="foo", job_cluster_key="foo", @@ -31,404 +20,4 @@ def databricks_task_operator(): }, ) - -@mock.patch("astro_databricks.operators.common.DatabricksTaskOperator.launch_task_job") -@mock.patch( - "astro_databricks.operators.common.DatabricksTaskOperator.monitor_databricks_job" -) -def test_databricks_task_operator_without_taskgroup(mock_monitor, mock_launch, dag): - with dag: - task = DatabricksTaskOperator( - task_id="test_task", - databricks_conn_id="foo", - job_cluster_key="foo", - task_config={ - "notebook_task": { - "notebook_path": "foo", - "source": "WORKSPACE", - "base_parameters": { - "foo": "bar", - }, - }, - }, - ) - assert task.task_id == "test_task" - assert task.databricks_conn_id == "foo" - assert task.job_cluster_key == "foo" - assert task.task_config == { - "notebook_task": { - "notebook_path": "foo", - "source": "WORKSPACE", - "base_parameters": { - "foo": "bar", - }, - }, - } - dag.test() - mock_launch.assert_called_once() - mock_monitor.assert_called_once() - - -@mock.patch("astro_databricks.operators.common.DatabricksTaskOperator.launch_task_job") -@mock.patch( - "astro_databricks.operators.common.DatabricksTaskOperator.monitor_databricks_job" -) -@mock.patch( - "astro_databricks.operators.workflow._CreateDatabricksWorkflowOperator.execute" -) -def test_databricks_task_operator_with_taskgroup( - mock_create, mock_monitor, mock_launch, dag -): - mock_create.return_value = { - "databricks_job_id": "job_id", - "databricks_run_id": "run_id", - "databricks_conn_id": "conn_id", - } - with dag: - task_group = DatabricksWorkflowTaskGroup( - group_id="test_workflow", - databricks_conn_id="foo", - job_clusters=[{"job_cluster_key": "foo"}], - ) - with task_group: - task = DatabricksTaskOperator( - task_id="test_task", - databricks_conn_id="foo", - job_cluster_key="foo", - task_config={ - "notebook_task": { - "notebook_path": "foo", - "source": "WORKSPACE", - "base_parameters": { - "foo": "bar", - }, - }, - }, - ) - assert task.task_id == "test_workflow.test_task" - assert task.databricks_conn_id == "foo" - assert task.job_cluster_key == "foo" - assert task.task_config == { - "notebook_task": { - "notebook_path": "foo", - "source": "WORKSPACE", - "base_parameters": { - "foo": "bar", - }, - }, - } - dag.test() - mock_launch.assert_not_called() - mock_monitor.assert_called_once() - - -@mock.patch( - "astro_databricks.operators.common.DatabricksTaskOperator.monitor_databricks_job" -) -@mock.patch( - "astro_databricks.operators.common.DatabricksTaskOperator._get_databricks_task_id" -) -@mock.patch("astro_databricks.operators.common.DatabricksTaskOperator._get_api_client") -@mock.patch("astro_databricks.operators.common.RunsApi") -def test_databricks_task_operator_without_taskgroup_new_cluster( - mock_runs_api, mock_api_client, mock_get_databricks_task_id, mock_monitor, dag -): - mock_get_databricks_task_id.return_value = "1234" - mock_runs_api.return_value = mock.MagicMock() - with dag: - DatabricksTaskOperator( - task_id="test_task", - databricks_conn_id="foo", - job_cluster_key="foo", - new_cluster={"foo": "bar"}, - task_config={ - "notebook_task": { - "notebook_path": "foo", - "source": "WORKSPACE", - "base_parameters": { - "foo": "bar", - }, - }, - }, - ) - dag.test() - mock_runs_api.return_value.submit_run.assert_called_once_with( - { - "run_name": "1234", - "new_cluster": {"foo": "bar"}, - "notebook_task": { - "notebook_path": "foo", - "source": "WORKSPACE", - "base_parameters": { - "foo": "bar", - }, - }, - } - ) - mock_monitor.assert_called_once() - - -@mock.patch( - "astro_databricks.operators.common.DatabricksTaskOperator.monitor_databricks_job" -) -@mock.patch( - "astro_databricks.operators.common.DatabricksTaskOperator._get_databricks_task_id" -) -@mock.patch("astro_databricks.operators.common.DatabricksTaskOperator._get_api_client") -@mock.patch("astro_databricks.operators.common.RunsApi") -def test_databricks_task_operator_without_taskgroup_existing_cluster( - mock_runs_api, mock_api_client, mock_get_databricks_task_id, mock_monitor, dag -): - mock_get_databricks_task_id.return_value = "1234" - mock_runs_api.return_value = mock.MagicMock() - with dag: - DatabricksTaskOperator( - task_id="test_task", - databricks_conn_id="foo", - job_cluster_key="foo", - task_config={ - "notebook_task": { - "notebook_path": "foo", - "source": "WORKSPACE", - "base_parameters": { - "foo": "bar", - }, - }, - }, - existing_cluster_id="123", - ) - dag.test() - mock_runs_api.return_value.submit_run.assert_called_once_with( - { - "run_name": "1234", - "notebook_task": { - "notebook_path": "foo", - "source": "WORKSPACE", - "base_parameters": { - "foo": "bar", - }, - }, - "existing_cluster_id": "123", - } - ) - mock_monitor.assert_called_once() - - -@mock.patch( - "astro_databricks.operators.common.DatabricksTaskOperator.monitor_databricks_job" -) -@mock.patch("astro_databricks.operators.common.DatabricksTaskOperator._get_api_client") -@mock.patch("astro_databricks.operators.common.RunsApi") -def test_databricks_task_operator_without_taskgroup_existing_cluster_and_new_cluster( - mock_runs_api, mock_api_client, mock_monitor, dag -): - with pytest.raises(ValueError): - with dag: - DatabricksTaskOperator( - task_id="test_task", - databricks_conn_id="foo", - job_cluster_key="foo", - task_config={ - "notebook_task": { - "notebook_path": "foo", - "source": "WORKSPACE", - "base_parameters": { - "foo": "bar", - }, - }, - }, - existing_cluster_id="123", - new_cluster={"foo": "bar"}, - ) - dag.test() - - -@mock.patch( - "astro_databricks.operators.common.DatabricksTaskOperator.monitor_databricks_job" -) -@mock.patch("astro_databricks.operators.common.DatabricksTaskOperator._get_api_client") -@mock.patch("astro_databricks.operators.common.RunsApi") -def test_databricks_task_operator_without_taskgroup_no_cluster( - mock_runs_api, mock_api_client, mock_monitor, dag -): - with pytest.raises(ValueError): - with dag: - DatabricksTaskOperator( - task_id="test_task", - databricks_conn_id="foo", - job_cluster_key="foo", - task_config={ - "notebook_task": { - "notebook_path": "foo", - "source": "WORKSPACE", - "base_parameters": { - "foo": "bar", - }, - }, - }, - ) - dag.test() - - -def test_handle_final_state_success(databricks_task_operator): - final_state = { - "life_cycle_state": "TERMINATED", - "result_state": "SUCCESS", - "state_message": "Job succeeded", - } - databricks_task_operator._handle_final_state(final_state) - - -def test_handle_final_state_failure(databricks_task_operator): - final_state = { - "life_cycle_state": "TERMINATED", - "result_state": "FAILED", - "state_message": "Job failed", - } - with pytest.raises(AirflowException): - databricks_task_operator._handle_final_state(final_state) - - -def test_handle_final_state_exception(databricks_task_operator): - final_state = { - "life_cycle_state": "SKIPPED", - "state_message": "Job skipped", - } - with pytest.raises(AirflowException): - databricks_task_operator._handle_final_state(final_state) - - -@mock.patch("astro_databricks.operators.common.RunsApi") -@mock.patch("time.sleep") -def test_wait_for_pending_task(mock_sleep, mock_runs_api, databricks_task_operator): - # create a mock current task with "PENDING" state - current_task = {"run_id": "123", "state": {"life_cycle_state": "PENDING"}} - mock_runs_api.get_run.side_effect = [ - {"state": {"life_cycle_state": "PENDING"}}, - {"state": {"life_cycle_state": "RUNNING"}}, - ] - databricks_task_operator._wait_for_pending_task(current_task, mock_runs_api) - mock_runs_api.get_run.assert_called_with("123", version="2.1") - assert mock_runs_api.get_run.call_count == 2 - mock_runs_api.reset_mock() - - -@mock.patch("astro_databricks.operators.common.RunsApi") -@mock.patch("time.sleep") -def test_wait_for_terminating_task(mock_sleep, mock_runs_api, databricks_task_operator): - current_task = {"run_id": "123", "state": {"life_cycle_state": "PENDING"}} - mock_runs_api.get_run.side_effect = [ - {"state": {"life_cycle_state": "TERMINATING"}}, - {"state": {"life_cycle_state": "TERMINATING"}}, - {"state": {"life_cycle_state": "TERMINATED"}}, - ] - databricks_task_operator._wait_for_terminating_task(current_task, mock_runs_api) - mock_runs_api.get_run.assert_called_with("123", version="2.1") - assert mock_runs_api.get_run.call_count == 3 - mock_runs_api.reset_mock() - - -@mock.patch("astro_databricks.operators.common.RunsApi") -@mock.patch("time.sleep") -def test_wait_for_running_task(mock_sleep, mock_runs_api, databricks_task_operator): - current_task = {"run_id": "123", "state": {"life_cycle_state": "PENDING"}} - mock_runs_api.get_run.side_effect = [ - {"state": {"life_cycle_state": "RUNNING"}}, - {"state": {"life_cycle_state": "RUNNING"}}, - {"state": {"life_cycle_state": "TERMINATED"}}, - ] - databricks_task_operator._wait_for_running_task(current_task, mock_runs_api) - mock_runs_api.get_run.assert_called_with("123", version="2.1") - assert mock_runs_api.get_run.call_count == 3 - mock_runs_api.reset_mock() - - -def test_get_lifestyle_state(databricks_task_operator): - runs_api_mock = MagicMock() - runs_api_mock.get_run.return_value = {"state": {"life_cycle_state": "TERMINATING"}} - - task_info = {"run_id": "test_run_id"} - - assert ( - databricks_task_operator._get_lifestyle_state(task_info, runs_api_mock) - == "TERMINATING" - ) - - -@mock.patch("astro_databricks.operators.common.DatabricksHook") -@mock.patch("astro_databricks.operators.common.RunsApi") -@mock.patch("astro_databricks.operators.common.DatabricksTaskOperator._get_api_client") -@mock.patch( - "astro_databricks.operators.common.DatabricksTaskOperator._get_databricks_task_id" -) -def test_monitor_databricks_job_success( - mock_get_databricks_task_id, - mock_get_api_client, - mock_runs_api, - mock_databricks_hook, - databricks_task_operator, - caplog, -): - mock_get_databricks_task_id.return_value = "1" - # Define the expected response - response = { - "run_page_url": "https://databricks-instance-xyz.cloud.databricks.com/#job/1234/run/1", - "state": { - "life_cycle_state": "TERMINATED", - "result_state": "SUCCESS", - "state_message": "Ran successfully", - }, - "tasks": [ - { - "run_id": "1", - "task_key": "1", - } - ], - } - mock_runs_api.return_value.get_run.return_value = response - - databricks_task_operator.databricks_run_id = "1" - databricks_task_operator.monitor_databricks_job() - mock_runs_api.return_value.get_run.assert_called_with( - databricks_task_operator.databricks_run_id, version="2.1" - ) - assert ( - "Check the job run in Databricks: https://databricks-instance-xyz.cloud.databricks.com/#job/1234/run/1" - in caplog.messages - ) - - -@mock.patch("astro_databricks.operators.common.DatabricksHook") -@mock.patch("astro_databricks.operators.common.RunsApi") -@mock.patch("astro_databricks.operators.common.DatabricksTaskOperator._get_api_client") -@mock.patch( - "astro_databricks.operators.common.DatabricksTaskOperator._get_databricks_task_id" -) -def test_monitor_databricks_job_fail( - mock_get_databricks_task_id, - mock_get_api_client, - mock_runs_api, - mock_databricks_hook, - databricks_task_operator, -): - mock_get_databricks_task_id.return_value = "1" - # Define the expected response - response = { - "run_page_url": "https://databricks-instance-xyz.cloud.databricks.com/#job/1234/run/1", - "state": { - "life_cycle_state": "TERMINATED", - "result_state": "FAILED", - "state_message": "job failed", - }, - "tasks": [ - { - "run_id": "1", - "task_key": "1", - } - ], - } - mock_runs_api.return_value.get_run.return_value = response - - databricks_task_operator.databricks_run_id = "1" - with pytest.raises(AirflowException): - databricks_task_operator.monitor_databricks_job() + assert isinstance(operator, UpstreamDatabricksTaskOperator) diff --git a/tests/databricks/test_notebook.py b/tests/databricks/test_notebook.py index d905afc..33fd91f 100644 --- a/tests/databricks/test_notebook.py +++ b/tests/databricks/test_notebook.py @@ -1,24 +1,11 @@ -import os -from unittest import mock -from unittest.mock import MagicMock - -import pytest -from airflow.exceptions import AirflowException -from astro_databricks.operators.notebook import DatabricksNotebookOperator -from astro_databricks.operators.workflow import ( - DatabricksWorkflowTaskGroup, +from airflow.providers.databricks.operators.databricks import ( + DatabricksNotebookOperator as UpstreamDatabricksNotebookOperator, ) -from astro_databricks.settings import DATABRICKS_JOBS_API_VERSION - - -@pytest.fixture -def mock_runs_api(): - return MagicMock() +from astro_databricks.operators.notebook import DatabricksNotebookOperator -@pytest.fixture -def databricks_notebook_operator(): - return DatabricksNotebookOperator( +def test_init(): + operator = DatabricksNotebookOperator( task_id="notebook", databricks_conn_id="foo", notebook_path="/foo/bar", @@ -28,417 +15,7 @@ def databricks_notebook_operator(): "foo": "bar", }, notebook_packages=[{"nb_index": {"package": "nb_package"}}], + existing_cluster_id="123", ) - -@mock.patch( - "astro_databricks.operators.notebook.DatabricksNotebookOperator.launch_notebook_job" -) -@mock.patch( - "astro_databricks.operators.notebook.DatabricksNotebookOperator.monitor_databricks_job" -) -def test_databricks_notebook_operator_without_taskgroup(mock_monitor, mock_launch, dag): - with dag: - notebook = DatabricksNotebookOperator( - task_id="notebook", - databricks_conn_id="foo", - notebook_path="/foo/bar", - source="WORKSPACE", - job_cluster_key="foo", - notebook_params={ - "foo": "bar", - }, - notebook_packages=[{"nb_index": {"package": "nb_package"}}], - ) - assert notebook.task_id == "notebook" - assert notebook.databricks_conn_id == "foo" - assert notebook.notebook_path == "/foo/bar" - assert notebook.source == "WORKSPACE" - assert notebook.job_cluster_key == "foo" - assert notebook.notebook_params == {"foo": "bar"} - assert notebook.notebook_packages == [{"nb_index": {"package": "nb_package"}}] - dag.test() - mock_launch.assert_called_once() - mock_monitor.assert_called_once() - - -@mock.patch( - "astro_databricks.operators.notebook.DatabricksNotebookOperator.launch_notebook_job" -) -@mock.patch( - "astro_databricks.operators.notebook.DatabricksNotebookOperator.monitor_databricks_job" -) -@mock.patch( - "astro_databricks.operators.workflow._CreateDatabricksWorkflowOperator.execute" -) -def test_databricks_notebook_operator_with_taskgroup( - mock_create, mock_monitor, mock_launch, dag -): - mock_create.return_value = { - "databricks_job_id": "job_id", - "databricks_run_id": "run_id", - "databricks_conn_id": "conn_id", - } - with dag: - task_group = DatabricksWorkflowTaskGroup( - group_id="test_workflow", - databricks_conn_id="foo", - job_clusters=[{"job_cluster_key": "foo"}], - notebook_params=[{"notebook_path": "/foo/bar"}], - ) - with task_group: - notebook = DatabricksNotebookOperator( - task_id="notebook", - databricks_conn_id="foo", - notebook_path="/foo/bar", - source="WORKSPACE", - job_cluster_key="foo", - notebook_params={ - "foo": "bar", - }, - notebook_packages=[{"nb_index": {"package": "nb_package"}}], - ) - assert notebook.task_id == "test_workflow.notebook" - assert notebook.databricks_conn_id == "foo" - assert notebook.notebook_path == "/foo/bar" - assert notebook.source == "WORKSPACE" - assert notebook.job_cluster_key == "foo" - assert notebook.notebook_params == {"foo": "bar"} - assert notebook.notebook_packages == [ - {"nb_index": {"package": "nb_package"}} - ] - dag.test() - mock_launch.assert_not_called() - mock_monitor.assert_called_once() - - -@pytest.mark.parametrize("api_version", ["3.2", "2.1"]) -@mock.patch( - "astro_databricks.operators.notebook.DatabricksNotebookOperator.monitor_databricks_job" -) -@mock.patch( - "astro_databricks.operators.notebook.DatabricksNotebookOperator._get_databricks_task_id" -) -@mock.patch( - "astro_databricks.operators.notebook.DatabricksNotebookOperator._get_api_client" -) -@mock.patch("astro_databricks.operators.notebook.RunsApi") -def test_databricks_notebook_operator_without_taskgroup_new_cluster( - mock_runs_api, - mock_api_client, - mock_get_databricks_task_id, - mock_monitor, - dag, - api_version, -): - mock_get_databricks_task_id.return_value = "1234" - mock_runs_api.return_value = mock.MagicMock() - with mock.patch.dict(os.environ, {"DATABRICKS_JOBS_API_VERSION": api_version}): - import importlib - - import astro_databricks - - importlib.reload(astro_databricks.settings) - - with dag: - DatabricksNotebookOperator( - task_id="notebook", - databricks_conn_id="foo", - notebook_path="/foo/bar", - source="WORKSPACE", - job_cluster_key="foo", - notebook_params={ - "foo": "bar", - }, - notebook_packages=[{"nb_index": {"package": "nb_package"}}], - new_cluster={"foo": "bar"}, - ) - dag.test() - mock_runs_api.return_value.submit_run.assert_called_once_with( - { - "run_name": "1234", - "notebook_task": { - "notebook_path": "/foo/bar", - "source": "WORKSPACE", - "base_parameters": {"foo": "bar"}, - }, - "new_cluster": {"foo": "bar"}, - "libraries": [{"nb_index": {"package": "nb_package"}}], - "timeout_seconds": 0, - "email_notifications": {}, - }, - version=api_version, - ) - mock_monitor.assert_called_once() - - -@mock.patch( - "astro_databricks.operators.notebook.DatabricksNotebookOperator.monitor_databricks_job" -) -@mock.patch( - "astro_databricks.operators.notebook.DatabricksNotebookOperator._get_databricks_task_id" -) -@mock.patch( - "astro_databricks.operators.notebook.DatabricksNotebookOperator._get_api_client" -) -@mock.patch("astro_databricks.operators.notebook.RunsApi") -def test_databricks_notebook_operator_without_taskgroup_existing_cluster( - mock_runs_api, mock_api_client, mock_get_databricks_task_id, mock_monitor, dag -): - mock_get_databricks_task_id.return_value = "1234" - mock_runs_api.return_value = mock.MagicMock() - with dag: - DatabricksNotebookOperator( - task_id="notebook", - databricks_conn_id="foo", - notebook_path="/foo/bar", - source="WORKSPACE", - job_cluster_key="foo", - notebook_params={ - "foo": "bar", - }, - notebook_packages=[{"nb_index": {"package": "nb_package"}}], - existing_cluster_id="123", - ) - dag.test() - mock_runs_api.return_value.submit_run.assert_called_once_with( - { - "run_name": "1234", - "notebook_task": { - "notebook_path": "/foo/bar", - "source": "WORKSPACE", - "base_parameters": {"foo": "bar"}, - }, - "existing_cluster_id": "123", - "libraries": [{"nb_index": {"package": "nb_package"}}], - "timeout_seconds": 0, - "email_notifications": {}, - }, - version=DATABRICKS_JOBS_API_VERSION, - ) - mock_monitor.assert_called_once() - - -@mock.patch( - "astro_databricks.operators.notebook.DatabricksNotebookOperator.monitor_databricks_job" -) -@mock.patch( - "astro_databricks.operators.notebook.DatabricksNotebookOperator._get_api_client" -) -@mock.patch("astro_databricks.operators.notebook.RunsApi") -def test_databricks_notebook_operator_without_taskgroup_existing_cluster_and_new_cluster( - mock_runs_api, mock_api_client, mock_monitor, dag -): - with pytest.raises(ValueError): - with dag: - DatabricksNotebookOperator( - task_id="notebook", - databricks_conn_id="foo", - notebook_path="/foo/bar", - source="WORKSPACE", - job_cluster_key="foo", - notebook_params={ - "foo": "bar", - }, - notebook_packages=[{"nb_index": {"package": "nb_package"}}], - existing_cluster_id="123", - new_cluster={"foo": "bar"}, - ) - dag.test() - - -@mock.patch( - "astro_databricks.operators.notebook.DatabricksNotebookOperator.monitor_databricks_job" -) -@mock.patch( - "astro_databricks.operators.notebook.DatabricksNotebookOperator._get_api_client" -) -@mock.patch("astro_databricks.operators.notebook.RunsApi") -def test_databricks_notebook_operator_without_taskgroup_no_cluster( - mock_runs_api, mock_api_client, mock_monitor, dag -): - with pytest.raises(ValueError): - with dag: - DatabricksNotebookOperator( - task_id="notebook", - databricks_conn_id="foo", - notebook_path="/foo/bar", - source="WORKSPACE", - job_cluster_key="foo", - notebook_params={ - "foo": "bar", - }, - notebook_packages=[{"nb_index": {"package": "nb_package"}}], - ) - dag.test() - - -def test_handle_final_state_success(databricks_notebook_operator): - final_state = { - "life_cycle_state": "TERMINATED", - "result_state": "SUCCESS", - "state_message": "Job succeeded", - } - databricks_notebook_operator._handle_final_state(final_state) - - -def test_handle_final_state_failure(databricks_notebook_operator): - final_state = { - "life_cycle_state": "TERMINATED", - "result_state": "FAILED", - "state_message": "Job failed", - } - with pytest.raises(AirflowException): - databricks_notebook_operator._handle_final_state(final_state) - - -def test_handle_final_state_exception(databricks_notebook_operator): - final_state = { - "life_cycle_state": "SKIPPED", - "state_message": "Job skipped", - } - with pytest.raises(AirflowException): - databricks_notebook_operator._handle_final_state(final_state) - - -@mock.patch("astro_databricks.operators.notebook.RunsApi") -@mock.patch("time.sleep") -def test_wait_for_pending_task(mock_sleep, mock_runs_api, databricks_notebook_operator): - # create a mock current task with "PENDING" state - current_task = {"run_id": "123", "state": {"life_cycle_state": "PENDING"}} - mock_runs_api.get_run.side_effect = [ - {"state": {"life_cycle_state": "PENDING"}}, - {"state": {"life_cycle_state": "RUNNING"}}, - ] - databricks_notebook_operator._wait_for_pending_task(current_task, mock_runs_api) - mock_runs_api.get_run.assert_called_with("123", version=DATABRICKS_JOBS_API_VERSION) - assert mock_runs_api.get_run.call_count == 2 - mock_runs_api.reset_mock() - - -@mock.patch("astro_databricks.operators.notebook.RunsApi") -@mock.patch("time.sleep") -def test_wait_for_terminating_task( - mock_sleep, mock_runs_api, databricks_notebook_operator -): - current_task = {"run_id": "123", "state": {"life_cycle_state": "PENDING"}} - mock_runs_api.get_run.side_effect = [ - {"state": {"life_cycle_state": "TERMINATING"}}, - {"state": {"life_cycle_state": "TERMINATING"}}, - {"state": {"life_cycle_state": "TERMINATED"}}, - ] - databricks_notebook_operator._wait_for_terminating_task(current_task, mock_runs_api) - mock_runs_api.get_run.assert_called_with("123", version=DATABRICKS_JOBS_API_VERSION) - assert mock_runs_api.get_run.call_count == 3 - mock_runs_api.reset_mock() - - -@mock.patch("astro_databricks.operators.notebook.RunsApi") -@mock.patch("time.sleep") -def test_wait_for_running_task(mock_sleep, mock_runs_api, databricks_notebook_operator): - current_task = {"run_id": "123", "state": {"life_cycle_state": "PENDING"}} - mock_runs_api.get_run.side_effect = [ - {"state": {"life_cycle_state": "RUNNING"}}, - {"state": {"life_cycle_state": "RUNNING"}}, - {"state": {"life_cycle_state": "TERMINATED"}}, - ] - databricks_notebook_operator._wait_for_running_task(current_task, mock_runs_api) - mock_runs_api.get_run.assert_called_with("123", version=DATABRICKS_JOBS_API_VERSION) - assert mock_runs_api.get_run.call_count == 3 - mock_runs_api.reset_mock() - - -def test_get_lifestyle_state(databricks_notebook_operator): - runs_api_mock = MagicMock() - runs_api_mock.get_run.return_value = {"state": {"life_cycle_state": "TERMINATING"}} - - task_info = {"run_id": "test_run_id"} - - assert ( - databricks_notebook_operator._get_lifestyle_state(task_info, runs_api_mock) - == "TERMINATING" - ) - - -@mock.patch("astro_databricks.operators.notebook.DatabricksHook") -@mock.patch("astro_databricks.operators.notebook.RunsApi") -@mock.patch( - "astro_databricks.operators.notebook.DatabricksNotebookOperator._get_api_client" -) -@mock.patch( - "astro_databricks.operators.notebook.DatabricksNotebookOperator._get_databricks_task_id" -) -def test_monitor_databricks_job_success( - mock_get_databricks_task_id, - mock_get_api_client, - mock_runs_api, - mock_databricks_hook, - databricks_notebook_operator, - caplog, -): - mock_get_databricks_task_id.return_value = "1" - # Define the expected response - response = { - "run_page_url": "https://databricks-instance-xyz.cloud.databricks.com/#job/1234/run/1", - "state": { - "life_cycle_state": "TERMINATED", - "result_state": "SUCCESS", - "state_message": "Ran successfully", - }, - "tasks": [ - { - "run_id": "1", - "task_key": "1", - } - ], - } - mock_runs_api.return_value.get_run.return_value = response - - databricks_notebook_operator.databricks_run_id = "1" - databricks_notebook_operator.monitor_databricks_job() - mock_runs_api.return_value.get_run.assert_called_with( - databricks_notebook_operator.databricks_run_id, - version=DATABRICKS_JOBS_API_VERSION, - ) - assert ( - "Check the job run in Databricks: https://databricks-instance-xyz.cloud.databricks.com/#job/1234/run/1" - in caplog.messages - ) - - -@mock.patch("astro_databricks.operators.notebook.DatabricksHook") -@mock.patch("astro_databricks.operators.notebook.RunsApi") -@mock.patch( - "astro_databricks.operators.notebook.DatabricksNotebookOperator._get_api_client" -) -@mock.patch( - "astro_databricks.operators.notebook.DatabricksNotebookOperator._get_databricks_task_id" -) -def test_monitor_databricks_job_fail( - mock_get_databricks_task_id, - mock_get_api_client, - mock_runs_api, - mock_databricks_hook, - databricks_notebook_operator, -): - mock_get_databricks_task_id.return_value = "1" - # Define the expected response - response = { - "run_page_url": "https://databricks-instance-xyz.cloud.databricks.com/#job/1234/run/1", - "state": { - "life_cycle_state": "TERMINATED", - "result_state": "FAILED", - "state_message": "job failed", - }, - "tasks": [ - { - "run_id": "1", - "task_key": "1", - } - ], - } - mock_runs_api.return_value.get_run.return_value = response - - databricks_notebook_operator.databricks_run_id = "1" - with pytest.raises(AirflowException): - databricks_notebook_operator.monitor_databricks_job() + assert isinstance(operator, UpstreamDatabricksNotebookOperator) diff --git a/tests/databricks/test_plugin.py b/tests/databricks/test_plugin.py index 6e92a38..700594f 100644 --- a/tests/databricks/test_plugin.py +++ b/tests/databricks/test_plugin.py @@ -1,421 +1,46 @@ from __future__ import annotations -import uuid -from unittest import mock -from unittest.mock import MagicMock, patch - -import pytest -from airflow import DAG -from airflow.models.dagbag import DagBag -from airflow.models.dagrun import DagRun -from airflow.models.taskinstance import TaskInstanceKey -from airflow.operators.dummy import DummyOperator -from airflow.providers.databricks.hooks.databricks import DatabricksHook -from airflow.utils.dates import days_ago -from airflow.utils.db import create_session -from airflow.utils.state import State -from airflow.utils.task_group import TaskGroup -from databricks_cli.sdk.service import JobsService - - -from astro_databricks.operators.notebook import DatabricksNotebookOperator -from astro_databricks.operators.workflow import DatabricksWorkflowTaskGroup +from airflow.providers.databricks.plugins.databricks_workflow import ( + DatabricksWorkflowPlugin as UpstreamDatabricksWorkflowPlugin, + RepairDatabricksTasks as UpstreamRepairDatabricksTasks, + WorkflowJobRepairAllFailedLink, + WorkflowJobRepairSingleTaskLink, + WorkflowJobRunLink, +) from astro_databricks.plugins.plugin import ( DatabricksJobRepairAllFailedLink, DatabricksJobRepairSingleFailedLink, DatabricksJobRunLink, - _clear_task_instances, - _get_dagrun, - _get_databricks_task_id, - _repair_task, - get_launch_task_id + DatabricksWorkflowPlugin, + RepairDatabricksTasks, ) -@pytest.fixture -def mock_dag(): - dag = MagicMock(spec=DAG) - dag.dag_id = "my_dag" - dag.get_task.return_value = MagicMock( - task_group=MagicMock( - children={"task_1": "task1", "task_2": "task2", "task_5": "task5"}, - group_id="test_group", - ) - ) - return dag - - -@pytest.fixture -def mock_session(): - return MagicMock() - - -@patch("astro_databricks.plugins.plugin.get_airflow_app") -@patch("astro_databricks.plugins.plugin._get_dagrun") -@patch("astro_databricks.plugins.plugin.clear_task_instances") -def test_clear_task_instances( - mock_clear_tasks, mock_dagrun, mock_get_airflow_app, session -): - mock_get_airflow_app.return_value.dag_bag = MagicMock() - - dag_id = "test_dag" - run_id = "test_run" - task_ids = ["task_1", "task_2", "task_3"] - - # Create mock objects for testing - dag = MagicMock() - dag.get_task_instances.return_value = [ - MagicMock(task_id=task_id, dag_id=dag_id) for task_id in task_ids - ] - - dag_run = MagicMock(spec=DagRun) - dag_run.get_task_instances.return_value = dag.get_task_instances() - mock_dagrun.return_value = dag_run - mock_session = MagicMock() - databricks_task_ids = [dag_id + "__" + task_id for task_id in task_ids] - _clear_task_instances( - dag_id, run_id, databricks_task_ids, log=MagicMock(), session=mock_session - ) - - # Assert that clear_task_instances was called with the correct list of TaskInstances - expected_tis = dag_run.get_task_instances.return_value - mock_clear_tasks.assert_called_once_with(expected_tis, mock_session) - - -@patch("astro_databricks.plugins.plugin.get_airflow_app") -@patch("astro_databricks.plugins.plugin.XCom.get_value") -@patch("astro_databricks.plugins.plugin.DatabricksHook") -@patch("astro_databricks.plugins.plugin.get_task_group") -def test_databricks_job_run_link( - mock_get_task_group, mock_hook, mock_xcom, mock_get_airflow_app, mock_dag -): - mock_dag_bag = MagicMock() - mock_dag_bag.get_dag.return_value = mock_dag - mock_get_airflow_app.return_value.dag_bag = mock_dag_bag - - mock_xcom.return_value = { - "databricks_job_id": "test_job", - "databricks_run_id": "test_run", - "databricks_conn_id": "test_conn", - } - - mock_hook.return_value.host = "test_host" - mock_get_task_group.return_value.get_child_by_label.return_value.task_id = ( - "test_group.launch" - ) +def test_databricks_job_run_link_init(): link = DatabricksJobRunLink() - operator = MagicMock( - task_id="dummy_task", - dag=mock_dag, - task_group=MagicMock(group_id="test_group", default_args={}), - ) - ti_key = TaskInstanceKey(dag_id="test_dag", task_id="dummy_task", run_id="test_run") - result = link.get_link(operator=operator, ti_key=ti_key) - - mock_dag_bag.get_dag.assert_called_once_with("test_dag") - mock_xcom.assert_called_once_with( - ti_key=TaskInstanceKey( - task_id="test_group.launch", dag_id="test_dag", run_id="test_run" - ), - key="return_value", - ) - mock_hook.assert_called_once_with("test_conn") - expected_result = "https://test_host/#job/test_job/run/test_run" - assert result == expected_result + assert isinstance(link, WorkflowJobRunLink) -@pytest.fixture -def mock_dagrun(): - dagrun = MagicMock(spec=DagRun) - dagrun.get_task_instances.return_value = [ - MagicMock(state="failed", task_id="task_1"), - MagicMock(state="skipped", task_id="task_2"), - MagicMock(state="up_for_retry", task_id="task_3"), - MagicMock(state="upstream_failed", task_id="task_4"), - MagicMock(state="success", task_id="task_5"), - ] - return dagrun - - -@mock.patch("astro_databricks.plugins.plugin.XCom") -def test_repair_all_get_link(mock_xcom, mock_dagrun, mock_dag, mock_session): - # Arrange - task_instance_key = TaskInstanceKey( - dag_id="my_dag", - task_id="my_task", - run_id="run_id", - ) - mock_xcom.get_value.return_value = { - "databricks_job_id": "job_id", - "databricks_run_id": "run_id", - "databricks_conn_id": "databricks_conn", - } - +def test_repair_all_get_link_init(): link = DatabricksJobRepairAllFailedLink() - link.get_dagrun = MagicMock(return_value=mock_dagrun) - link.get_dag = MagicMock(return_value=mock_dag) - link.get_tasks_to_run = MagicMock(return_value="task_1,task_2") - mock_operator = MagicMock(task_group=MagicMock(group_id="test_group")) - # Act - result = link.get_link(mock_operator, None, ti_key=task_instance_key) + assert isinstance(link, WorkflowJobRepairAllFailedLink) - # Assert - assert ( - result == "/repair_databricks_job?dag_id=my_dag&" - "databricks_conn_id=databricks_conn&" - "databricks_run_id=run_id&" - "run_id=run_id&" - "tasks_to_repair=task_1,task_2" - ) - -@mock.patch("astro_databricks.plugins.plugin.get_airflow_app") -@mock.patch("astro_databricks.plugins.plugin._get_dagrun") -def test_get_tasks_to_run(mock_dagrun, mock_airflow_app): - link = DatabricksJobRepairAllFailedLink() - ti_key = TaskInstanceKey(dag_id="test_dag", task_id="test_task", run_id="test_run") - dag = DAG("test_dag") - task_group_children = { - "test_group.test_task": MagicMock( - task_id="test_group.test_task", dag_id="test_dag" - ), - "test_group.test_task_2": MagicMock( - task_id="test_group.test_task_2", dag_id="test_dag" - ), - "test_group.test_task_3": MagicMock( - task_id="test_group.test_task_3", dag_id="test_dag" - ), - } - task = MagicMock( - task_id="test_task", - task_group=MagicMock( - dag_id="test_dag", group_id="test_group", children=task_group_children - ), - ) - mock_airflow_app.return_value.dag_bag.get_dag.return_value = dag - dag.add_task(task) - - def generate_mock_dagrun(task_map: dict[str, str]): - dagrun = MagicMock(spec=DagRun) - dagrun.get_task_instances.return_value = [ - MagicMock(state=state, task_id=task_id) - for task_id, state in task_map.items() - ] - return dagrun - - # Case 1: No failed or skipped tasks - mock_dagrun.return_value = generate_mock_dagrun({}) - tasks_str = link.get_tasks_to_run(ti_key, operator=task, log=MagicMock()) - assert tasks_str == "" - - # # Case 2: One failed task - mock_dagrun.return_value = generate_mock_dagrun({"test_group.test_task": "failed"}) - tasks_str = link.get_tasks_to_run(ti_key, task, log=MagicMock()) - assert tasks_str == "test_dag__test_group__test_task" - # - # # Case 3: One skipped task - mock_dagrun.return_value = generate_mock_dagrun({"test_group.test_task": "skipped"}) - tasks_str = link.get_tasks_to_run(ti_key, task, log=MagicMock()) - assert tasks_str == "test_dag__test_group__test_task" - # - # # Case 4: Multiple failed and skipped tasks - mock_dagrun.return_value = generate_mock_dagrun( - {"test_group.test_task": "failed", "test_group.test_task_2": "skipped"} - ) - tasks_str = link.get_tasks_to_run(ti_key, task, log=MagicMock()) - assert ( - tasks_str == "test_dag__test_group__test_task,test_dag__test_group__test_task_2" - ) - - -@pytest.fixture -def session(): - with create_session() as session: - yield session - - -def test_get_dagrun(session, dag): - # Create a DagRun object and add it to the database - DagBag() - run_id = "test_run_id" + uuid.uuid4().hex - dr = dag.create_dagrun(run_id=run_id, state=State.RUNNING) - session.add(dr) - session.commit() - - # Call the function and ensure it returns the correct DagRun object - result = _get_dagrun(dag, run_id, session=session) - assert result == dr - - -@mock.patch("astro_databricks.plugins.plugin.DatabricksHook") -@mock.patch("astro_databricks.plugins.plugin.JobsService") -@mock.patch("astro_databricks.plugins.plugin.ApiClient") -def test_repair_task(mock_api_client, mock_jobs_service, mock_hook): - databricks_conn_id = "my_databricks_conn" - databricks_run_id = "my_databricks_run" - tasks_to_repair = ["task1", "task2"] - - # Mock the Databricks hook and API client - mock_hook.return_value = MagicMock(spec=DatabricksHook) - mock_api_client = MagicMock() - mock_hook.get_conn.return_value = mock_api_client - - # Mock the JobsService and its methods - mock_jobs_service.return_value = MagicMock(spec=JobsService) - mock_jobs_service.return_value.get_run.return_value = { - "job_id": 1234, - "run_id": databricks_run_id, - "state": "RUNNING", - "start_time": "2022-02-27T00:00:00Z", - "end_time": None, - "tasks": [], - "state_message": None, - "creator_user_name": "airflow", - "run_name": None, - "run_page_url": None, - "run_type": None, - "spark_context_id": None, - "retry_number": 0, - "previous_run_id": None, - "trigger": {}, - "is_completed": False, - "is_active": True, - "is_queued": False, - "cluster_spec": {}, - "overriding_parameters": {}, - "start_time_epoch": 1645958400, - } - mock_jobs_service.return_value.repair.return_value = None - - # Patch the DatabricksHook and JobsService constructors - _repair_task( - databricks_conn_id, databricks_run_id, tasks_to_repair, log=MagicMock() - ) - - # # Check that the JobsService methods were called correctly - mock_hook.return_value.get_conn.assert_called_once_with() - mock_jobs_service.return_value.get_run.assert_called_once_with( - run_id=databricks_run_id, include_history=True - ) - mock_jobs_service.return_value.repair.assert_called_once_with( - run_id=databricks_run_id, - version="2.1", - latest_repair_id=None, - rerun_tasks=tasks_to_repair, - ) - - -@patch("astro_databricks.plugins.plugin.get_airflow_app") -def test_databricks_job_repair_single_failed_link(mock_get_airflow_app, dag): - mock_dag_bag = MagicMock() - mock_task = MagicMock( - task_id="test_task", - task_group=MagicMock(group_id="test_group", default_args={}), - ) - test_dag = DAG("test_dag", start_date=days_ago(1)) - mock_dag_bag.get_dag.return_value = test_dag - test_dag.get_task = MagicMock(return_value=mock_task) - mock_get_airflow_app.return_value.dag_bag = mock_dag_bag +def test_databricks_job_repair_single_failed_link_init(): link = DatabricksJobRepairSingleFailedLink() - dag_id = "test_dag" - task_id = "test_task" - run_id = "test_run" - databricks_conn_id = "test_conn" - databricks_run_id = "test_run_id" + assert isinstance(link, WorkflowJobRepairSingleTaskLink) - ti_key = TaskInstanceKey(dag_id, task_id, run_id) - metadata = { - "databricks_conn_id": databricks_conn_id, - "databricks_run_id": databricks_run_id, - "databricks_job_id": 1234, - } - - mock_xcom = MagicMock() - mock_xcom.get_value.return_value = metadata - - with patch("astro_databricks.plugins.plugin.XCom", mock_xcom): - link.get_link(mock_task, dttm=None, ti_key=ti_key) - f"/repair_databricks_job?dag_id={dag_id}&databricks_conn_id={databricks_conn_id}&databricks_run_id={databricks_run_id}&tasks_to_repair={_get_databricks_task_id(mock_task)}" - - -@mock.patch("astro_databricks.operators.workflow.DatabricksHook") -@mock.patch("astro_databricks.operators.workflow.ApiClient") -@mock.patch("astro_databricks.operators.workflow.JobsApi") -@mock.patch("astro_databricks.operators.workflow._get_job_by_name") -@mock.patch( - "astro_databricks.operators.workflow.RunsApi.get_run", - return_value={"state": {"life_cycle_state": "RUNNING"}}, -) -def test_create_workflow_with_nested_task_groups( - mock_run_api, mock_get_jobs, mock_jobs_api, mock_api, mock_hook, dag -): - mock_get_jobs.return_value = {"job_id": 862519602273592} - extra_job_params = { - "timeout_seconds": 10, # default: 0 - "webhook_notifications": { - "on_failure": [{"id": "b0aea8ab-ea8c-4a45-a2e9-9a26753fd702"}], - }, - "email_notifications": { - "no_alert_for_skipped_runs": True, # default: False - "on_start": ["user.name@databricks.com"], - }, - "git_source": { # no default value - "git_url": "https://github.com/astronomer/astro-provider-databricks", - "git_provider": "gitHub", - "git_branch": "main", - }, - } - with dag: - outer_task_group = DatabricksWorkflowTaskGroup( - group_id="test_workflow", - databricks_conn_id="foo", - job_clusters=[{"job_cluster_key": "foo"}], - notebook_params={"notebook_path": "/foo/bar"}, - extra_job_params=extra_job_params, - notebook_packages=[ - {"pypi": {"package": "mlflow==2.4.0"}}, - ] - ) - with outer_task_group: - direct_notebook = DatabricksNotebookOperator( - task_id="direct_notebook", - databricks_conn_id="foo", - notebook_path="/foo/bar", - source="WORKSPACE", - job_cluster_key="foo", - ) +def test_repair_databricks_task_init(): + repair_databricks_task = RepairDatabricksTasks() - with TaskGroup("middle_task_group") as middle_task_group: - with TaskGroup("inner_task_group") as inner_task_group: - inner_notebook = DatabricksNotebookOperator( - task_id="inner_notebook", - databricks_conn_id="foo", - notebook_path="/foo/bar", - source="WORKSPACE", - job_cluster_key="foo", - ) - inner_notebook - inner_task_group - direct_notebook >> middle_task_group + assert isinstance(repair_databricks_task, UpstreamRepairDatabricksTasks) - task_id = get_launch_task_id(outer_task_group) - assert task_id == "test_workflow.launch" -def test_get_task_group_children(dag): - repair_all_link = DatabricksJobRepairAllFailedLink() - with dag: - with TaskGroup("parent_task_group") as parent_task_group: - parent_task = DummyOperator(task_id="parent_task") - with TaskGroup("inner_task_group") as inner_task_group: - inner_task = DummyOperator(task_id="inner_task") - parent_task >> inner_task_group +def test_plugin_init(): + plugin = DatabricksWorkflowPlugin() - children = repair_all_link.get_task_group_children(parent_task_group) - children_keys = children.keys() - assert len(children_keys) == 2 - assert 'parent_task_group.parent_task' in children_keys - assert 'parent_task_group.inner_task_group.inner_task' in children_keys + assert isinstance(plugin, UpstreamDatabricksWorkflowPlugin) diff --git a/tests/databricks/test_workflow.py b/tests/databricks/test_workflow.py index 176c880..cff61be 100644 --- a/tests/databricks/test_workflow.py +++ b/tests/databricks/test_workflow.py @@ -1,402 +1,12 @@ from __future__ import annotations -import logging -from unittest import mock - -import pytest -import copy -from airflow.exceptions import AirflowException -from airflow.utils.task_group import TaskGroup -from astro_databricks.operators.notebook import DatabricksNotebookOperator -from astro_databricks.operators.workflow import DatabricksWorkflowTaskGroup -from astro_databricks.settings import DATABRICKS_JOBS_API_VERSION - -expected_workflow_json = { - "name": "unit_test_dag.test_workflow", - "email_notifications": {"no_alert_for_skipped_runs": False}, - "format": "MULTI_TASK", - "job_clusters": [{"job_cluster_key": "foo"}], - "max_concurrent_runs": 1, - "tasks": [ - { - "depends_on": [], - "email_notifications": {}, - "job_cluster_key": "foo", - "libraries": [ - {"nb_index": {"package": "nb_package"}}, - {"tg_index": {"package": "tg_package"}}, - ], - "notebook_task": { - "base_parameters": {"notebook_path": "/foo/bar"}, - "notebook_path": "/foo/bar", - "source": "WORKSPACE", - }, - "task_key": "unit_test_dag__test_workflow__notebook_1", - "timeout_seconds": 0, - }, - { - "depends_on": [{"task_key": "unit_test_dag__test_workflow__notebook_1"}], - "email_notifications": {}, - "job_cluster_key": "foo", - "libraries": [{"tg_index": {"package": "tg_package"}}], - "notebook_task": { - "base_parameters": {"foo": "bar", "notebook_path": "/foo/bar"}, - "notebook_path": "/foo/bar", - "source": "WORKSPACE", - }, - "task_key": "unit_test_dag__test_workflow__notebook_2", - "timeout_seconds": 0, - }, - ], - "timeout_seconds": 0, -} - -expected_workflow_json_existing_cluster_id = copy.deepcopy(expected_workflow_json) -# remove job_cluster_key and add existing_cluster_id -expected_workflow_json_existing_cluster_id['tasks'][1].pop('job_cluster_key') -expected_workflow_json_existing_cluster_id['tasks'][1]['existing_cluster_id'] = 'foo' - -@mock.patch("astro_databricks.operators.workflow.DatabricksHook") -@mock.patch("astro_databricks.operators.workflow.ApiClient") -@mock.patch("astro_databricks.operators.workflow.JobsApi") -@mock.patch( - "astro_databricks.operators.workflow.RunsApi.get_run", - return_value={"state": {"life_cycle_state": "SKIPPED"}}, +from airflow.providers.databricks.operators.databricks_workflow import ( + DatabricksWorkflowTaskGroup as UpstreamDatabricksWorkflowTaskGroup, ) -def test_create_workflow_from_notebooks_raises_exception_due_to_job_being_skipped( - mock_run_api, mock_jobs_api, mock_api, mock_hook, dag -): - mock_jobs_api.return_value.create_job.return_value = {"job_id": 1} - with dag: - task_group = DatabricksWorkflowTaskGroup( - group_id="test_workflow", - databricks_conn_id="foo", - job_clusters=[{"job_cluster_key": "foo"}], - notebook_params={"notebook_path": "/foo/bar"}, - notebook_packages=[{"tg_index": {"package": "tg_package"}}], - ) - with task_group: - notebook_1 = DatabricksNotebookOperator( - task_id="notebook_1", - databricks_conn_id="foo", - notebook_path="/foo/bar", - notebook_packages=[{"nb_index": {"package": "nb_package"}}], - source="WORKSPACE", - job_cluster_key="foo", - ) - notebook_2 = DatabricksNotebookOperator( - task_id="notebook_2", - databricks_conn_id="foo", - notebook_path="/foo/bar", - source="WORKSPACE", - job_cluster_key="foo", - notebook_params={ - "foo": "bar", - }, - ) - notebook_1 >> notebook_2 - - assert len(task_group.children) == 3 - with pytest.raises(AirflowException) as exc_info: - task_group.children["test_workflow.launch"].execute(context={}) - assert ( - str(exc_info.value) == "Could not start the workflow job, it had state SKIPPED" - ) - - -@mock.patch("astro_databricks.operators.workflow.DatabricksHook") -@mock.patch("astro_databricks.operators.workflow.ApiClient") -@mock.patch("astro_databricks.operators.workflow.JobsApi") -@mock.patch( - "astro_databricks.operators.workflow.RunsApi.get_run", - return_value={"state": {"life_cycle_state": "RUNNING"}}, -) -def test_create_workflow_from_notebooks_with_create( - mock_run_api, mock_jobs_api, mock_api, mock_hook, dag -): - mock_jobs_api.return_value.create_job.return_value = {"job_id": 1} - # In unittest, this function returns a MagicMock object by default, which updates an existing workflow instead of creating a new one. - # This causes the create_job assertion to fail. To prevent this, the function's return value should be overridden to an empty list. - mock_jobs_api.return_value.list_jobs.return_value.get.return_value = [] - - with dag: - task_group = DatabricksWorkflowTaskGroup( - group_id="test_workflow", - databricks_conn_id="foo", - job_clusters=[{"job_cluster_key": "foo"}], - notebook_params={"notebook_path": "/foo/bar"}, - notebook_packages=[{"tg_index": {"package": "tg_package"}}], - ) - with task_group: - notebook_1 = DatabricksNotebookOperator( - task_id="notebook_1", - databricks_conn_id="foo", - notebook_path="/foo/bar", - notebook_packages=[{"nb_index": {"package": "nb_package"}}], - source="WORKSPACE", - job_cluster_key="foo", - ) - notebook_2 = DatabricksNotebookOperator( - task_id="notebook_2", - databricks_conn_id="foo", - notebook_path="/foo/bar", - source="WORKSPACE", - job_cluster_key="foo", - notebook_params={ - "foo": "bar", - }, - ) - notebook_1 >> notebook_2 - - assert len(task_group.children) == 3 - task_group.children["test_workflow.launch"].execute(context={}) - mock_jobs_api.return_value.create_job.assert_called_once_with( - json=expected_workflow_json, - version=DATABRICKS_JOBS_API_VERSION, - ) - mock_jobs_api.return_value.run_now.assert_called_once_with( - job_id=1, - jar_params=[], - notebook_params={"notebook_path": "/foo/bar"}, - python_params=[], - spark_submit_params=[], - version=DATABRICKS_JOBS_API_VERSION, - ) - - -@mock.patch("astro_databricks.operators.workflow.DatabricksHook") -@mock.patch("astro_databricks.operators.workflow.ApiClient") -@mock.patch("astro_databricks.operators.workflow._get_job_by_name") -@mock.patch( - "astro_databricks.operators.workflow.RunsApi.get_run", - side_effect=[ - {"state": {"life_cycle_state": "BLOCKED"}}, - {"state": {"life_cycle_state": "BLOCKED"}}, - {"state": {"life_cycle_state": "RUNNING"}}, - ], -) -@mock.patch("astro_databricks.operators.workflow.JobsApi.run_now") -@mock.patch("astro_databricks.operators.workflow.JobsApi.create_job") -def test_create_workflow_from_notebooks_job_templates_notebook_jobs( - mock_create_job, - mock_run_now, - mock_get_run, - mock_get_jobs, - mock_api, - mock_hook, - dag, - caplog, -): - mock_get_jobs.return_value = {"job_id": None} - caplog.set_level(logging.INFO) - with dag: - task_group = DatabricksWorkflowTaskGroup( - group_id="test_workflow", - databricks_conn_id="foo", - job_clusters=[{"job_cluster_key": "foo"}], - notebook_params={"notebook_path": "/foo/bar", "ts": "{{ ts }}"}, - notebook_packages=[{"tg_index": {"package": "tg_package"}}], - ) - with task_group: - notebook_1 = DatabricksNotebookOperator( - task_id="notebook_1", - databricks_conn_id="foo", - notebook_path="/foo/bar", - notebook_packages=[{"nb_index": {"package": "nb_package"}}], - notebook_params={"ds": "{{ ds }}"}, - source="WORKSPACE", - job_cluster_key="foo", - ) - - notebook_1 - - assert len(task_group.children) == 2 - context = { - "ds": "yyyy-mm-dd", - "ts": "hh:mm", - "ti": mock.MagicMock(), - "expanded_ti_count": 0, - } - task_group.children["test_workflow.launch"].execute(context=context) - assert mock_get_run.call_count == 3 - assert "Job state: BLOCKED" in caplog.messages - - notebook_job_parameters = mock_create_job.call_args.kwargs["json"]["tasks"][0][ - "notebook_task" - ]["base_parameters"] - assert notebook_job_parameters["ds"] == "yyyy-mm-dd" - assert notebook_job_parameters["ts"] == "hh:mm" - - -@mock.patch("astro_databricks.operators.workflow.DatabricksHook") -@mock.patch("astro_databricks.operators.workflow.ApiClient") -@mock.patch("astro_databricks.operators.workflow.JobsApi") -@mock.patch("astro_databricks.operators.workflow._get_job_by_name") -@mock.patch( - "astro_databricks.operators.workflow.RunsApi.get_run", - return_value={"state": {"life_cycle_state": "RUNNING"}}, -) -def test_create_workflow_with_arbitrary_extra_job_params( - mock_run_api, mock_get_jobs, mock_jobs_api, mock_api, mock_hook, dag -): - mock_get_jobs.return_value = {"job_id": 862519602273592} - - extra_job_params = { - "timeout_seconds": 10, # default: 0 - "webhook_notifications": { - "on_failure": [{"id": "b0aea8ab-ea8c-4a45-a2e9-9a26753fd702"}], - }, - "email_notifications": { - "no_alert_for_skipped_runs": True, # default: False - "on_start": ["user.name@databricks.com"], - }, - "git_source": { # no default value - "git_url": "https://github.com/astronomer/astro-provider-databricks", - "git_provider": "gitHub", - "git_branch": "main", - }, - } - with dag: - task_group = DatabricksWorkflowTaskGroup( - group_id="test_workflow", - databricks_conn_id="foo", - job_clusters=[{"job_cluster_key": "foo"}], - notebook_params={"notebook_path": "/foo/bar"}, - extra_job_params=extra_job_params, - ) - with task_group: - notebook_with_extra = DatabricksNotebookOperator( - task_id="notebook_with_extra", - databricks_conn_id="foo", - notebook_path="/foo/bar", - source="WORKSPACE", - job_cluster_key="foo", - ) - notebook_with_extra - - assert len(task_group.children) == 2 - - task_group.children["test_workflow.launch"].create_workflow_json() - task_group.children["test_workflow.launch"].execute(context={}) - - mock_jobs_api.return_value.reset_job.assert_called_once() - kwargs = mock_jobs_api.return_value.reset_job.call_args_list[0].kwargs["json"] - - assert kwargs["job_id"] == 862519602273592 - assert ( - kwargs["new_settings"]["email_notifications"] - == extra_job_params["email_notifications"] - ) - assert ( - kwargs["new_settings"]["timeout_seconds"] == extra_job_params["timeout_seconds"] - ) - assert kwargs["new_settings"]["git_source"] == extra_job_params["git_source"] - assert ( - kwargs["new_settings"]["webhook_notifications"] - == extra_job_params["webhook_notifications"] - ) - assert ( - kwargs["new_settings"]["email_notifications"] - == extra_job_params["email_notifications"] - ) - - -@mock.patch("astro_databricks.operators.workflow.DatabricksHook") -@mock.patch("astro_databricks.operators.workflow.ApiClient") -@mock.patch("astro_databricks.operators.workflow.JobsApi") -@mock.patch("astro_databricks.operators.workflow._get_job_by_name") -@mock.patch( - "astro_databricks.operators.workflow.RunsApi.get_run", - return_value={"state": {"life_cycle_state": "RUNNING"}}, -) -def test_create_workflow_with_nested_task_groups( - mock_run_api, mock_get_jobs, mock_jobs_api, mock_api, mock_hook, dag -): - mock_get_jobs.return_value = {"job_id": 862519602273592} - - extra_job_params = { - "timeout_seconds": 10, # default: 0 - "webhook_notifications": { - "on_failure": [{"id": "b0aea8ab-ea8c-4a45-a2e9-9a26753fd702"}], - }, - "email_notifications": { - "no_alert_for_skipped_runs": True, # default: False - "on_start": ["user.name@databricks.com"], - }, - "git_source": { # no default value - "git_url": "https://github.com/astronomer/astro-provider-databricks", - "git_provider": "gitHub", - "git_branch": "main", - }, - } - with dag: - outer_task_group = DatabricksWorkflowTaskGroup( - group_id="test_workflow", - databricks_conn_id="foo", - job_clusters=[{"job_cluster_key": "foo"}], - notebook_params={"notebook_path": "/foo/bar"}, - extra_job_params=extra_job_params, - notebook_packages=[ - {"pypi": {"package": "mlflow==2.4.0"}}, - ], - ) - with outer_task_group: - direct_notebook = DatabricksNotebookOperator( - task_id="direct_notebook", - databricks_conn_id="foo", - notebook_path="/foo/bar", - source="WORKSPACE", - job_cluster_key="foo", - ) - - with TaskGroup("middle_task_group") as middle_task_group: - with TaskGroup("inner_task_group") as inner_task_group: - inner_notebook = DatabricksNotebookOperator( - task_id="inner_notebook", - databricks_conn_id="foo", - notebook_path="/foo/bar", - source="WORKSPACE", - job_cluster_key="foo", - ) - inner_notebook - inner_task_group - direct_notebook >> middle_task_group - - assert len(outer_task_group.children) == 3 - - outer_task_group.children["test_workflow.launch"].create_workflow_json() - outer_task_group.children["test_workflow.launch"].execute(context={}) - - kwargs = mock_jobs_api.return_value.reset_job.call_args_list[0].kwargs["json"] - - inner_notebook_json = kwargs["new_settings"]["tasks"][0] - outer_notebook_json = kwargs["new_settings"]["tasks"][1] - - assert ( - inner_notebook_json["task_key"] - == "unit_test_dag__test_workflow__direct_notebook" - ) - assert inner_notebook_json["libraries"] == [{"pypi": {"package": "mlflow==2.4.0"}}] +from astro_databricks.operators.workflow import DatabricksWorkflowTaskGroup - assert ( - outer_notebook_json["task_key"] - == "unit_test_dag__test_workflow__middle_task_group__inner_task_group__inner_notebook" - ) - assert outer_notebook_json["libraries"] == [{"pypi": {"package": "mlflow==2.4.0"}}] -@mock.patch("astro_databricks.operators.workflow.DatabricksHook") -@mock.patch("astro_databricks.operators.workflow.ApiClient") -@mock.patch("astro_databricks.operators.workflow.JobsApi") -@mock.patch( - "astro_databricks.operators.workflow.RunsApi.get_run", - return_value={"state": {"life_cycle_state": "RUNNING"}}, -) -def test_create_workflow_from_notebooks_with_different_clusters( - mock_run_api, mock_jobs_api, mock_api, mock_hook, dag -): - mock_jobs_api.return_value.create_job.return_value = {"job_id": 1} - mock_jobs_api.return_value.list_jobs.return_value.get.return_value = [] - +def test_workflow_init(dag): with dag: task_group = DatabricksWorkflowTaskGroup( group_id="test_workflow", @@ -405,38 +15,5 @@ def test_create_workflow_from_notebooks_with_different_clusters( notebook_params={"notebook_path": "/foo/bar"}, notebook_packages=[{"tg_index": {"package": "tg_package"}}], ) - with task_group: - notebook_1 = DatabricksNotebookOperator( - task_id="notebook_1", - databricks_conn_id="foo", - notebook_path="/foo/bar", - notebook_packages=[{"nb_index": {"package": "nb_package"}}], - source="WORKSPACE", - job_cluster_key="foo", - ) - notebook_2 = DatabricksNotebookOperator( - task_id="notebook_2", - databricks_conn_id="foo", - notebook_path="/foo/bar", - source="WORKSPACE", - existing_cluster_id="foo", - notebook_params={ - "foo": "bar", - }, - ) - notebook_1 >> notebook_2 - assert len(task_group.children) == 3 - task_group.children["test_workflow.launch"].execute(context={}) - mock_jobs_api.return_value.create_job.assert_called_once_with( - json=expected_workflow_json_existing_cluster_id, - version=DATABRICKS_JOBS_API_VERSION, - ) - mock_jobs_api.return_value.run_now.assert_called_once_with( - job_id=1, - jar_params=[], - notebook_params={"notebook_path": "/foo/bar"}, - python_params=[], - spark_submit_params=[], - version=DATABRICKS_JOBS_API_VERSION, - ) + assert isinstance(task_group, UpstreamDatabricksWorkflowTaskGroup) From b08de38f4ea03f8d72852fcdb3d30a3ce65a675a Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Mon, 5 Aug 2024 15:00:06 +0530 Subject: [PATCH 2/5] Update Airflow matrix in CI and specify current branch to run CI on --- .github/workflows/ci.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ad0d2a5..5a8fbd8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,7 +2,7 @@ name: Build and test astro databricks provider on: # yamllint disable-line rule:truthy push: - branches: [main] + branches: [main, deprecate-provider] pull_request: branches: [main, 'release-**'] @@ -78,7 +78,7 @@ jobs: fail-fast: false matrix: python: ['3.8', '3.9', '3.10'] - airflow: [2.5] + airflow: [2.7] if: >- github.event_name == 'push' || ( @@ -125,7 +125,7 @@ jobs: fail-fast: false matrix: python: ['3.8', '3.9', '3.10'] - airflow: ['2.3', '2.4', '2.5', '2.6', '2.7', '2.8'] + airflow: ['2.7', '2.8'] if: >- github.event_name == 'push' || @@ -172,7 +172,7 @@ jobs: path: ./.coverage env: DATABRICKS_CONN_TOKEN: ${{ secrets.DATABRICKS_CONN_TOKEN }} - DATABRICKS_CONN_HOST: {{ secrets.DATABRICKS_CONN_HOST }} + DATABRICKS_CONN_HOST: ${{ secrets.DATABRICKS_CONN_HOST }} DATABRICKS_CONN: ${{ secrets.AIRFLOW_CONN_DATABRICKS_DEFAULT }} From a4bd3f7c52326be35328d493d5529669f4e7fb19 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Tue, 6 Aug 2024 15:30:01 +0530 Subject: [PATCH 3/5] Enhance docs --- README.md | 15 ++++++++++++++- docs/index.rst | 28 ++++++++++++++++------------ 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 9524b91..ba5fc77 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,18 @@ ## Deprecation Notice -This provider is now deprecated since version 0.3.0 and will not be maintained. Please use the official [Apache Airflow Databricks Provider >= 6.8.0](https://airflow.apache.org/docs/apache-airflow-providers-databricks/stable/index.html) instead. + +With the release ``0.3.0`` of the ``astro-provider-databricks`` package, this provider stands deprecated and will +no longer receive updates. We recommend migrating to the official ``apache-airflow-providers-databricks>= 6.8.0`` for the latest features and support. +For the operators and sensors that are deprecated in this repository, migrating to the official Apache Airflow Databricks Provider +is as simple as changing the import path in your DAG code as per the below examples. + +| Previous import path used (Deprecated now) | Suggested import path to use | +|-------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------| +| `from astro_databricks.operators.notebook import DatabricksNotebookOperator` | `from airflow.providers.databricks.operators.databricks import DatabricksNotebookOperator` | +| `from astro_databricks.operators.workflow import DatabricksWorkflowTaskGroup` | `from airflow.providers.databricks.operators.databricks_workflow import DatabricksWorkflowTaskGroup` | +| `from astro_databricks.operators.common import DatabricksTaskOperator` | `from airflow.providers.databricks.operators.databricks import DatabricksTaskOperator` | +| `from astro_databricks.plugins.plugin import AstroDatabricksPlugin` | `from airflow.providers.airflow.providers.databricks.plugins.databricks_workflow import DatabricksWorkflowPlugin` | + +# Archives

Databricks Workflows in Airflow diff --git a/docs/index.rst b/docs/index.rst index 4bac7a9..ed54dce 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -6,19 +6,23 @@ Deprecation Notice ------------------ With the release ``0.3.0`` of the ``astro-provider-databricks`` package, this provider stands deprecated and will -no longer receive updates. We recommend migrating to the official Apache Airflow Databricks Provider for the latest features and support. +no longer receive updates. We recommend migrating to the official ``apache-airflow-providers-databricks>=6.8.0`` for the latest features and support. For the operators and sensors that are deprecated in this repository, migrating to the official Apache Airflow Databricks Provider -is as simple as changing the import path from - -.. code-block:: - - from astro_databricks import import SomeOperator - -to - -.. code-block:: - - from airflow.providers.databricks.operators.operator_module import SomeOperator +is as simple as changing the import path in your DAG code as per the below examples. + +.. list-table:: Import paths to change for migrating to the official Apache Airflow Databricks Provider + :header-rows: 1 + + * - Previous import path used + - Newer import path to use + * - from astro_databricks.operators.notebook import DatabricksNotebookOperator + - from airflow.providers.databricks.operators.databricks import DatabricksNotebookOperator + * - from astro_databricks.operators.workflow import DatabricksWorkflowTaskGroup + - from airflow.providers.databricks.operators.databricks_workflow import DatabricksWorkflowTaskGroup + * - from astro_databricks.operators.common import DatabricksTaskOperator + - from airflow.providers.databricks.operators.databricks import DatabricksTaskOperator + * - from astro_databricks.plugins.plugin import AstroDatabricksPlugin + - from airflow.providers.airflow.providers.databricks.plugins.databricks_workflow import DatabricksWorkflowPlugin Astro Databricks Provider From 949b4103458f49691780cdd8ea0fead949246267 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Tue, 6 Aug 2024 15:30:31 +0530 Subject: [PATCH 4/5] Release 0.3.0rc1 --- CHANGELOG.rst | 8 ++++++++ src/astro_databricks/__init__.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index dabbbb6..50fb490 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,14 @@ Changelog ========= +0.3.0rc1 (06-08-24) +------------------- + +Deprecations + +* Deprecate the provider and proxy instantiations to upstream official Apache Airflow Databricks provider (PR `#84 `_ by @pankajkoti) + + 0.2.2 (16-04-24) ---------------- diff --git a/src/astro_databricks/__init__.py b/src/astro_databricks/__init__.py index 7d435e7..81e6ff7 100644 --- a/src/astro_databricks/__init__.py +++ b/src/astro_databricks/__init__.py @@ -3,7 +3,7 @@ from astro_databricks.operators.notebook import DatabricksNotebookOperator from astro_databricks.operators.workflow import DatabricksWorkflowTaskGroup -__version__ = "0.3.0" +__version__ = "0.3.0rc1" __all__ = [ "DatabricksNotebookOperator", "DatabricksWorkflowTaskGroup", From abc416fc12712a5d82726b569318e98466aec46a Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Wed, 7 Aug 2024 10:51:29 +0530 Subject: [PATCH 5/5] Use stable version 6.8.0 for Airflow Databricks provider since RC has been released --- .github/workflows/ci.yml | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5a8fbd8..1cbcb8d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,7 +2,7 @@ name: Build and test astro databricks provider on: # yamllint disable-line rule:truthy push: - branches: [main, deprecate-provider] + branches: [main] pull_request: branches: [main, 'release-**'] diff --git a/pyproject.toml b/pyproject.toml index 8b619bb..8f83b68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ classifiers = [ ] dependencies = [ "apache-airflow>=2.7", - "apache-airflow-providers-databricks>=6.8.0rc1", + "apache-airflow-providers-databricks>=6.8.0", ] [project.optional-dependencies]