Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions nemo_reinforcer/distributed/virtual_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
import ray
import logging
import time
from ray.util.placement_group import placement_group, remove_placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

Expand Down Expand Up @@ -101,6 +102,10 @@ def init_ray(log_dir: Optional[str] = None):
logger.info(f"Started local cluster with: {ray.cluster_resources()}")


class ResourceInsufficientError(Exception):
"""Exception raised when the cluster does not have enough resources to satisfy the requested configuration."""


class RayVirtualCluster:
"""Creates a virtual distributed cluster using Ray placement groups.

Expand Down Expand Up @@ -146,7 +151,25 @@ def __init__(
)
self.max_colocated_worker_groups = max_colocated_worker_groups
self.name = name
self._init_placement_groups(placement_group_strategy)
max_retries = int(os.environ.get("NRL_VIRTUAL_CLUSTER_MAX_RETRIES", 6))
assert max_retries > 0, (
f"NRL_VIRTUAL_CLUSTER_MAX_RETRIES={max_retries} must be an integer greater than 0"
)
for i in range(max_retries):
try:
self._init_placement_groups(placement_group_strategy)
# Reaching here means we were successful
break
except ResourceInsufficientError:
print(
f"Retrying placement group creation... {i + 1}/{max_retries}. Next retry in {2**i} seconds."
)
time.sleep(2**i)
continue
else:
raise ResourceInsufficientError(
f"Maximum number of retries reached ({max_retries}). Cluster resources may be insufficient or cluster itself is highly unstable. Please check your cluster configuration and your cluster logs."
)

def _init_placement_groups(self, strategy: str):
"""Creates placement groups for each node in the cluster. Has empty groups for nodes that don't have any bundles.
Expand Down Expand Up @@ -175,12 +198,12 @@ def _init_placement_groups(self, strategy: str):

# Validate resources
if self.use_gpus and total_requested_gpus > total_available_gpus:
raise ValueError(
raise ResourceInsufficientError(
f"Not enough GPUs available. Requested {total_requested_gpus} GPUs, but only {total_available_gpus} are available in the cluster."
)

if total_requested_cpus > total_available_cpus:
raise ValueError(
raise ResourceInsufficientError(
f"Not enough CPUs available. Requested {total_requested_cpus} CPUs, but only {total_available_cpus} are available in the cluster."
)

Expand Down
81 changes: 81 additions & 0 deletions tests/unit/distributed/test_virtual_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,14 @@
from nemo_reinforcer.distributed.virtual_cluster import (
_get_node_ip_and_free_port,
PY_EXECUTABLES,
RayVirtualCluster,
ResourceInsufficientError,
)
import ray
import pytest
import os
from unittest.mock import patch, MagicMock
import importlib


def test_get_node_ip_and_free_port_does_not_start_with_zero():
Expand All @@ -30,3 +36,78 @@ def test_get_node_ip_and_free_port_does_not_start_with_zero():
).remote()
)
assert not node_ip.startswith("0."), "Node IP should not start with 0.*.*.*"


def test_env_max_retries_invalid_value():
"""Test that NRL_VIRTUAL_CLUSTER_MAX_RETRIES rejects invalid values (less than or equal to zero)."""

# Mock environment with invalid max_retries value
env_vars = {"NRL_VIRTUAL_CLUSTER_MAX_RETRIES": "0"}

with patch.dict(os.environ, env_vars, clear=True):
with pytest.raises(AssertionError):
RayVirtualCluster(bundle_ct_per_node_list=[1])


def test_env_max_retries_non_integer():
"""Test that NRL_VIRTUAL_CLUSTER_MAX_RETRIES handles non-integer values properly."""

# Mock environment with non-integer max_retries value
env_vars = {"NRL_VIRTUAL_CLUSTER_MAX_RETRIES": "not_a_number"}

with patch.dict(os.environ, env_vars, clear=True):
with pytest.raises(ValueError):
RayVirtualCluster(bundle_ct_per_node_list=[1])


def test_env_max_retries_default_value():
"""Test that default value for NRL_VIRTUAL_CLUSTER_MAX_RETRIES is used when not set."""

# Ensure environment variable is not set
with (
patch.dict(os.environ, {}, clear=True),
patch(
"nemo_reinforcer.distributed.virtual_cluster.RayVirtualCluster._init_placement_groups"
) as mock_init,
):
# Mock successful initialization
mock_init.return_value = [MagicMock()]

# Create cluster
cluster = RayVirtualCluster(bundle_ct_per_node_list=[1])

# Default value should be 6 (as seen in the code)
# We can't directly verify this, but we can check that initialization was attempted
assert mock_init.call_count == 1


def test_env_max_retries_exhausted():
"""Test that NRL_VIRTUAL_CLUSTER_MAX_RETRIES correctly handles the case where all retries fail."""

# Set specific retry count to 4
retry_count = 4
env_vars = {"NRL_VIRTUAL_CLUSTER_MAX_RETRIES": str(retry_count)}

with (
patch.dict(os.environ, env_vars, clear=True),
patch(
"nemo_reinforcer.distributed.virtual_cluster.RayVirtualCluster._init_placement_groups"
) as mock_init,
patch("time.sleep") as mock_sleep,
):
# Make _init_placement_groups raise ResourceInsufficientError each time
mock_init.side_effect = ResourceInsufficientError("Not enough resources")

# Create cluster - should retry retry_count times and then fail
with pytest.raises(ResourceInsufficientError):
RayVirtualCluster(bundle_ct_per_node_list=[1])

# Verify _init_placement_groups was called exactly retry_count times
assert mock_init.call_count == retry_count

# Verify time.sleep was called with exponentially increasing values
assert mock_sleep.call_count == retry_count
mock_sleep.assert_any_call(1) # 2^0
mock_sleep.assert_any_call(2) # 2^1
mock_sleep.assert_any_call(4) # 2^2
mock_sleep.assert_any_call(8) # 2^3