diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 46e1e8a52a..6872250d10 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -41,7 +41,6 @@ from nemo_rl.algorithms.interfaces import LossFunction, LossType from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.distributed.worker_group_utils import get_nsight_config_if_pattern_matches from nemo_rl.models.dtensor.parallelize import ( _parallelize_model, clip_grad_by_total_norm_, @@ -57,6 +56,7 @@ ) from nemo_rl.models.policy.utils import ( get_gpu_info, + get_runtime_env_for_policy_worker, import_class_from_path, sliding_window_overwrite, ) @@ -114,13 +114,7 @@ def get_cpu_state_dict( return new_state_dict -@ray.remote( - runtime_env={ - # TODO: This option causes a crash on Ampere. It's okay to enable on Hopper. - # "env_vars": {"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True"}, - **get_nsight_config_if_pattern_matches("dtensor_policy_worker"), - } -) +@ray.remote(runtime_env=get_runtime_env_for_policy_worker("dtensor_policy_worker")) class DTensorPolicyWorker: def __repr__(self) -> str: """Customizes the actor's prefix in the Ray logs. diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 89eb263674..9113723af0 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -89,7 +89,6 @@ from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import from_parallel_logits_to_logprobs from nemo_rl.distributed.named_sharding import NamedSharding -from nemo_rl.distributed.worker_group_utils import get_nsight_config_if_pattern_matches from nemo_rl.models.generation.interfaces import ( GenerationDatumSpec, GenerationOutputSpec, @@ -112,7 +111,7 @@ LogprobOutputSpec, ReferenceLogprobOutputSpec, ) -from nemo_rl.models.policy.utils import get_gpu_info +from nemo_rl.models.policy.utils import get_gpu_info, get_runtime_env_for_policy_worker TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase) @@ -322,11 +321,7 @@ def destroy_parallel_state(): pass -@ray.remote( - runtime_env={ - **get_nsight_config_if_pattern_matches("megatron_policy_worker"), - } -) +@ray.remote(runtime_env=get_runtime_env_for_policy_worker("megatron_policy_worker")) class MegatronPolicyWorker: def __repr__(self): """Customizes the actor's prefix in the Ray logs. diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index f19914576d..485dea9011 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -19,6 +19,8 @@ import torch from transformers import AutoConfig +from nemo_rl.distributed.worker_group_utils import get_nsight_config_if_pattern_matches + def import_class_from_path(name: str) -> Any: """Import a class from a string path (e.g. 'torch.optim.AdamW'). @@ -127,3 +129,27 @@ def sliding_window_overwrite(model_name: str) -> dict[str, Any]: ) return overwrite_dict + + +def get_runtime_env_for_policy_worker(policy_worker_name: str) -> dict[str, Any]: + """Get runtime environment configuration for DTensorPolicyWorker. + + Conditionally enables expandable_segments on Hopper GPUs only, + as it causes crashes on Ampere GPUs. + """ + runtime_env = { + **get_nsight_config_if_pattern_matches(policy_worker_name), + } + + # Only enable expandable_segments on Hopper and newer architectures (compute capability 9.x+) + try: + compute_capability = torch.cuda.get_device_properties(0).major + if compute_capability >= 9: # Hopper+ + runtime_env["env_vars"] = { + "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True" + } + except Exception: + # If we can't detect GPU capability, don't enable expandable_segments for safety + pass + + return runtime_env