Skip to content
Closed
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Check out the Getting Started guide in our [docs](https://astronomer.github.io/a
## Sample DAGs

### Example 1: Using @ray.task for job life cycle
The below example showcases how to use the ``@ray.task`` decorator to manage the full lifecycle of a Ray cluster: setup, job execution, and teardown.
The below example showcases how to use the ``@ray.task`` decorator to manage the full lifecycle of a Ray cluster: setup, job execution, and teardown. The configuration for the decorator can provided statically or at runtime.

This approach is ideal for jobs that require a dedicated, short-lived cluster, optimizing resource usage by cleaning up after task completion

Expand Down
3 changes: 3 additions & 0 deletions docs/getting_started/code_samples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ The below example showcases how to use the ``@ray.task`` decorator to manage the

This approach is ideal for jobs that require a dedicated, short-lived cluster, optimizing resource usage by cleaning up after task completion.

.. note::
Configuration can be specified as a dictionary, either statically or dynamically at runtime as needed. We can also provide additional inputs while generating dynamic configurations. See example dags for reference.

.. literalinclude:: ../../example_dags/ray_taskflow_example.py
:language: python
:linenos:
Expand Down
60 changes: 60 additions & 0 deletions example_dags/ray_taskflow_example_dynamic_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from datetime import datetime
from pathlib import Path

from airflow.decorators import dag, task

from ray_provider.decorators.ray import ray


def generate_config(custom_memory: int, **context):

CONN_ID = "ray_conn"
RAY_SPEC = Path(__file__).parent / "scripts/ray.yaml"
FOLDER_PATH = Path(__file__).parent / "ray_scripts"

return {
"conn_id": CONN_ID,
"runtime_env": {"working_dir": str(FOLDER_PATH), "pip": ["numpy"]},
"num_cpus": 1,
"num_gpus": 0,
"memory": custom_memory,
"poll_interval": 5,
"ray_cluster_yaml": str(RAY_SPEC),
"xcom_task_key": "dashboard",
"execution_date": str(context.get("execution_date")),
}


@dag(
dag_id="Ray_Taskflow_Example_Dynamic_Config",
start_date=datetime(2023, 1, 1),
schedule=None,
catchup=False,
tags=["ray", "example"],
)
def ray_taskflow_dag():
@task
def generate_data():
return [1, 2, 3]

@ray.task(config=generate_config, custom_memory=1024)
def process_data_with_ray(data):
import numpy as np
import ray

@ray.remote
def square(x):
return x**2

data = np.array(data)
futures = [square.remote(x) for x in data]
results = ray.get(futures)
mean = np.mean(results)
print(f"Mean of this population is {mean}")
return mean

data = generate_data()
process_data_with_ray(data)


ray_example_dag = ray_taskflow_dag()
2 changes: 1 addition & 1 deletion ray_provider/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

__version__ = "0.2.1"
__version__ = "0.3.0a6"

from typing import Any

Expand Down
106 changes: 57 additions & 49 deletions ray_provider/decorators/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,54 +31,23 @@ class _RayDecoratedOperator(DecoratedOperator, SubmitRayJob):

template_fields: Any = (*SubmitRayJob.template_fields, "op_args", "op_kwargs")

def __init__(self, config: dict[str, Any], **kwargs: Any) -> None:
self.conn_id: str = config.get("conn_id", "")
self.is_decorated_function = False if "entrypoint" in config else True
self.entrypoint: str = config.get("entrypoint", "python script.py")
self.runtime_env: dict[str, Any] = config.get("runtime_env", {})

self.num_cpus: int | float = config.get("num_cpus", 1)
self.num_gpus: int | float = config.get("num_gpus", 0)
self.memory: int | float = config.get("memory", None)
self.ray_resources: dict[str, Any] | None = config.get("resources", None)
self.ray_cluster_yaml: str | None = config.get("ray_cluster_yaml", None)
self.update_if_exists: bool = config.get("update_if_exists", False)
self.kuberay_version: str = config.get("kuberay_version", "1.0.0")
self.gpu_device_plugin_yaml: str = config.get(
"gpu_device_plugin_yaml",
"https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml",
)
self.fetch_logs: bool = config.get("fetch_logs", True)
self.wait_for_completion: bool = config.get("wait_for_completion", True)
def __init__(self, config: dict[str, Any] | Callable[..., dict[str, Any]], **kwargs: Any) -> None:
job_timeout_seconds: int = config.get("job_timeout_seconds", 600)
self.poll_interval: int = config.get("poll_interval", 60)
self.xcom_task_key: str | None = config.get("xcom_task_key", None)

self.config = config
self.kwargs = kwargs
super().__init__(conn_id="", entrypoint="python script.py", runtime_env={}, **kwargs)

if not isinstance(self.num_cpus, (int, float)):
raise TypeError("num_cpus should be an integer or float value")
if not isinstance(self.num_gpus, (int, float)):
raise TypeError("num_gpus should be an integer or float value")

super().__init__(
conn_id=self.conn_id,
entrypoint=self.entrypoint,
runtime_env=self.runtime_env,
num_cpus=self.num_cpus,
num_gpus=self.num_gpus,
memory=self.memory,
resources=self.ray_resources,
ray_cluster_yaml=self.ray_cluster_yaml,
update_if_exists=self.update_if_exists,
kuberay_version=self.kuberay_version,
gpu_device_plugin_yaml=self.gpu_device_plugin_yaml,
fetch_logs=self.fetch_logs,
wait_for_completion=self.wait_for_completion,
job_timeout_seconds=job_timeout_seconds,
poll_interval=self.poll_interval,
xcom_task_key=self.xcom_task_key,
**kwargs,
)
def get_config(self, context: Context, config: Callable[..., dict[str, Any]], **kwargs: Any) -> dict[str, Any]:
config_params = inspect.signature(config).parameters

config_kwargs = {k: v for k, v in kwargs.items() if k in config_params and k != "context"}

if "context" in config_params:
config_kwargs["context"] = context

# Call config with the prepared arguments
return config(**config_kwargs)

def execute(self, context: Context) -> Any:
"""
Expand All @@ -88,8 +57,42 @@ def execute(self, context: Context) -> Any:
:return: The result of the Ray job execution.
:raises AirflowException: If job submission fails.
"""
tmp_dir = None
temp_dir = None
try:
# Generate the configuration
if callable(self.config):
config = self.get_config(context=context, config=self.config, **self.kwargs)
else:
config = self.config

# Prepare Ray job parameters
self.conn_id: str = config.get("conn_id", "")
self.is_decorated_function = False if "entrypoint" in config else True
self.entrypoint: str = config.get("entrypoint", "python script.py")
self.runtime_env: dict[str, Any] = config.get("runtime_env", {})

self.num_cpus: int | float = config.get("num_cpus", 1)
self.num_gpus: int | float = config.get("num_gpus", 0)
self.memory: int | float = config.get("memory", None)
self.ray_resources: dict[str, Any] | None = config.get("resources", None)
self.ray_cluster_yaml: str | None = config.get("ray_cluster_yaml", None)
self.update_if_exists: bool = config.get("update_if_exists", False)
self.kuberay_version: str = config.get("kuberay_version", "1.0.0")
self.gpu_device_plugin_yaml: str = config.get(
"gpu_device_plugin_yaml",
"https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml",
)
self.fetch_logs: bool = config.get("fetch_logs", True)
self.wait_for_completion: bool = config.get("wait_for_completion", True)
self.job_timeout_seconds: int = config.get("job_timeout_seconds", 600)
self.poll_interval: int = config.get("poll_interval", 60)
self.xcom_task_key: str | None = config.get("xcom_task_key", None)

if not isinstance(self.num_cpus, (int, float)):
raise TypeError("num_cpus should be an integer or float value")
if not isinstance(self.num_gpus, (int, float)):
raise TypeError("num_gpus should be an integer or float value")

if self.is_decorated_function:
self.log.info(
f"Entrypoint is not provided, is_decorated_function is set to {self.is_decorated_function}"
Expand Down Expand Up @@ -126,8 +129,8 @@ def execute(self, context: Context) -> Any:
self.log.error(f"Failed during execution with error: {e}")
raise AirflowException("Job submission failed") from e
finally:
if tmp_dir and os.path.exists(tmp_dir):
shutil.rmtree(tmp_dir)
if temp_dir and os.path.exists(temp_dir):
shutil.rmtree(temp_dir)

def _extract_function_body(self, source: str) -> str:
"""Extract the function, excluding only the ray.task decorator."""
Expand All @@ -146,19 +149,24 @@ class ray:
def task(
python_callable: Callable[..., Any] | None = None,
multiple_outputs: bool | None = None,
config: Callable[[], dict[str, Any]] | dict[str, Any] | None = None,
**kwargs: Any,
) -> TaskDecorator:
"""
Decorator to define a task that submits a Ray job.

:param python_callable: The callable function to decorate.
:param multiple_outputs: If True, will return multiple outputs.
:param config: A dictionary of configuration or a callable that returns a dictionary.
:param kwargs: Additional keyword arguments.
:return: The decorated task.
"""
if config is None:
config = {}

return task_decorator_factory(
python_callable=python_callable,
multiple_outputs=multiple_outputs,
decorated_operator_class=_RayDecoratedOperator,
config=config,
**kwargs,
)
11 changes: 9 additions & 2 deletions ray_provider/operators/ray.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import traceback
from datetime import timedelta
from functools import cached_property
from typing import Any
Expand Down Expand Up @@ -281,6 +282,10 @@ def execute(self, context: Context) -> str:
current_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id)
self.log.info(f"Current job status for {self.job_id} is: {current_status}")

job_timeout_seconds = self.job_timeout_seconds
if isinstance(self.job_timeout_seconds, int):
job_timeout_seconds = timedelta(seconds=self.job_timeout_seconds) if self.job_timeout_seconds > 0 else None

if current_status not in self.terminal_states:
self.log.info("Deferring the polling to RayJobTrigger...")
self.defer(
Expand All @@ -294,7 +299,7 @@ def execute(self, context: Context) -> str:
fetch_logs=self.fetch_logs,
),
method_name="execute_complete",
timeout=self.job_timeout_seconds,
timeout=job_timeout_seconds,
)
elif current_status == JobStatus.SUCCEEDED:
self.log.info("Job %s completed successfully", self.job_id)
Expand All @@ -308,8 +313,10 @@ def execute(self, context: Context) -> str:
)
return self.job_id
except Exception as e:
self._delete_cluster()
error_details = traceback.format_exc()
self.log.info(error_details)
raise AirflowException(f"SubmitRayJob operator failed due to {e}. Cleaning up resources...")
self._delete_cluster()

def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
"""
Expand Down