From 120ed9976ba3a244061986bea643bc84690e31f7 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Sat, 19 Apr 2025 14:55:36 -0700 Subject: [PATCH] add Signed-off-by: Terry Kong --- .../distributed/virtual_cluster.py | 29 ++++++- .../unit/distributed/test_virtual_cluster.py | 81 +++++++++++++++++++ 2 files changed, 107 insertions(+), 3 deletions(-) diff --git a/nemo_reinforcer/distributed/virtual_cluster.py b/nemo_reinforcer/distributed/virtual_cluster.py index 4f19fb821f..8b6600353c 100644 --- a/nemo_reinforcer/distributed/virtual_cluster.py +++ b/nemo_reinforcer/distributed/virtual_cluster.py @@ -19,6 +19,7 @@ import os import ray import logging +import time from ray.util.placement_group import placement_group, remove_placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -101,6 +102,10 @@ def init_ray(log_dir: Optional[str] = None): logger.info(f"Started local cluster with: {ray.cluster_resources()}") +class ResourceInsufficientError(Exception): + """Exception raised when the cluster does not have enough resources to satisfy the requested configuration.""" + + class RayVirtualCluster: """Creates a virtual distributed cluster using Ray placement groups. @@ -146,7 +151,25 @@ def __init__( ) self.max_colocated_worker_groups = max_colocated_worker_groups self.name = name - self._init_placement_groups(placement_group_strategy) + max_retries = int(os.environ.get("NRL_VIRTUAL_CLUSTER_MAX_RETRIES", 6)) + assert max_retries > 0, ( + f"NRL_VIRTUAL_CLUSTER_MAX_RETRIES={max_retries} must be an integer greater than 0" + ) + for i in range(max_retries): + try: + self._init_placement_groups(placement_group_strategy) + # Reaching here means we were successful + break + except ResourceInsufficientError: + print( + f"Retrying placement group creation... {i + 1}/{max_retries}. Next retry in {2**i} seconds." + ) + time.sleep(2**i) + continue + else: + raise ResourceInsufficientError( + f"Maximum number of retries reached ({max_retries}). Cluster resources may be insufficient or cluster itself is highly unstable. Please check your cluster configuration and your cluster logs." + ) def _init_placement_groups(self, strategy: str): """Creates placement groups for each node in the cluster. Has empty groups for nodes that don't have any bundles. @@ -175,12 +198,12 @@ def _init_placement_groups(self, strategy: str): # Validate resources if self.use_gpus and total_requested_gpus > total_available_gpus: - raise ValueError( + raise ResourceInsufficientError( f"Not enough GPUs available. Requested {total_requested_gpus} GPUs, but only {total_available_gpus} are available in the cluster." ) if total_requested_cpus > total_available_cpus: - raise ValueError( + raise ResourceInsufficientError( f"Not enough CPUs available. Requested {total_requested_cpus} CPUs, but only {total_available_cpus} are available in the cluster." ) diff --git a/tests/unit/distributed/test_virtual_cluster.py b/tests/unit/distributed/test_virtual_cluster.py index 4d01dd24f0..bdf24f3b18 100644 --- a/tests/unit/distributed/test_virtual_cluster.py +++ b/tests/unit/distributed/test_virtual_cluster.py @@ -14,8 +14,14 @@ from nemo_reinforcer.distributed.virtual_cluster import ( _get_node_ip_and_free_port, PY_EXECUTABLES, + RayVirtualCluster, + ResourceInsufficientError, ) import ray +import pytest +import os +from unittest.mock import patch, MagicMock +import importlib def test_get_node_ip_and_free_port_does_not_start_with_zero(): @@ -30,3 +36,78 @@ def test_get_node_ip_and_free_port_does_not_start_with_zero(): ).remote() ) assert not node_ip.startswith("0."), "Node IP should not start with 0.*.*.*" + + +def test_env_max_retries_invalid_value(): + """Test that NRL_VIRTUAL_CLUSTER_MAX_RETRIES rejects invalid values (less than or equal to zero).""" + + # Mock environment with invalid max_retries value + env_vars = {"NRL_VIRTUAL_CLUSTER_MAX_RETRIES": "0"} + + with patch.dict(os.environ, env_vars, clear=True): + with pytest.raises(AssertionError): + RayVirtualCluster(bundle_ct_per_node_list=[1]) + + +def test_env_max_retries_non_integer(): + """Test that NRL_VIRTUAL_CLUSTER_MAX_RETRIES handles non-integer values properly.""" + + # Mock environment with non-integer max_retries value + env_vars = {"NRL_VIRTUAL_CLUSTER_MAX_RETRIES": "not_a_number"} + + with patch.dict(os.environ, env_vars, clear=True): + with pytest.raises(ValueError): + RayVirtualCluster(bundle_ct_per_node_list=[1]) + + +def test_env_max_retries_default_value(): + """Test that default value for NRL_VIRTUAL_CLUSTER_MAX_RETRIES is used when not set.""" + + # Ensure environment variable is not set + with ( + patch.dict(os.environ, {}, clear=True), + patch( + "nemo_reinforcer.distributed.virtual_cluster.RayVirtualCluster._init_placement_groups" + ) as mock_init, + ): + # Mock successful initialization + mock_init.return_value = [MagicMock()] + + # Create cluster + cluster = RayVirtualCluster(bundle_ct_per_node_list=[1]) + + # Default value should be 6 (as seen in the code) + # We can't directly verify this, but we can check that initialization was attempted + assert mock_init.call_count == 1 + + +def test_env_max_retries_exhausted(): + """Test that NRL_VIRTUAL_CLUSTER_MAX_RETRIES correctly handles the case where all retries fail.""" + + # Set specific retry count to 4 + retry_count = 4 + env_vars = {"NRL_VIRTUAL_CLUSTER_MAX_RETRIES": str(retry_count)} + + with ( + patch.dict(os.environ, env_vars, clear=True), + patch( + "nemo_reinforcer.distributed.virtual_cluster.RayVirtualCluster._init_placement_groups" + ) as mock_init, + patch("time.sleep") as mock_sleep, + ): + # Make _init_placement_groups raise ResourceInsufficientError each time + mock_init.side_effect = ResourceInsufficientError("Not enough resources") + + # Create cluster - should retry retry_count times and then fail + with pytest.raises(ResourceInsufficientError): + RayVirtualCluster(bundle_ct_per_node_list=[1]) + + # Verify _init_placement_groups was called exactly retry_count times + assert mock_init.call_count == retry_count + + # Verify time.sleep was called with exponentially increasing values + assert mock_sleep.call_count == retry_count + mock_sleep.assert_any_call(1) # 2^0 + mock_sleep.assert_any_call(2) # 2^1 + mock_sleep.assert_any_call(4) # 2^2 + mock_sleep.assert_any_call(8) # 2^3