Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 6 additions & 24 deletions examples/nemo_gym/run_grpo_nemo_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

wandb.util.VALUE_BYTES_LIMIT = 10_000_000

import ray
from omegaconf import OmegaConf
from wandb import Table

Expand All @@ -44,13 +43,8 @@
from nemo_rl.algorithms.utils import get_tokenizer
from nemo_rl.data.datasets import AllTaskProcessedDataset
from nemo_rl.data.interfaces import DatumSpec
from nemo_rl.distributed.ray_actor_environment_registry import (
get_actor_python_env,
)
from nemo_rl.distributed.virtual_cluster import init_ray
from nemo_rl.environments.nemo_gym import (
NemoGym,
NemoGymConfig,
nemo_gym_example_to_nemo_rl_datum_spec,
setup_nemo_gym_config,
)
Expand Down Expand Up @@ -233,9 +227,14 @@ def main() -> None:

init_ray()

is_trajectory_collection = (
config["env"]["nemo_gym"].pop("is_trajectory_collection", False) or False
)

(
policy,
policy_generation,
nemo_gym_env,
cluster,
dataloader,
val_dataloader,
Expand All @@ -246,24 +245,7 @@ def main() -> None:
master_config,
) = setup(config, tokenizer, train_dataset, val_dataset)

is_trajectory_collection = (
config["env"]["nemo_gym"].pop("is_trajectory_collection", False) or False
)
nemo_gym_config = NemoGymConfig(
model_name=policy_generation.cfg["model_name"],
base_urls=policy_generation.dp_openai_server_base_urls,
initial_global_config_dict=config["env"]["nemo_gym"],
)
nemo_gym = NemoGym.options(
runtime_env={
"py_executable": get_actor_python_env(
"nemo_rl.environments.nemo_gym.NemoGym"
),
}
).remote(nemo_gym_config)
# Blocking wait for NeMo-Gym to spin up
ray.get(nemo_gym.health_check.remote())
task_to_env = {"nemo_gym": nemo_gym}
task_to_env = {"nemo_gym": nemo_gym_env}
val_task_to_env = task_to_env

if is_trajectory_collection:
Expand Down
195 changes: 188 additions & 7 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,17 @@

import numpy as np
import ray
import ray.util.state
import torch
from ray.actor import ActorProxy
from ray.util.placement_group import (
placement_group,
remove_placement_group,
)
from ray.util.scheduling_strategies import (
NodeAffinitySchedulingStrategy,
PlacementGroupSchedulingStrategy,
)
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import AutoProcessor
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
Expand Down Expand Up @@ -53,8 +63,16 @@
)
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env
from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster
from nemo_rl.distributed.virtual_cluster import (
ClusterConfig,
RayClusterSetupHelper,
RayVirtualCluster,
)
from nemo_rl.environments.interfaces import EnvironmentInterface
from nemo_rl.environments.nemo_gym import (
NemoGym,
NemoGymConfig,
)
from nemo_rl.experience.rollouts import (
run_async_multi_turn_rollout,
run_async_nemo_gym_rollout,
Expand Down Expand Up @@ -191,6 +209,7 @@ def setup(
) -> tuple[
ColocatablePolicyInterface,
Optional[GenerationInterface],
Optional[ActorProxy[NemoGym]],
tuple[RayVirtualCluster, RayVirtualCluster],
StatefulDataLoader,
Optional[StatefulDataLoader],
Expand All @@ -217,6 +236,7 @@ def setup(
data_config = master_config["data"]
logger_config = master_config["logger"]
cluster_config = master_config["cluster"]
enable_nemo_gym = _should_use_nemo_gym(master_config)

assert generation_config is not None, (
"A generation config in the PolicyConfig is required for GRPO"
Expand Down Expand Up @@ -300,6 +320,20 @@ def setup(
)

total_nodes = cluster_config["num_nodes"]
policy_nodes = total_nodes

if enable_nemo_gym:
nemo_gym_num_nodes = env_configs.get("nemo_gym", {}).get("num_gpu_nodes", 0)
nemo_gym_num_gpus_per_node = cluster_config["gpus_per_node"]
else:
nemo_gym_num_nodes = 0
nemo_gym_num_gpus_per_node = 0

if nemo_gym_num_nodes:
assert total_nodes > 1
assert nemo_gym_num_nodes >= 1
policy_nodes -= nemo_gym_num_nodes

if reward_model_enabled:
rm_resource = env_configs["reward_model"]["resources"]
rm_nodes = rm_resource["num_nodes"]
Expand All @@ -309,14 +343,115 @@ def setup(
rm_gpus_per_node = 0

if total_nodes == 1:
policy_nodes = total_nodes
else:
policy_nodes = total_nodes - rm_nodes
assert policy_nodes > 0, (
"policy_nodes must be > 0, but got "
f"policy_nodes:{policy_nodes} + rm_nodes:{rm_nodes} = total_nodes:{total_nodes}"
# TODO: special case for colocated policy + reward model.
pass
elif rm_nodes:
policy_nodes -= rm_nodes

print(
f"policy_nodes:{policy_nodes} + nemo_gym_nodes:{nemo_gym_num_nodes} + rm_nodes:{rm_nodes} = total_nodes:{total_nodes}",
flush=True,
)

assert policy_nodes > 0, (
"policy_nodes must be > 0, but got "
f"policy_nodes:{policy_nodes} + nemo_gym_nodes:{nemo_gym_num_nodes} + rm_nodes:{rm_nodes} = total_nodes:{total_nodes}"
)

ray_runtime_ctx = ray.get_runtime_context()
ray_cur_node_id = ray_runtime_ctx.get_node_id()
ray_namespace = ray_runtime_ctx.namespace

all_node_infos = {}

if nemo_gym_num_nodes:
# Reserve the nemo_gym node(s) here before actually starting nemo_gym.

ray_nodes = ray.util.state.list_nodes()
for node in ray_nodes:
assert node.node_ip not in all_node_infos
all_node_infos[node.node_ip] = {
"node_id": node.node_id,
"node_ip": node.node_ip,
}
del ray_nodes

helper_pgs = []

for nemo_gym_node_idx in range(nemo_gym_num_nodes):
helper_bundles = [{"GPU": nemo_gym_num_gpus_per_node, "CPU": 1}]
helper_pg = placement_group(
bundles=helper_bundles,
strategy="STRICT_PACK",
name=f"nemo_gym-pnode{nemo_gym_node_idx}",
)
try:
ray.get(helper_pg.ready(), timeout=30)
except (TimeoutError, ray.exceptions.GetTimeoutError):
try:
remove_placement_group(helper_pg)
except Exception:
pass
raise TimeoutError(
"Timed out waiting for placement groups to be ready. The cluster may not have enough resources "
"to satisfy the requested configuration, or the resources may be busy with other tasks."
)
helper_pgs.append(helper_pg)

helpers = []
nemo_gym_nodes = []

for nemo_gym_node_idx in range(nemo_gym_num_nodes):
helper_pg = helper_pgs[nemo_gym_node_idx]
helper_options = {}
helper_options["num_gpus"] = nemo_gym_num_gpus_per_node
helper_options["scheduling_strategy"] = PlacementGroupSchedulingStrategy(
placement_group=helper_pg,
placement_group_capture_child_tasks=True,
)
helper = RayClusterSetupHelper.options(**helper_options).remote()
helper_node_info = ray.get(helper._get_node_info.remote())
helpers.append(helper)
nemo_gym_nodes.append(helper_node_info)

for helper in helpers:
ray.kill(helper, no_restart=True)

helpers = []

for nemo_gym_node_idx in range(nemo_gym_num_nodes):
helper_node_id = nemo_gym_nodes[nemo_gym_node_idx]["node_id"]
helper_node_ip = nemo_gym_nodes[nemo_gym_node_idx]["node_ip"]
if not helper_node_id:
helper_node_id = all_node_infos[helper_node_ip]["node_id"]
assert helper_node_id
nemo_gym_nodes[nemo_gym_node_idx]["node_id"] = helper_node_id
helper_options = {}
helper_options["num_gpus"] = nemo_gym_num_gpus_per_node
helper_options["scheduling_strategy"] = NodeAffinitySchedulingStrategy(
node_id=helper_node_id,
soft=False,
)
helper = RayClusterSetupHelper.options(**helper_options).remote()
helpers.append(helper)

for helper_pg in helper_pgs:
try:
remove_placement_group(helper_pg)
except Exception:
pass

nemo_gym_helpers = helpers

print(
f" ✓ Ray cluster for NeMo Gym reserved with {nemo_gym_num_nodes} nodes",
flush=True,
)

else:
nemo_gym_nodes = []
nemo_gym_helpers = []

if colocated_inference:
if total_nodes == 1:
policy_gpus_per_node = cluster_config["gpus_per_node"] - rm_gpus_per_node
Expand Down Expand Up @@ -612,6 +747,51 @@ def init_vllm():
)
print(" ✓ force_on_policy_ratio enabled")

if enable_nemo_gym:
# ==========================
# NeMo Gym
# ==========================
print("\n▶ Setting up NeMo Gym...", flush=True)

nemo_gym_config = NemoGymConfig(
model_name=policy_generation.cfg["model_name"],
base_urls=policy_generation.dp_openai_server_base_urls,
ray_gpu_nodes=[node["node_id"] for node in nemo_gym_nodes],
ray_num_gpus_per_node=nemo_gym_num_gpus_per_node,
ray_namespace=ray_namespace,
initial_global_config_dict=env_configs["nemo_gym"],
)
nemo_gym_py_exec = get_actor_python_env("nemo_rl.environments.nemo_gym.NemoGym")
if nemo_gym_py_exec.startswith("uv"):
# Lazily build a dedicated venv across all Ray nodes on-demand.
nemo_gym_py_exec = create_local_venv_on_each_node(
nemo_gym_py_exec, "nemo_rl.environments.nemo_gym.NemoGym"
)
for nemo_gym_helper in nemo_gym_helpers:
ray.kill(nemo_gym_helper, no_restart=True)
del nemo_gym_helpers
nemo_gym_options = {}
if nemo_gym_num_nodes:
nemo_gym_head_node_id = ray_cur_node_id
nemo_gym_options["scheduling_strategy"] = NodeAffinitySchedulingStrategy(
node_id=nemo_gym_head_node_id,
soft=True,
)
nemo_gym_options["runtime_env"] = {
"py_executable": nemo_gym_py_exec,
"env_vars": {
**os.environ,
"VIRTUAL_ENV": nemo_gym_py_exec,
"UV_PROJECT_ENVIRONMENT": nemo_gym_py_exec,
},
}
nemo_gym_actor = NemoGym.options(**nemo_gym_options).remote(nemo_gym_config)
# Blocking wait for NeMo Gym to spin up
ray.get(nemo_gym_actor._spinup.remote())

else:
nemo_gym_actor = None

# Calculate total setup time
total_setup_time = time.perf_counter() - setup_start_time
worker_init_timing_metrics["total_setup_time_s"] = total_setup_time
Expand Down Expand Up @@ -648,6 +828,7 @@ def init_vllm():
return (
policy,
policy_generation,
nemo_gym_actor,
(train_cluster, inference_cluster),
dataloader,
val_dataloader,
Expand Down
16 changes: 16 additions & 0 deletions nemo_rl/distributed/virtual_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Optional, TypedDict

import ray
from ray.util import get_node_ip_address
from ray.util.placement_group import (
PlacementGroup,
placement_group,
Expand Down Expand Up @@ -503,3 +504,18 @@ def __del__(self) -> None:
user calls shutdown().
"""
self.shutdown()


@ray.remote # pragma: no cover
class RayClusterSetupHelper:
def __init__(self, *init_args, **init_kwargs):
self.init_args = init_args
self.init_kwargs = init_kwargs

def _get_node_info(self) -> dict:
try:
node_id = ray.get_runtime_context().get_node_id()
except Exception:
node_id = None
node_ip = get_node_ip_address()
return {"node_id": node_id, "node_ip": node_ip}
Loading
Loading