Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions nemo_rl/distributed/worker_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def __init__(
name_prefix: str = "",
bundle_indices_list: Optional[list[tuple[int, list[int]]]] = None,
sharding_annotations: Optional[NamedSharding] = None,
env_vars: dict[str, str] = {},
):
"""Initialize a group of distributed Ray workers.

Expand Down Expand Up @@ -391,7 +392,7 @@ def __init__(

# Create workers based on the bundle_indices_list
self._create_workers_from_bundle_indices(
remote_worker_builder, bundle_indices_list
remote_worker_builder, bundle_indices_list, env_vars=env_vars
)

def get_dp_leader_worker_idx(self, dp_shard_idx: int) -> int:
Expand All @@ -407,6 +408,7 @@ def _create_workers_from_bundle_indices(
self,
remote_worker_builder: RayWorkerBuilder,
bundle_indices_list: list[tuple[int, list[int]]],
env_vars: dict[str, str] = {},
) -> None:
"""Create workers based on explicit bundle indices for tied worker groups.

Expand All @@ -421,6 +423,10 @@ def _create_workers_from_bundle_indices(
self.cluster.get_master_address_and_port()
)

# Update env_vars with the current environment variables
env_vars.update(dict(os.environ))

# Get the python environment for the actor
actor_python_env = get_actor_python_env(
remote_worker_builder.ray_actor_class_fqn
)
Expand Down Expand Up @@ -459,8 +465,8 @@ def _create_workers_from_bundle_indices(

for local_rank, bundle_idx in enumerate(local_bundle_indices):
# Set up basic distributed environment variables
env_vars = dict(os.environ)
env_vars.update(
worker_env_vars = deepcopy(env_vars)
worker_env_vars.update(
{
"RANK": str(global_rank),
"LOCAL_RANK": str(bundle_idx),
Expand All @@ -470,7 +476,7 @@ def _create_workers_from_bundle_indices(
"NODE_RANK": str(pg_idx),
}
)
env_vars.pop("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", None)
worker_env_vars.pop("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", None)

# Only the first worker in each group gets bundle_indices
# This ensures only one worker per group is the model owner
Expand All @@ -494,7 +500,10 @@ def _create_workers_from_bundle_indices(
)

# Pass these options to the remote_worker_builder
runtime_env = {"env_vars": env_vars, "py_executable": py_executable}
runtime_env = {
"env_vars": worker_env_vars,
"py_executable": py_executable,
}
runtime_env["env_vars"]["VIRTUAL_ENV"] = py_executable
runtime_env["env_vars"]["UV_PROJECT_ENVIRONMENT"] = py_executable

Expand Down
13 changes: 9 additions & 4 deletions nemo_rl/models/generation/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,10 +319,6 @@ def _patch_vllm_init_workers_ray():
os.environ["VLLM_USE_V1"] = os.environ.get("NRL_VLLM_USE_V1", "1")
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"

if not self.cfg["colocated"]["enabled"]:
os.environ["NCCL_SHM_DISABLE"] = "1"
os.environ["NCCL_P2P_DISABLE"] = "1"

load_format = self.cfg["vllm_cfg"]["load_format"]
if ModelFlag.VLLM_LOAD_FORMAT_AUTO.matches(self.model_name):
load_format = "auto"
Expand Down Expand Up @@ -1225,6 +1221,13 @@ def __init__(
"nemo_rl.models.generation.vllm.VllmGenerationWorker", config
)

# It's necessary to set env_vars here to ensure that vllm non-leader workers also have these env_vars
# Disable NCCL SHM if training and generation are not co-located: https://github.com/NVIDIA-NeMo/RL/issues/564
env_vars = {}
if not self.cfg["colocated"]["enabled"]:
env_vars["NCCL_SHM_DISABLE"] = "1"
env_vars["NCCL_P2P_DISABLE"] = "1"

# Check if we need parallelism-aware worker group creation
if self.model_parallel_size > 1:
# For parallelism, create node-aware worker groups
Expand All @@ -1236,6 +1239,7 @@ def __init__(
name_prefix=name_prefix,
bundle_indices_list=node_bundle_indices,
sharding_annotations=self.sharding_annotations,
env_vars=env_vars,
)
else:
# Use standard worker group creation for non-parallel case
Expand All @@ -1245,6 +1249,7 @@ def __init__(
name_prefix=name_prefix,
workers_per_node=workers_per_node,
sharding_annotations=self.sharding_annotations,
env_vars=env_vars,
)

# Number of data parallel groups is the number of tied worker groups
Expand Down
2 changes: 1 addition & 1 deletion nemo_rl/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def init_collective(self, ip: str, port: int, world_size: int) -> None:
from vllm.distributed.utils import StatelessProcessGroup

# keep the same behavior as vllm
# see https://github.com/vllm-project/vllm/blob/v0.8.5/vllm/env_override.py#L25
# see https://github.com/vllm-project/vllm/blob/v0.9.0/vllm/env_override.py#L25
if not os.path.exists("/dev/nvidia-caps-imex-channels"):
os.environ["NCCL_CUMEM_ENABLE"] = "0"

Expand Down
36 changes: 23 additions & 13 deletions tests/unit/models/generation/test_vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,17 +264,15 @@ def policy_cluster_separate():
print(f"Error during policy_cluster_separate shutdown: {e}")


@pytest.fixture(scope="function")
def generation_cluster_separate():
"""Create a virtual cluster for the VllmGeneration policy, using 1 GPU."""
cluster = _create_ray_virtual_cluster_for_test(
"vllm-test-generation-cluster-separate"
def get_generation_cluster_separate(num_gpus_per_node: int = 1) -> RayVirtualCluster:
"""Create a virtual cluster for the VllmGeneration policy, using num_gpus_per_node GPU."""
return RayVirtualCluster(
bundle_ct_per_node_list=[num_gpus_per_node],
use_gpus=True,
max_colocated_worker_groups=1,
num_gpus_per_node=num_gpus_per_node,
name="vllm-test-generation-cluster-separate",
)
yield cluster
try:
cluster.shutdown()
except Exception as e:
print(f"Error during generation_cluster_separate shutdown: {e}")


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -1177,13 +1175,22 @@ def test_vllm_non_divisible_batch_handling(policy):

@pytest.mark.asyncio
@pytest.mark.parametrize("async_engine", [True, False])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
async def test_vllm_refit_non_collocated_update_weights(
policy_cluster_separate,
generation_cluster_separate,
tokenizer,
test_input_data,
async_engine,
tensor_parallel_size,
):
# Skip tensor_parallel_size == 2 until we have resources in CI
if tensor_parallel_size == 2:
pytest.skip(
"Test requires at least three GPUs to run with tensor_parallel_size == 2 on separate clusters."
)

generation_cluster_separate = get_generation_cluster_separate(tensor_parallel_size)

if (
policy_cluster_separate.num_gpus_per_node < 1
or generation_cluster_separate.num_gpus_per_node < 1
Expand All @@ -1194,15 +1201,14 @@ async def test_vllm_refit_non_collocated_update_weights(

# Create Policy on its own cluster
hf_config = get_basic_hf_test_config(enable_dtensor=True)
hf_config["dtensor_cfg"]["tensor_parallel_size"] = 1
hf_config["generation"]["colocated"]["enabled"] = False
lm_policy = Policy(policy_cluster_separate, hf_config, tokenizer)

# Create VllmGeneration policy on its own cluster
vllm_config = deepcopy(basic_vllm_test_config)
vllm_config = configure_generation_config(vllm_config, tokenizer, is_eval=True)
vllm_config["vllm_cfg"]["async_engine"] = async_engine
vllm_config["vllm_cfg"]["tensor_parallel_size"] = 1
vllm_config["vllm_cfg"]["tensor_parallel_size"] = tensor_parallel_size
vllm_config["colocated"]["enabled"] = False
vllm_generation = VllmGeneration(generation_cluster_separate, vllm_config)

Expand Down Expand Up @@ -1234,6 +1240,10 @@ async def test_vllm_refit_non_collocated_update_weights(
# Clean up
vllm_generation.shutdown()
lm_policy.shutdown()
try:
generation_cluster_separate.shutdown()
except Exception as e:
print(f"Error during generation_cluster_separate shutdown: {e}")


@pytest.mark.timeout(210)
Expand Down
Loading