diff --git a/dev/dags/setup-teardown.py b/dev/dags/setup-teardown.py index 48451c5..69589b8 100644 --- a/dev/dags/setup-teardown.py +++ b/dev/dags/setup-teardown.py @@ -42,3 +42,4 @@ # Create ray cluster and submit ray job setup_cluster.as_setup() >> submit_ray_job >> delete_cluster.as_teardown() setup_cluster >> delete_cluster + diff --git a/docs/getting_started/local_development_setup.rst b/docs/getting_started/local_development_setup.rst index bf2298d..c395156 100644 --- a/docs/getting_started/local_development_setup.rst +++ b/docs/getting_started/local_development_setup.rst @@ -28,6 +28,7 @@ Install the following software: 1. **Create a Kind Cluster** a) If you plan to access the Kind Kubernetes cluster from Airflow using Astro CLI, use the following configuration file, + also available in ``dev/kind-config.yaml``, to create the Kind cluster: .. code-block:: @@ -356,3 +357,4 @@ The most basic setup will look something like below: - Disable SSL: Tick the disable SSL boolean if needed .. image:: ../_static/basic_local_kubernetes_conn.png + diff --git a/ray_provider/constants.py b/ray_provider/constants.py index 9e37f6a..84d5104 100644 --- a/ray_provider/constants.py +++ b/ray_provider/constants.py @@ -1,3 +1,4 @@ from ray.job_submission import JobStatus + TERMINAL_JOB_STATUSES = {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED} diff --git a/ray_provider/hooks.py b/ray_provider/hooks.py index 2000ecf..6f30884 100644 --- a/ray_provider/hooks.py +++ b/ray_provider/hooks.py @@ -5,6 +5,7 @@ import subprocess import tempfile import time +from functools import cached_property from typing import Any, AsyncIterator import requests @@ -15,6 +16,10 @@ from kubernetes import client, config from ray.job_submission import JobStatus, JobSubmissionClient +from ray_provider.constants import TERMINAL_JOB_STATUSES + +DEFAULT_NAMESPACE = "default" + class RayHook(KubernetesHook): # type: ignore """ @@ -26,13 +31,11 @@ class RayHook(KubernetesHook): # type: ignore :param conn_id: The connection ID to use when fetching connection info. """ - conn_name_attr = "ray_conn_id" + conn_name_attr = "conn_id" default_conn_name = "ray_default" conn_type = "ray" hook_name = "Ray" - DEFAULT_NAMESPACE = "default" - @classmethod def get_ui_field_behaviour(cls) -> dict[str, Any]: """ @@ -84,7 +87,7 @@ def __init__( self.conn_id = conn_id self.address = self._get_field("address") or os.getenv("RAY_ADDRESS") - self.log.info(f"Ray cluster address is: {self.address}") + self.log.debug(f"Ray cluster address is: {self.address}") self.create_cluster_if_needed = False self.cookies = self._get_field("cookies") self.metadata = self._get_field("metadata") @@ -92,7 +95,7 @@ def __init__( self.verify = self._get_field("verify") or False self.ray_client_instance = None - self.namespace = self.get_namespace() or self.DEFAULT_NAMESPACE + self.default_namespace = self.get_namespace() or DEFAULT_NAMESPACE self.kubeconfig: str | None = None self.in_cluster: bool | None = None self.client_configuration = None @@ -104,9 +107,41 @@ def __init__( self.cluster_context = self._get_field("cluster_context") self.kubeconfig_path = self._get_field("kube_config_path") self.kubeconfig_content = self._get_field("kube_config") + self.ray_cluster_yaml: None | str = None self._setup_kubeconfig(self.kubeconfig_path, self.kubeconfig_content, self.cluster_context) + # Create a PR for this + @cached_property + def namespace(self) -> str: + if self.ray_cluster_yaml is None: + return self.default_namespace + cluster_spec = self.load_yaml_content(self.ray_cluster_yaml) + return cluster_spec["metadata"].get("namespace") or self.default_namespace + + # Create another PR for this + def test_connection(self) -> (bool, str): + job_client = self.ray_client(self.address) + + job_id = job_client.submit_job(entrypoint="import ray; ray.init(); print(ray.cluster_resources())") + self.log.info(f"Ray test connection: Submitted job with ID: {job_id}") + + job_completed = False + connection_attempt = 10 + while not job_completed and connection_attempt: + time.sleep(0.5) + job_status = job_client.get_job_status(job_id) + self.log.info(f"Ray test connection: Job {job_id} status {job_status}") + if job_status in TERMINAL_JOB_STATUSES: + job_completed = True + connection_attempt -= 1 + + if job_status != JobStatus.SUCCEEDED: + return False, f"Ray test connection failed: Job {job_id} status {job_status}" + + return True, job_status + # TODO: check webserver logs + def _setup_kubeconfig( self, kubeconfig_path: str | None, kubeconfig_content: str | None, cluster_context: str | None ) -> None: @@ -151,19 +186,16 @@ def ray_client(self, dashboard_url: str | None = None) -> JobSubmissionClient: :raises AirflowException: If the connection fails. """ if not self.ray_client_instance: - try: - self.log.info(f"Address URL is: {self.address}") - self.log.info(f"Dashboard URL is: {dashboard_url}") - self.ray_client_instance = JobSubmissionClient( - address=dashboard_url or self.address, - create_cluster_if_needed=self.create_cluster_if_needed, - cookies=self.cookies, - metadata=self.metadata, - headers=self.headers, - verify=self.verify, - ) - except Exception as e: - raise AirflowException(f"Failed to create Ray JobSubmissionClient: {e}") + self.log.info(f"Address URL is: {self.address}") + self.log.info(f"Dashboard URL is: {dashboard_url}") + self.ray_client_instance = JobSubmissionClient( + address=dashboard_url or self.address, + create_cluster_if_needed=self.create_cluster_if_needed, + cookies=self.cookies, + metadata=self.metadata, + headers=self.headers, + verify=self.verify, + ) return self.ray_client_instance def submit_ray_job( @@ -255,6 +287,7 @@ def _is_port_open(self, host: str, port: int) -> bool: :param port: The port number to check. :return: True if the port is open, False otherwise. """ + self.log.info(f"_is_port_open: {host} {port}") with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.settimeout(1) try: @@ -331,11 +364,13 @@ def _wait_for_load_balancer( :raises AirflowException: If the LoadBalancer does not become ready within the specified retries. """ for attempt in range(1, max_retries + 1): - self.log.info(f"Attempt {attempt}: Checking LoadBalancer status...") + self.log.info(f"Attempt {attempt}: Checking LoadBalancer status {service_name} in {namespace}...") try: service: client.V1Service = self._get_service(service_name, namespace) + self.log.info(f"service: {service}") lb_details: dict[str, Any] | None = self._get_load_balancer_details(service) + self.log.info(f"lb_details: {lb_details}") if not lb_details: self.log.info("LoadBalancer details not available yet.") @@ -385,7 +420,7 @@ def _create_or_update_cluster( name: str, namespace: str, cluster_spec: dict[str, Any], - ) -> None: + ) -> str: """ Create or update the Ray cluster based on the cluster specification. @@ -398,21 +433,45 @@ def _create_or_update_cluster( :param cluster_spec: The specification of the Ray cluster. :raises AirflowException: If there's an error accessing or creating the Ray cluster. """ - try: + if update_if_exists: + + self.log.info(f"Updating existing Ray cluster: {name}") self.get_custom_object(group=group, version=version, plural=plural, name=name, namespace=namespace) - if update_if_exists: - self.log.info(f"Updating existing Ray cluster: {name}") - self.custom_object_client.patch_namespaced_custom_object( - group=group, version=version, namespace=namespace, plural=plural, name=name, body=cluster_spec - ) - except client.exceptions.ApiException as e: - if e.status == 404: - self.log.info(f"Creating new Ray cluster: {name}") - self.create_custom_object( - group=group, version=version, namespace=namespace, plural=plural, body=cluster_spec + self.custom_object_client.patch_namespaced_custom_object( + group=group, version=version, namespace=namespace, plural=plural, name=name, body=cluster_spec + ) + + self.log.info(f"Creating new Ray cluster: {name}") + + response = self.create_custom_object( + group=group, version=version, namespace=namespace, plural=plural, body=cluster_spec + ) + self.log.info(f"Resource created. Response: {response}") + + # TODO: may go to a different PR + start_time = time.time() + wait_timeout = 300 + poll_interval = 5 + + while time.time() - start_time < wait_timeout: + try: + resource = self.get_custom_object( + group=group, version=version, plural=plural, name=name, namespace=namespace ) + except client.exceptions.ApiException as e: + self.log.warning(f"Error fetching resource status: {e}") else: - raise AirflowException(f"Error accessing Ray cluster '{name}': {e}") + status = resource.get("status", {}) + self.log.info(f"Current status: {status}") + if status.get("state") == "ready": + self.log.info(f"Resource {name} of group {group} is now ready.") + return status + + time.sleep(poll_interval) + + raise TimeoutError( + f"Resource {name} of group {group} did not reach the desired state within {wait_timeout} seconds." + ) def _setup_gpu_driver(self, gpu_device_plugin_yaml: str) -> None: """ @@ -420,12 +479,14 @@ def _setup_gpu_driver(self, gpu_device_plugin_yaml: str) -> None: :param gpu_device_plugin_yaml: Path or URL to the GPU device plugin YAML. """ - gpu_driver = self.load_yaml_content(gpu_device_plugin_yaml) - gpu_driver_name = gpu_driver["metadata"]["name"] + self.log.info("Trying to setup gpu_device_plugin_yaml %s", gpu_device_plugin_yaml) + if gpu_device_plugin_yaml: + gpu_driver = self.load_yaml_content(gpu_device_plugin_yaml) + gpu_driver_name = gpu_driver["metadata"]["name"] - if not self.get_daemon_set(gpu_driver_name): - self.log.info("Creating DaemonSet for NVIDIA device plugin...") - self.create_daemon_set(gpu_driver_name, gpu_driver) + if not self.get_daemon_set(gpu_driver_name): + self.log.info("Creating DaemonSet for NVIDIA device plugin...") + self.create_daemon_set(gpu_driver_name, gpu_driver) def _setup_load_balancer(self, name: str, namespace: str, context: Context) -> None: """ @@ -464,24 +525,25 @@ def setup_ray_cluster( :param update_if_exists: Whether to update the cluster if it already exists. :raises AirflowException: If there's an error setting up the Ray cluster. """ - try: - self._validate_yaml_file(ray_cluster_yaml) + self._validate_yaml_file(ray_cluster_yaml) + self.ray_cluster_yaml = ray_cluster_yaml - self.log.info("::group::Add KubeRay operator") - self.install_kuberay_operator(version=kuberay_version) - self.log.info("::endgroup::") + self.log.info("::group:: (Setup 1/3) Add KubeRay operator") + self.install_kuberay_operator(version=kuberay_version) + self.log.info("::endgroup::") - self.log.info("::group::Create Ray Cluster") - self.log.info("Loading yaml content for Ray cluster CRD...") - cluster_spec = self.load_yaml_content(ray_cluster_yaml) + self.log.info("::group:: (Setup 2/3) Create Ray Cluster") + self.log.info("Loading yaml content for Ray cluster CRD...") + cluster_spec = self.load_yaml_content(ray_cluster_yaml) - kind = cluster_spec["kind"] - plural = f"{kind.lower()}s" if kind == "RayCluster" else kind - name = cluster_spec["metadata"]["name"] - namespace = self.namespace - api_version = cluster_spec["apiVersion"] - group, version = api_version.split("/") if "/" in api_version else ("", api_version) + kind = cluster_spec["kind"] + plural = f"{kind.lower()}s" if kind == "RayCluster" else kind + name = cluster_spec["metadata"]["name"] + namespace = cluster_spec["metadata"].get("namespace") or self.namespace + api_version = cluster_spec["apiVersion"] + group, version = api_version.split("/") if "/" in api_version else ("", api_version) + try: self._create_or_update_cluster( update_if_exists=update_if_exists, group=group, @@ -491,17 +553,21 @@ def setup_ray_cluster( namespace=namespace, cluster_spec=cluster_spec, ) - self.log.info("::endgroup::") + except TimeoutError as e: + self._delete_ray_cluster_crd(ray_cluster_yaml) + raise e + self.log.info("::endgroup::") - self._setup_gpu_driver(gpu_device_plugin_yaml=gpu_device_plugin_yaml) + self._setup_gpu_driver(gpu_device_plugin_yaml=gpu_device_plugin_yaml) - self.log.info("::group::Setup Load Balancer service") - self._setup_load_balancer(name, namespace, context) - self.log.info("::endgroup::") + # TODO: separate PR + # self.log.info("::group:: (Step 3/3) Setup Node Port service") + # self._setup_node_port(name, namespace, context) + # self.log.info("::endgroup::") - except Exception as e: - self.log.error(f"Error setting up Ray cluster: {e}") - raise AirflowException(f"Failed to set up Ray cluster: {e}") + self.log.info("::group:: (Setup 3/3) Setup Load Balancer service") + self._setup_load_balancer(name, namespace, context) + self.log.info("::endgroup::") def _delete_ray_cluster_crd(self, ray_cluster_yaml: str) -> None: """ @@ -510,6 +576,7 @@ def _delete_ray_cluster_crd(self, ray_cluster_yaml: str) -> None: :param ray_cluster_yaml: Path to the YAML file defining the Ray cluster. :raises AirflowException: If there's an error deleting the Ray cluster. """ + self.log.info("Attempting to delete a ray cluster...") self.log.info("Loading yaml content for Ray cluster CRD...") cluster_spec = self.load_yaml_content(ray_cluster_yaml) @@ -521,14 +588,15 @@ def _delete_ray_cluster_crd(self, ray_cluster_yaml: str) -> None: group, version = api_version.split("/") if "/" in api_version else ("", api_version) try: - if self.get_custom_object(group=group, version=version, plural=plural, name=name, namespace=namespace): - self.delete_custom_object(group=group, version=version, name=name, namespace=namespace, plural=plural) - self.log.info(f"Deleted Ray cluster: {name}") - else: - self.log.info(f"Ray cluster: {name} not found. Skipping the delete step.") + self.get_custom_object(group=group, version=version, plural=plural, name=name, namespace=namespace) except client.exceptions.ApiException as e: - if e.status != 404: - raise AirflowException(f"Error deleting Ray cluster '{name}': {e}") + if e.status == 404: + self.log.info(f"Ray cluster: {name} not found. Skipping the delete step.") + else: + self.log.exception(f"Issue retrieving Ray cluster: {name}. Unable to delete it.") + else: + self.delete_custom_object(group=group, version=version, name=name, namespace=namespace, plural=plural) + self.log.info(f"Deleted Ray cluster: {name}") def delete_ray_cluster(self, ray_cluster_yaml: str, gpu_device_plugin_yaml: str) -> None: """ @@ -538,10 +606,10 @@ def delete_ray_cluster(self, ray_cluster_yaml: str, gpu_device_plugin_yaml: str) :param gpu_device_plugin_yaml: Path or URL to the GPU device plugin YAML. Defaults to NVIDIA's plugin :raises AirflowException: If there's an error deleting the Ray cluster. """ - try: - self._validate_yaml_file(ray_cluster_yaml) + self._validate_yaml_file(ray_cluster_yaml) - """Delete the NVIDIA GPU device plugin DaemonSet if it exists.""" + if gpu_device_plugin_yaml: + # Delete the NVIDIA GPU device plugin DaemonSet if it exists. gpu_driver = self.load_yaml_content(gpu_device_plugin_yaml) gpu_driver_name = gpu_driver["metadata"]["name"] @@ -549,13 +617,18 @@ def delete_ray_cluster(self, ray_cluster_yaml: str, gpu_device_plugin_yaml: str) self.log.info("Deleting DaemonSet for NVIDIA device plugin...") self.delete_daemon_set(gpu_driver_name) - self.log.info("::group:: Delete Ray Cluster") - self._delete_ray_cluster_crd(ray_cluster_yaml=ray_cluster_yaml) - self.log.info("::endgroup::") - self.uninstall_kuberay_operator() - except Exception as e: - self.log.error(f"Error deleting Ray cluster: {e}") - raise AirflowException(f"Failed to delete Ray cluster: {e}") + self.log.info("::group:: Delete Ray Cluster") + self._delete_ray_cluster_crd(ray_cluster_yaml=ray_cluster_yaml) + self.log.info("::endgroup::") + + # TODO: review this previous behaviour of the code + # It can be problematic for us to uninstall the Kuberay operator that might have been previously installed: + self.log.info("::group:: Delete Kuberay operator") + self.uninstall_kuberay_operator() + self.log.info("::endgroup::") + # except Exception as e: + # self.log.error(f"Error deleting Ray cluster: {e}") + # raise AirflowException(f"Failed to delete Ray cluster: {e}") def _run_bash_command(self, command: str, env: dict[str, str] | None = None) -> tuple[str | None, str | None]: """ @@ -571,14 +644,17 @@ def _run_bash_command(self, command: str, env: dict[str, str] | None = None) -> try: result = subprocess.run(command, shell=True, check=True, text=True, capture_output=True, env=custom_env) - self.log.info("Standard Output: %s", result.stdout) - self.log.info("Standard Error: %s", result.stderr) + + if result.stderr: + self.log.info("Standard Error: %s", result.stderr) + else: + self.log.info("Standard Output: %s", result.stdout) return result.stdout, result.stderr except subprocess.CalledProcessError as e: - self.log.error("An error occurred while executing the command: %s", e) - self.log.error("Return code: %s", e.returncode) - self.log.error("Standard Output: %s", e.stdout) - self.log.error("Standard Error: %s", e.stderr) + self.log.warning("An error occurred while executing the command: %s", e) + self.log.warning("Return code: %s", e.returncode) + self.log.warning("Standard Output: %s", e.stdout) + self.log.warning("Standard Error: %s", e.stderr) return None, None def install_kuberay_operator( @@ -597,6 +673,7 @@ def install_kuberay_operator( helm upgrade --install kuberay-operator kuberay/kuberay-operator \ --version {version} --create-namespace --namespace {self.namespace} --kubeconfig {self.kubeconfig} """ + self.log.info(helm_command) result = self._run_bash_command(helm_command, env) self.log.info(result) return result @@ -621,6 +698,7 @@ def get_daemon_set(self, name: str) -> client.V1DaemonSet | None: :param name: The name of the DaemonSet. :return: The DaemonSet resource if found, None otherwise. """ + self.log.warning(f"Trying to find DaemonSet not found: {name}") try: api_response = self.apps_v1_client.read_namespaced_daemon_set(name, self.namespace) self.log.info(f"DaemonSet {api_response.metadata.name} retrieved.") @@ -667,3 +745,125 @@ def delete_daemon_set(self, name: str) -> client.V1Status | None: except client.exceptions.ApiException as e: self.log.error(f"Exception when deleting DaemonSet: {e}") return None + + # Add this to yet another PR + def _get_node_ip(self) -> str: + """ + Retrieve the IP address of a Kubernetes node. + + :return: The IP address of a node in the Kubernetes cluster. + """ + # Example: Retrieve the first node's IP (adjust based on your cluster setup) + nodes = self.core_v1_client.list_node().items + self.log.info(f"Nodes: {nodes}") + for node in nodes: + self.log.info(f"Node address: {node.status.addresses}") + for address in node.status.addresses: + if address.type == "ExternalIP": + return address.address + + for node in nodes: + self.log.info(f"Node address: {node.status.addresses}") + for address in node.status.addresses: + if address.type == "InternalIP": + return address.address + + raise AirflowException("No valid node IP found in the cluster.") + + # Add this to yet another PR + def _setup_node_port(self, name: str, namespace: str, context: dict) -> None: + """ + Set up the NodePort service and push URLs to XCom. + + :param name: The name of the Ray cluster. + :param namespace: The namespace where the cluster is deployed. + :param context: The Airflow task context. + """ + node_port_details: dict[str, Any] = self._wait_for_node_port_service( + service_name=f"{name}-head-svc", namespace=namespace + ) + + if node_port_details: + self.log.info(node_port_details) + + node_ports = node_port_details["node_ports"] + # Example: Assuming `node_ip` is provided as an environment variable or a known cluster node. + node_ip = self._get_node_ip() # Implement this method to return a valid node IP or DNS. + + for port in node_ports: + url = f"http://{node_ip}:{port['port']}" + context["task_instance"].xcom_push(key=port["name"], value=url) + self.log.info(f"Pushed URL to XCom: {url}") + else: + self.log.info("No NodePort URLs to push to XCom.") + + # Add this to yet another PR + def _wait_for_node_port_service( + self, + service_name: str, + namespace: str = "default", + max_retries: int = 30, + retry_interval: int = 10, + ) -> dict[str, Any]: + """ + Wait for the NodePort service to be ready and return its details. + + :param service_name: The name of the NodePort service. + :param namespace: The namespace of the service. + :param max_retries: Maximum number of retries. + :param retry_interval: Interval between retries in seconds. + :return: A dictionary containing NodePort service details. + :raises AirflowException: If the service does not become ready within the specified retries. + """ + for attempt in range(1, max_retries + 1): + self.log.info(f"Attempt {attempt}: Checking NodePort service status...") + + try: + service: client.V1Service = self._get_service(service_name, namespace) + service_details: dict[str, Any] | None = self._get_node_port_details(service) + + if service_details: + self.log.info("NodePort service is ready.") + return service_details + + self.log.info("NodePort details not available yet. Retrying...") + except AirflowException: + self.log.info("Service is not available yet.") + + time.sleep(retry_interval) + + raise AirflowException(f"Service did not become ready after {max_retries} attempts") + + # Add this to yet another PR + def _get_node_port_details(self, service: client.V1Service) -> dict[str, Any] | None: + """ + Extract NodePort details from the service. + + :param service: The Kubernetes service object. + :return: A dictionary containing NodePort details if available, None otherwise. + """ + node_ports = [] + for port in service.spec.ports: + if port.node_port: + node_ports.append({"name": port.name, "port": port.node_port}) + + if node_ports: + return {"node_ports": node_ports} + + return None + + # Add this to yet another PR + def _check_node_port_connectivity(self, node_ports: list[dict[str, Any]]) -> bool: + """ + Check if the NodePort is reachable. + + :param node_ports: List of NodePort details. + :return: True if at least one NodePort is accessible, False otherwise. + """ + for port_info in node_ports: + # Replace with actual logic to test connectivity if needed. + self.log.info(f"Checking connectivity for NodePort {port_info['port']}") + # Example: Simulate readiness check. + if self._is_port_open("example-node-ip", port_info["port"]): + return True + return False diff --git a/ray_provider/operators.py b/ray_provider/operators.py index 3a0776e..1217b4c 100644 --- a/ray_provider/operators.py +++ b/ray_provider/operators.py @@ -32,7 +32,8 @@ def __init__( conn_id: str, ray_cluster_yaml: str, kuberay_version: str = "1.0.0", - gpu_device_plugin_yaml: str = "https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml", + #gpu_device_plugin_yaml: str = "https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml", + gpu_device_plugin_yaml: str = "", update_if_exists: bool = False, **kwargs: Any, ) -> None: @@ -55,6 +56,7 @@ def execute(self, context: Context) -> None: :param context: The context in which the operator is being executed. """ self.log.info(f"Trying to setup the ray cluster defined in {self.ray_cluster_yaml}") + self.hook.setup_ray_cluster( context=context, ray_cluster_yaml=self.ray_cluster_yaml, @@ -62,6 +64,7 @@ def execute(self, context: Context) -> None: gpu_device_plugin_yaml=self.gpu_device_plugin_yaml, update_if_exists=self.update_if_exists, ) + self.log.info("Finished setting up the ray cluster.") @@ -282,7 +285,9 @@ def execute(self, context: Context) -> str: self.log.info("::endgroup::") self.log.info("::group:: (SubmitJob 3/5) Submit job") - self.log.info(f"Ray job with id {self.job_id} submitted") + + self.log.info(f"Ray job submitted with id: {self.job_id}") + self.job_id = self.hook.submit_ray_job( dashboard_url=self.dashboard_url, entrypoint=self.entrypoint,