diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker.py b/nemo_rl/models/policy/workers/dtensor_policy_worker.py index 0f41938c88..e9c57cfb55 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker.py @@ -76,6 +76,7 @@ resolve_model_class, ) from nemo_rl.models.policy.workers.base_policy_worker import AbstractPolicyWorker +from nemo_rl.models.policy.workers.patches import apply_torch_aten_alias_tensor_patch from nemo_rl.utils.native_checkpoint import ( load_checkpoint, save_checkpoint, @@ -157,6 +158,9 @@ def __init__( init_reference_model: bool = True, **kwargs: Any, ): + # Apply patch to work around 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered' + apply_torch_aten_alias_tensor_patch() + """Initialize the DTensorPolicyWorker.""" self.tokenizer = tokenizer self.processor = processor @@ -296,10 +300,6 @@ def __init__( print( "[WARNING]: sequence_parallel=True, but tp_size=1 which has no effect. Enable tp_size > 1 to use sequence parallelism." ) - elif sequence_parallel_enabled and tp_size > 1: - raise RuntimeError( - "Sequence parallel + tp_size >1 is currently broken in torch==2.8.0. See https://github.com/NVIDIA-NeMo/Automodel/issues/652 for more details." - ) if cp_size > 1: assert not isinstance(self.model, Gemma3ForCausalLM), ( diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 785568cc76..0cc3d7495c 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -86,7 +86,10 @@ resolve_model_class, ) from nemo_rl.models.policy.workers.base_policy_worker import AbstractPolicyWorker -from nemo_rl.models.policy.workers.patches import apply_transformer_engine_patch +from nemo_rl.models.policy.workers.patches import ( + apply_torch_aten_alias_tensor_patch, + apply_transformer_engine_patch, +) from nemo_rl.utils.automodel_checkpoint import AutomodelCheckpointManager from nemo_rl.utils.checkpoint import CheckpointingConfig from nemo_rl.utils.nsys import wrap_with_nvtx_name @@ -127,6 +130,8 @@ def __init__( """Initialize the DTensorPolicyWorkerV2.""" # Apply TE patch until TE is upgraded to 2.10.0 apply_transformer_engine_patch() + # Apply patch to work around 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered' + apply_torch_aten_alias_tensor_patch() self.tokenizer = tokenizer self.processor = processor @@ -338,10 +343,6 @@ def __init__( print( "[WARNING]: sequence_parallel=True, but tp_size=1 which has no effect. Enable tp_size > 1 to use sequence parallelism." ) - elif sequence_parallel_enabled and tp_size > 1: - raise RuntimeError( - "Sequence parallel + tp_size >1 is currently broken in torch==2.8.0. See https://github.com/NVIDIA-NeMo/Automodel/issues/652 for more details." - ) if cp_size > 1: assert not isinstance(self.model, Gemma3ForCausalLM), ( diff --git a/nemo_rl/models/policy/workers/patches.py b/nemo_rl/models/policy/workers/patches.py index 5a0d5b0ab8..aa140548aa 100644 --- a/nemo_rl/models/policy/workers/patches.py +++ b/nemo_rl/models/policy/workers/patches.py @@ -15,6 +15,10 @@ import os from importlib.util import find_spec +import torch +from torch.distributed.tensor._ops._tensor_ops import propagate_single_input_strategy +from torch.distributed.tensor._ops.utils import register_op_strategy + def _get_transformer_engine_file(relative_path: str) -> str: """Return absolute path to a Transformer Engine file or raise if it cannot be found. @@ -104,3 +108,21 @@ def apply_transformer_engine_patch(): except Exception as e: print(f"Error checking/patching transformer_engine: {e}") + + +def apply_torch_aten_alias_tensor_patch(): + """Register a sharding rule for `torch.ops.aten.alias.default`. + + Work around 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered' + in PyTorch 2.9. See https://github.com/pytorch/pytorch/pull/166867 for the upstream fix. + We can remove this patch when we upgrade torch to include this fix. + """ + assert torch.__version__.startswith("2.9.0"), ( + "This patch is needed for torch 2.9.0. Please retest if you upgrade torch to a newer version and remove this patch." + ) + try: + register_op_strategy(torch.ops.aten.alias.default)( + propagate_single_input_strategy + ) + except Exception as e: + print(f"Error applying torch.ops.aten.alias.default patch: {e}") diff --git a/tests/unit/models/policy/test_dtensor_worker.py b/tests/unit/models/policy/test_dtensor_worker.py index 4ddd320bfe..a750a78f9a 100644 --- a/tests/unit/models/policy/test_dtensor_worker.py +++ b/tests/unit/models/policy/test_dtensor_worker.py @@ -581,6 +581,9 @@ def policy_setup(self, request, two_gpu_cluster, tiny_llama_model_path): ("tiny_nemotron5_h_model_path", 1, 1, False, False, False), ("tiny_nemotron5_h_model_path", 1, 1, False, True, True), # nemotron5_h doesn't support cp + # TP2, SP=True + ("tiny_llama_model_path", 2, 1, True, False, False), + ("tiny_qwen2_model_path", 2, 1, True, False, False), ] ) def training_setup(self, request, two_gpu_cluster): diff --git a/tests/unit/models/policy/test_patches.py b/tests/unit/models/policy/test_patches.py index e8cacbcd4a..0505768bf1 100644 --- a/tests/unit/models/policy/test_patches.py +++ b/tests/unit/models/policy/test_patches.py @@ -18,9 +18,11 @@ from unittest.mock import MagicMock, mock_open, patch import pytest +import torch from nemo_rl.models.policy.workers.patches import ( _get_transformer_engine_file, + apply_torch_aten_alias_tensor_patch, apply_transformer_engine_patch, ) @@ -445,3 +447,44 @@ def permutation_kernel(x): assert captured.out.count("Successfully patched") == 1 finally: os.unlink(tmp_path) + + +def build_sharded_3d(rank: int, world_size: int): + """ + Build tensor, DTensor, and test the sharding rule for torch.ops.aten.alias.default. + """ + from torch.distributed.tensor import DeviceMesh, DTensor + + mesh = DeviceMesh("cuda", list(range(world_size))) + global_shape = (4, 2, 4) + tensor = torch.arange( + torch.tensor(global_shape).prod(), dtype=torch.float32 + ).reshape(global_shape) + + try: + dtensor = DTensor.from_local( + tensor, + device_mesh=mesh, + ) + alias_dtensor = torch.ops.aten.alias.default(dtensor) + assert False, ( + "Torch==2.9 should raise 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered', but it didn't." + "You can:\n " + "1. Check is you bump your torch version which contain the fix https://github.com/pytorch/pytorch/pull/166867\n" + "2. If yes, remove patch apply_torch_aten_alias_tensor_patch in nemo_rl/models/policy/workers/patches.py \n" + "3. Remove the patching call in nemo_rl/models/policy/workers/dtensor_policy_worker.py and nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py \n" + "4. Delete this test" + ) + except NotImplementedError: + apply_torch_aten_alias_tensor_patch() + alias_dtensor = torch.ops.aten.alias.default(dtensor) + assert alias_dtensor.shape == global_shape + assert torch.allclose(alias_dtensor.to_local().cpu(), tensor.cpu()) + + +@pytest.mark.cuda +def test_aten_alias_sharding_still_missing_upstream( + distributed_test_runner, world_size=2 +): + """Test that sharding rule for aten.alias.default is still missing upstream.""" + distributed_test_runner(build_sharded_3d, world_size=world_size)