diff --git a/examples/nemo_gym/run_grpo_nemo_gym.py b/examples/nemo_gym/run_grpo_nemo_gym.py index c8d2c911e2..df3c78ddc3 100644 --- a/examples/nemo_gym/run_grpo_nemo_gym.py +++ b/examples/nemo_gym/run_grpo_nemo_gym.py @@ -24,7 +24,6 @@ wandb.util.VALUE_BYTES_LIMIT = 10_000_000 -import ray from omegaconf import OmegaConf from wandb import Table @@ -44,13 +43,8 @@ from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.data.datasets import AllTaskProcessedDataset from nemo_rl.data.interfaces import DatumSpec -from nemo_rl.distributed.ray_actor_environment_registry import ( - get_actor_python_env, -) from nemo_rl.distributed.virtual_cluster import init_ray from nemo_rl.environments.nemo_gym import ( - NemoGym, - NemoGymConfig, nemo_gym_example_to_nemo_rl_datum_spec, setup_nemo_gym_config, ) @@ -233,9 +227,14 @@ def main() -> None: init_ray() + is_trajectory_collection = ( + config["env"]["nemo_gym"].pop("is_trajectory_collection", False) or False + ) + ( policy, policy_generation, + nemo_gym_env, cluster, dataloader, val_dataloader, @@ -246,24 +245,7 @@ def main() -> None: master_config, ) = setup(config, tokenizer, train_dataset, val_dataset) - is_trajectory_collection = ( - config["env"]["nemo_gym"].pop("is_trajectory_collection", False) or False - ) - nemo_gym_config = NemoGymConfig( - model_name=policy_generation.cfg["model_name"], - base_urls=policy_generation.dp_openai_server_base_urls, - initial_global_config_dict=config["env"]["nemo_gym"], - ) - nemo_gym = NemoGym.options( - runtime_env={ - "py_executable": get_actor_python_env( - "nemo_rl.environments.nemo_gym.NemoGym" - ), - } - ).remote(nemo_gym_config) - # Blocking wait for NeMo-Gym to spin up - ray.get(nemo_gym.health_check.remote()) - task_to_env = {"nemo_gym": nemo_gym} + task_to_env = {"nemo_gym": nemo_gym_env} val_task_to_env = task_to_env if is_trajectory_collection: diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index d79b6d2fac..d5dffaf681 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -22,7 +22,17 @@ import numpy as np import ray +import ray.util.state import torch +from ray.actor import ActorProxy +from ray.util.placement_group import ( + placement_group, + remove_placement_group, +) +from ray.util.scheduling_strategies import ( + NodeAffinitySchedulingStrategy, + PlacementGroupSchedulingStrategy, +) from torchdata.stateful_dataloader import StatefulDataLoader from transformers import AutoProcessor from transformers.tokenization_utils_base import PreTrainedTokenizerBase @@ -53,8 +63,16 @@ ) from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env -from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster +from nemo_rl.distributed.virtual_cluster import ( + ClusterConfig, + RayClusterSetupHelper, + RayVirtualCluster, +) from nemo_rl.environments.interfaces import EnvironmentInterface +from nemo_rl.environments.nemo_gym import ( + NemoGym, + NemoGymConfig, +) from nemo_rl.experience.rollouts import ( run_async_multi_turn_rollout, run_async_nemo_gym_rollout, @@ -191,6 +209,7 @@ def setup( ) -> tuple[ ColocatablePolicyInterface, Optional[GenerationInterface], + Optional[ActorProxy[NemoGym]], tuple[RayVirtualCluster, RayVirtualCluster], StatefulDataLoader, Optional[StatefulDataLoader], @@ -217,6 +236,7 @@ def setup( data_config = master_config["data"] logger_config = master_config["logger"] cluster_config = master_config["cluster"] + enable_nemo_gym = _should_use_nemo_gym(master_config) assert generation_config is not None, ( "A generation config in the PolicyConfig is required for GRPO" @@ -300,6 +320,20 @@ def setup( ) total_nodes = cluster_config["num_nodes"] + policy_nodes = total_nodes + + if enable_nemo_gym: + nemo_gym_num_nodes = env_configs.get("nemo_gym", {}).get("num_gpu_nodes", 0) + nemo_gym_num_gpus_per_node = cluster_config["gpus_per_node"] + else: + nemo_gym_num_nodes = 0 + nemo_gym_num_gpus_per_node = 0 + + if nemo_gym_num_nodes: + assert total_nodes > 1 + assert nemo_gym_num_nodes >= 1 + policy_nodes -= nemo_gym_num_nodes + if reward_model_enabled: rm_resource = env_configs["reward_model"]["resources"] rm_nodes = rm_resource["num_nodes"] @@ -309,14 +343,115 @@ def setup( rm_gpus_per_node = 0 if total_nodes == 1: - policy_nodes = total_nodes - else: - policy_nodes = total_nodes - rm_nodes - assert policy_nodes > 0, ( - "policy_nodes must be > 0, but got " - f"policy_nodes:{policy_nodes} + rm_nodes:{rm_nodes} = total_nodes:{total_nodes}" + # TODO: special case for colocated policy + reward model. + pass + elif rm_nodes: + policy_nodes -= rm_nodes + + print( + f"policy_nodes:{policy_nodes} + nemo_gym_nodes:{nemo_gym_num_nodes} + rm_nodes:{rm_nodes} = total_nodes:{total_nodes}", + flush=True, + ) + + assert policy_nodes > 0, ( + "policy_nodes must be > 0, but got " + f"policy_nodes:{policy_nodes} + nemo_gym_nodes:{nemo_gym_num_nodes} + rm_nodes:{rm_nodes} = total_nodes:{total_nodes}" + ) + + ray_runtime_ctx = ray.get_runtime_context() + ray_cur_node_id = ray_runtime_ctx.get_node_id() + ray_namespace = ray_runtime_ctx.namespace + + all_node_infos = {} + + if nemo_gym_num_nodes: + # Reserve the nemo_gym node(s) here before actually starting nemo_gym. + + ray_nodes = ray.util.state.list_nodes() + for node in ray_nodes: + assert node.node_ip not in all_node_infos + all_node_infos[node.node_ip] = { + "node_id": node.node_id, + "node_ip": node.node_ip, + } + del ray_nodes + + helper_pgs = [] + + for nemo_gym_node_idx in range(nemo_gym_num_nodes): + helper_bundles = [{"GPU": nemo_gym_num_gpus_per_node, "CPU": 1}] + helper_pg = placement_group( + bundles=helper_bundles, + strategy="STRICT_PACK", + name=f"nemo_gym-pnode{nemo_gym_node_idx}", + ) + try: + ray.get(helper_pg.ready(), timeout=30) + except (TimeoutError, ray.exceptions.GetTimeoutError): + try: + remove_placement_group(helper_pg) + except Exception: + pass + raise TimeoutError( + "Timed out waiting for placement groups to be ready. The cluster may not have enough resources " + "to satisfy the requested configuration, or the resources may be busy with other tasks." + ) + helper_pgs.append(helper_pg) + + helpers = [] + nemo_gym_nodes = [] + + for nemo_gym_node_idx in range(nemo_gym_num_nodes): + helper_pg = helper_pgs[nemo_gym_node_idx] + helper_options = {} + helper_options["num_gpus"] = nemo_gym_num_gpus_per_node + helper_options["scheduling_strategy"] = PlacementGroupSchedulingStrategy( + placement_group=helper_pg, + placement_group_capture_child_tasks=True, + ) + helper = RayClusterSetupHelper.options(**helper_options).remote() + helper_node_info = ray.get(helper._get_node_info.remote()) + helpers.append(helper) + nemo_gym_nodes.append(helper_node_info) + + for helper in helpers: + ray.kill(helper, no_restart=True) + + helpers = [] + + for nemo_gym_node_idx in range(nemo_gym_num_nodes): + helper_node_id = nemo_gym_nodes[nemo_gym_node_idx]["node_id"] + helper_node_ip = nemo_gym_nodes[nemo_gym_node_idx]["node_ip"] + if not helper_node_id: + helper_node_id = all_node_infos[helper_node_ip]["node_id"] + assert helper_node_id + nemo_gym_nodes[nemo_gym_node_idx]["node_id"] = helper_node_id + helper_options = {} + helper_options["num_gpus"] = nemo_gym_num_gpus_per_node + helper_options["scheduling_strategy"] = NodeAffinitySchedulingStrategy( + node_id=helper_node_id, + soft=False, + ) + helper = RayClusterSetupHelper.options(**helper_options).remote() + helpers.append(helper) + + for helper_pg in helper_pgs: + try: + remove_placement_group(helper_pg) + except Exception: + pass + + nemo_gym_helpers = helpers + + print( + f" āœ“ Ray cluster for NeMo Gym reserved with {nemo_gym_num_nodes} nodes", + flush=True, ) + else: + nemo_gym_nodes = [] + nemo_gym_helpers = [] + if colocated_inference: if total_nodes == 1: policy_gpus_per_node = cluster_config["gpus_per_node"] - rm_gpus_per_node @@ -612,6 +747,51 @@ def init_vllm(): ) print(" āœ“ force_on_policy_ratio enabled") + if enable_nemo_gym: + # ========================== + # NeMo Gym + # ========================== + print("\nā–¶ Setting up NeMo Gym...", flush=True) + + nemo_gym_config = NemoGymConfig( + model_name=policy_generation.cfg["model_name"], + base_urls=policy_generation.dp_openai_server_base_urls, + ray_gpu_nodes=[node["node_id"] for node in nemo_gym_nodes], + ray_num_gpus_per_node=nemo_gym_num_gpus_per_node, + ray_namespace=ray_namespace, + initial_global_config_dict=env_configs["nemo_gym"], + ) + nemo_gym_py_exec = get_actor_python_env("nemo_rl.environments.nemo_gym.NemoGym") + if nemo_gym_py_exec.startswith("uv"): + # Lazily build a dedicated venv across all Ray nodes on-demand. + nemo_gym_py_exec = create_local_venv_on_each_node( + nemo_gym_py_exec, "nemo_rl.environments.nemo_gym.NemoGym" + ) + for nemo_gym_helper in nemo_gym_helpers: + ray.kill(nemo_gym_helper, no_restart=True) + del nemo_gym_helpers + nemo_gym_options = {} + if nemo_gym_num_nodes: + nemo_gym_head_node_id = ray_cur_node_id + nemo_gym_options["scheduling_strategy"] = NodeAffinitySchedulingStrategy( + node_id=nemo_gym_head_node_id, + soft=True, + ) + nemo_gym_options["runtime_env"] = { + "py_executable": nemo_gym_py_exec, + "env_vars": { + **os.environ, + "VIRTUAL_ENV": nemo_gym_py_exec, + "UV_PROJECT_ENVIRONMENT": nemo_gym_py_exec, + }, + } + nemo_gym_actor = NemoGym.options(**nemo_gym_options).remote(nemo_gym_config) + # Blocking wait for NeMo Gym to spin up + ray.get(nemo_gym_actor._spinup.remote()) + + else: + nemo_gym_actor = None + # Calculate total setup time total_setup_time = time.perf_counter() - setup_start_time worker_init_timing_metrics["total_setup_time_s"] = total_setup_time @@ -648,6 +828,7 @@ def init_vllm(): return ( policy, policy_generation, + nemo_gym_actor, (train_cluster, inference_cluster), dataloader, val_dataloader, diff --git a/nemo_rl/distributed/virtual_cluster.py b/nemo_rl/distributed/virtual_cluster.py index 3021b760e4..7591278eae 100644 --- a/nemo_rl/distributed/virtual_cluster.py +++ b/nemo_rl/distributed/virtual_cluster.py @@ -18,6 +18,7 @@ from typing import Optional, TypedDict import ray +from ray.util import get_node_ip_address from ray.util.placement_group import ( PlacementGroup, placement_group, @@ -503,3 +504,18 @@ def __del__(self) -> None: user calls shutdown(). """ self.shutdown() + + +@ray.remote # pragma: no cover +class RayClusterSetupHelper: + def __init__(self, *init_args, **init_kwargs): + self.init_args = init_args + self.init_kwargs = init_kwargs + + def _get_node_info(self) -> dict: + try: + node_id = ray.get_runtime_context().get_node_id() + except Exception: + node_id = None + node_ip = get_node_ip_address() + return {"node_id": node_id, "node_ip": node_ip} diff --git a/nemo_rl/environments/nemo_gym.py b/nemo_rl/environments/nemo_gym.py index 83e9858b8e..113a37edd3 100644 --- a/nemo_rl/environments/nemo_gym.py +++ b/nemo_rl/environments/nemo_gym.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from pathlib import Path -from typing import Any, Dict, List, TypedDict +from typing import Any, Dict, List, Optional, TypedDict import ray import torch @@ -27,6 +27,9 @@ class NemoGymConfig(TypedDict): model_name: str base_urls: List[str] + ray_gpu_nodes: List[str] + ray_num_gpus_per_node: Optional[int] + ray_namespace: Optional[str] initial_global_config_dict: Dict[str, Any] @@ -37,6 +40,7 @@ class NemoGym(EnvironmentInterface): def __init__(self, cfg: NemoGymConfig): self.cfg = cfg + def _spinup(self) -> None: self.node_ip = _get_node_ip_local() self.head_server_port = _get_free_port_local() @@ -77,6 +81,22 @@ def __init__(self, cfg: NemoGymConfig): initial_global_config_dict["ray_head_node_address"] = ray_context.gcs_address print(f"Ray head node address: {ray_context.gcs_address}") + ray_namespace = self.cfg.get("ray_namespace", None) + if ray_namespace is not None: + initial_global_config_dict["ray_namespace"] = ray_namespace + print(f"Ray namespace: {ray_namespace}") + + initial_global_config_dict["ray_gpu_nodes"] = self.cfg["ray_gpu_nodes"] + initial_global_config_dict["ray_num_gpus_per_node"] = self.cfg[ + "ray_num_gpus_per_node" + ] + print( + f"Ray reserved GPU nodes: {len(initial_global_config_dict['ray_gpu_nodes'])}" + ) + print( + f"Ray num GPUs per node: {initial_global_config_dict['ray_num_gpus_per_node']}" + ) + # Head server initial_global_config_dict[HEAD_SERVER_KEY_NAME] = { "host": "0.0.0.0", @@ -100,9 +120,6 @@ def __init__(self, cfg: NemoGymConfig): ) self.rch = RolloutCollectionHelper() - def health_check(self) -> bool: - return True - async def run_rollouts( self, nemo_gym_examples: list[dict],