Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions nemo_rl/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@

from nemo_rl.algorithms.interfaces import LossFunction, LossType
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.distributed.worker_group_utils import get_nsight_config_if_pattern_matches
from nemo_rl.models.dtensor.parallelize import (
_parallelize_model,
clip_grad_by_total_norm_,
Expand All @@ -57,6 +56,7 @@
)
from nemo_rl.models.policy.utils import (
get_gpu_info,
get_runtime_env_for_policy_worker,
import_class_from_path,
sliding_window_overwrite,
)
Expand Down Expand Up @@ -114,13 +114,7 @@ def get_cpu_state_dict(
return new_state_dict


@ray.remote(
runtime_env={
# TODO: This option causes a crash on Ampere. It's okay to enable on Hopper.
# "env_vars": {"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True"},
**get_nsight_config_if_pattern_matches("dtensor_policy_worker"),
}
)
@ray.remote(runtime_env=get_runtime_env_for_policy_worker("dtensor_policy_worker"))
class DTensorPolicyWorker:
def __repr__(self) -> str:
"""Customizes the actor's prefix in the Ray logs.
Expand Down
9 changes: 2 additions & 7 deletions nemo_rl/models/policy/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.distributed.model_utils import from_parallel_logits_to_logprobs
from nemo_rl.distributed.named_sharding import NamedSharding
from nemo_rl.distributed.worker_group_utils import get_nsight_config_if_pattern_matches
from nemo_rl.models.generation.interfaces import (
GenerationDatumSpec,
GenerationOutputSpec,
Expand All @@ -112,7 +111,7 @@
LogprobOutputSpec,
ReferenceLogprobOutputSpec,
)
from nemo_rl.models.policy.utils import get_gpu_info
from nemo_rl.models.policy.utils import get_gpu_info, get_runtime_env_for_policy_worker

TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase)

Expand Down Expand Up @@ -322,11 +321,7 @@ def destroy_parallel_state():
pass


@ray.remote(
runtime_env={
**get_nsight_config_if_pattern_matches("megatron_policy_worker"),
}
)
@ray.remote(runtime_env=get_runtime_env_for_policy_worker("megatron_policy_worker"))
class MegatronPolicyWorker:
def __repr__(self):
"""Customizes the actor's prefix in the Ray logs.
Expand Down
26 changes: 26 additions & 0 deletions nemo_rl/models/policy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import torch
from transformers import AutoConfig

from nemo_rl.distributed.worker_group_utils import get_nsight_config_if_pattern_matches


def import_class_from_path(name: str) -> Any:
"""Import a class from a string path (e.g. 'torch.optim.AdamW').
Expand Down Expand Up @@ -127,3 +129,27 @@ def sliding_window_overwrite(model_name: str) -> dict[str, Any]:
)

return overwrite_dict


def get_runtime_env_for_policy_worker(policy_worker_name: str) -> dict[str, Any]:
"""Get runtime environment configuration for DTensorPolicyWorker.

Conditionally enables expandable_segments on Hopper GPUs only,
as it causes crashes on Ampere GPUs.
"""
runtime_env = {
**get_nsight_config_if_pattern_matches(policy_worker_name),
}

# Only enable expandable_segments on Hopper and newer architectures (compute capability 9.x+)
try:
compute_capability = torch.cuda.get_device_properties(0).major
if compute_capability >= 9: # Hopper+
runtime_env["env_vars"] = {
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True"
}
except Exception:
# If we can't detect GPU capability, don't enable expandable_segments for safety
pass

return runtime_env
Loading