Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions nemo_rl/models/policy/workers/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
resolve_model_class,
)
from nemo_rl.models.policy.workers.base_policy_worker import AbstractPolicyWorker
from nemo_rl.models.policy.workers.patches import apply_torch_aten_alias_tensor_patch
from nemo_rl.utils.native_checkpoint import (
Comment on lines +79 to 80
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

head -170 nemo_rl/models/policy/workers/dtensor_policy_worker.py | tail -30

Repository: NVIDIA-NeMo/RL

Length of output: 1191


🏁 Script executed:

sed -n '155,175p' nemo_rl/models/policy/workers/dtensor_policy_worker.py

Repository: NVIDIA-NeMo/RL

Length of output: 995


🏁 Script executed:

sed -n '75,85p' nemo_rl/models/policy/workers/dtensor_policy_worker.py

Repository: NVIDIA-NeMo/RL

Length of output: 498


Move __init__ docstring before the patch call (docstring must be the first statement).

The """Initialize the DTensorPolicyWorker.""" string is no longer recognized as a docstring because apply_torch_aten_alias_tensor_patch() is called before it. In Python, docstrings must be the first statement in a function or method body to be recognized as such. Move the docstring to the first line of the method and place the patch call after it.

Proposed fix
     def __init__(
         self,
         config: PolicyConfig,
         tokenizer: AutoTokenizer,
         processor: Optional[AutoProcessor] = None,
         weights_path: Optional[str] = None,
         optimizer_path: Optional[str] = None,
         init_optimizer: bool = True,
         init_reference_model: bool = True,
         **kwargs: Any,
     ):
+        """Initialize the DTensorPolicyWorker."""
+
         # Apply patch to work around 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered'
         apply_torch_aten_alias_tensor_patch()
-
-        """Initialize the DTensorPolicyWorker."""
         self.tokenizer = tokenizer
🤖 Prompt for AI Agents
In @nemo_rl/models/policy/workers/dtensor_policy_worker.py around lines 79 - 80,
The __init__ docstring in DTensorPolicyWorker is not the first statement because
apply_torch_aten_alias_tensor_patch() is called before it; move the string
literal """Initialize the DTensorPolicyWorker.""" to be the very first line
inside the __init__ method (so it is recognized as the docstring), then call
apply_torch_aten_alias_tensor_patch() after that docstring and proceed with the
rest of the initialization.

load_checkpoint,
save_checkpoint,
Expand Down Expand Up @@ -157,6 +158,9 @@ def __init__(
init_reference_model: bool = True,
**kwargs: Any,
):
# Apply patch to work around 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered'
apply_torch_aten_alias_tensor_patch()

"""Initialize the DTensorPolicyWorker."""
self.tokenizer = tokenizer
self.processor = processor
Expand Down Expand Up @@ -296,10 +300,6 @@ def __init__(
print(
"[WARNING]: sequence_parallel=True, but tp_size=1 which has no effect. Enable tp_size > 1 to use sequence parallelism."
)
elif sequence_parallel_enabled and tp_size > 1:
raise RuntimeError(
"Sequence parallel + tp_size >1 is currently broken in torch==2.8.0. See https://github.com/NVIDIA-NeMo/Automodel/issues/652 for more details."
)

if cp_size > 1:
assert not isinstance(self.model, Gemma3ForCausalLM), (
Expand Down
11 changes: 6 additions & 5 deletions nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@
resolve_model_class,
)
from nemo_rl.models.policy.workers.base_policy_worker import AbstractPolicyWorker
from nemo_rl.models.policy.workers.patches import apply_transformer_engine_patch
from nemo_rl.models.policy.workers.patches import (
apply_torch_aten_alias_tensor_patch,
apply_transformer_engine_patch,
)
from nemo_rl.utils.automodel_checkpoint import AutomodelCheckpointManager
from nemo_rl.utils.checkpoint import CheckpointingConfig
from nemo_rl.utils.nsys import wrap_with_nvtx_name
Expand Down Expand Up @@ -127,6 +130,8 @@ def __init__(
"""Initialize the DTensorPolicyWorkerV2."""
# Apply TE patch until TE is upgraded to 2.10.0
apply_transformer_engine_patch()
# Apply patch to work around 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered'
apply_torch_aten_alias_tensor_patch()

self.tokenizer = tokenizer
self.processor = processor
Expand Down Expand Up @@ -338,10 +343,6 @@ def __init__(
print(
"[WARNING]: sequence_parallel=True, but tp_size=1 which has no effect. Enable tp_size > 1 to use sequence parallelism."
)
elif sequence_parallel_enabled and tp_size > 1:
raise RuntimeError(
"Sequence parallel + tp_size >1 is currently broken in torch==2.8.0. See https://github.com/NVIDIA-NeMo/Automodel/issues/652 for more details."
)

if cp_size > 1:
assert not isinstance(self.model, Gemma3ForCausalLM), (
Expand Down
22 changes: 22 additions & 0 deletions nemo_rl/models/policy/workers/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
import os
from importlib.util import find_spec

import torch
from torch.distributed.tensor._ops._tensor_ops import propagate_single_input_strategy
from torch.distributed.tensor._ops.utils import register_op_strategy

Comment on lines +18 to +21
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

find . -type f -name "patches.py" | grep -E "nemo_rl.*workers" | head -20

Repository: NVIDIA-NeMo/RL

Length of output: 100


🏁 Script executed:

cat -n nemo_rl/models/policy/workers/patches.py | head -150

Repository: NVIDIA-NeMo/RL

Length of output: 6194


Guard private torch imports, replace assert with explicit version check, and raise on patch failure.

Lines 18–21 import from private torch.distributed.tensor._ops.* modules without error handling, which will fail on incompatible torch versions. Lines 120–128 use assert for version gating (which can be stripped with python -O) and silently swallow registration failures with a bare Exception catch that only prints.

Suggested fix
@@
-import torch
-from torch.distributed.tensor._ops._tensor_ops import propagate_single_input_strategy
-from torch.distributed.tensor._ops.utils import register_op_strategy
+import torch
+try:
+    # NOTE: These are internal torch APIs; keep guarded and remove when upstream fix is available.
+    from torch.distributed.tensor._ops._tensor_ops import propagate_single_input_strategy
+    from torch.distributed.tensor._ops.utils import register_op_strategy
+except ImportError as e:
+    propagate_single_input_strategy = None  # type: ignore[assignment]
+    register_op_strategy = None  # type: ignore[assignment]
+    _DTENSOR_OPS_IMPORT_ERROR = e
+else:
+    _DTENSOR_OPS_IMPORT_ERROR = None
@@
 def apply_torch_aten_alias_tensor_patch():
     """Register a sharding rule for `torch.ops.aten.alias.default`.
@@
-    assert torch.__version__.startswith("2.9.0"), (
-        "This patch is needed for torch 2.9.0. Please retest if you upgrade torch to a newer version and remove this patch."
-    )
-    try:
-        register_op_strategy(torch.ops.aten.alias.default)(
-            propagate_single_input_strategy
-        )
-    except Exception as e:
-        print(f"Error applying torch.ops.aten.alias.default patch: {e}")
+    if _DTENSOR_OPS_IMPORT_ERROR is not None:
+        raise RuntimeError(
+            "DTensor internal APIs needed for the aten.alias.default sharding patch "
+            "are not available in this torch build."
+        ) from _DTENSOR_OPS_IMPORT_ERROR
+
+    if not torch.__version__.startswith("2.9.0"):
+        # Patch is intended for torch 2.9.x; avoid crashing worker init on other versions.
+        return
+
+    try:
+        register_op_strategy(torch.ops.aten.alias.default)(propagate_single_input_strategy)
+    except RuntimeError as e:
+        raise RuntimeError(
+            "Failed to register sharding strategy for torch.ops.aten.alias.default"
+        ) from e


def _get_transformer_engine_file(relative_path: str) -> str:
"""Return absolute path to a Transformer Engine file or raise if it cannot be found.
Expand Down Expand Up @@ -104,3 +108,21 @@ def apply_transformer_engine_patch():

except Exception as e:
print(f"Error checking/patching transformer_engine: {e}")


def apply_torch_aten_alias_tensor_patch():
"""Register a sharding rule for `torch.ops.aten.alias.default`.

Work around 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered'
in PyTorch 2.9. See https://github.com/pytorch/pytorch/pull/166867 for the upstream fix.
We can remove this patch when we upgrade torch to include this fix.
"""
assert torch.__version__.startswith("2.9.0"), (
"This patch is needed for torch 2.9.0. Please retest if you upgrade torch to a newer version and remove this patch."
)
try:
register_op_strategy(torch.ops.aten.alias.default)(
propagate_single_input_strategy
)
except Exception as e:
print(f"Error applying torch.ops.aten.alias.default patch: {e}")
3 changes: 3 additions & 0 deletions tests/unit/models/policy/test_dtensor_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,9 @@ def policy_setup(self, request, two_gpu_cluster, tiny_llama_model_path):
("tiny_nemotron5_h_model_path", 1, 1, False, False, False),
("tiny_nemotron5_h_model_path", 1, 1, False, True, True),
# nemotron5_h doesn't support cp
# TP2, SP=True
("tiny_llama_model_path", 2, 1, True, False, False),
("tiny_qwen2_model_path", 2, 1, True, False, False),
]
)
def training_setup(self, request, two_gpu_cluster):
Expand Down
43 changes: 43 additions & 0 deletions tests/unit/models/policy/test_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
from unittest.mock import MagicMock, mock_open, patch

import pytest
import torch

from nemo_rl.models.policy.workers.patches import (
_get_transformer_engine_file,
apply_torch_aten_alias_tensor_patch,
apply_transformer_engine_patch,
)

Expand Down Expand Up @@ -445,3 +447,44 @@ def permutation_kernel(x):
assert captured.out.count("Successfully patched") == 1
finally:
os.unlink(tmp_path)


def build_sharded_3d(rank: int, world_size: int):
"""
Build tensor, DTensor, and test the sharding rule for torch.ops.aten.alias.default.
"""
from torch.distributed.tensor import DeviceMesh, DTensor

mesh = DeviceMesh("cuda", list(range(world_size)))
global_shape = (4, 2, 4)
tensor = torch.arange(
torch.tensor(global_shape).prod(), dtype=torch.float32
).reshape(global_shape)

try:
dtensor = DTensor.from_local(
tensor,
device_mesh=mesh,
)
alias_dtensor = torch.ops.aten.alias.default(dtensor)
assert False, (
"Torch==2.9 should raise 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered', but it didn't."
"You can:\n "
"1. Check is you bump your torch version which contain the fix https://github.com/pytorch/pytorch/pull/166867\n"
"2. If yes, remove patch apply_torch_aten_alias_tensor_patch in nemo_rl/models/policy/workers/patches.py \n"
"3. Remove the patching call in nemo_rl/models/policy/workers/dtensor_policy_worker.py and nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py \n"
"4. Delete this test"
)
except NotImplementedError:
apply_torch_aten_alias_tensor_patch()
alias_dtensor = torch.ops.aten.alias.default(dtensor)
assert alias_dtensor.shape == global_shape
assert torch.allclose(alias_dtensor.to_local().cpu(), tensor.cpu())

Comment on lines +452 to +483
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, locate and examine the actual test file
fd -t f test_patches.py -x wc -l {}

Repository: NVIDIA-NeMo/RL

Length of output: 104


🏁 Script executed:

# Read the specific lines mentioned in the review
sed -n '452,490p' tests/unit/models/policy/test_patches.py

Repository: NVIDIA-NeMo/RL

Length of output: 1838


🏁 Script executed:

# Check the imports at the top of the file to understand available tools
head -50 tests/unit/models/policy/test_patches.py

Repository: NVIDIA-NeMo/RL

Length of output: 1920


🏁 Script executed:

# Search for other uses of DTensor in the codebase to understand best practices
rg "DTensor\.from_local" -A 3 -B 1

Repository: NVIDIA-NeMo/RL

Length of output: 13928


Fix CUDA placement, replace assert False, and address unused rank in build_sharded_3d().

The local tensor is created on CPU, but DeviceMesh("cuda", ...) requires a CUDA tensor. Also, assert False is brittle under optimization flags, and rank is unused.

Proposed fix
-def build_sharded_3d(rank: int, world_size: int):
-    """
-    Build tensor, DTensor, and test the sharding rule for torch.ops.aten.alias.default.
-    """
+def build_sharded_3d(_rank: int, world_size: int):
+    """Build tensor/DTensor and validate the sharding rule for `aten.alias.default`."""
     from torch.distributed.tensor import DeviceMesh, DTensor
 
     mesh = DeviceMesh("cuda", list(range(world_size)))
     global_shape = (4, 2, 4)
     tensor = torch.arange(
         torch.tensor(global_shape).prod(), dtype=torch.float32
-    ).reshape(global_shape)
+    ).reshape(global_shape).cuda()
 
     try:
         dtensor = DTensor.from_local(
             tensor,
             device_mesh=mesh,
         )
         alias_dtensor = torch.ops.aten.alias.default(dtensor)
-        assert False, (
+        raise AssertionError(
             "Torch==2.9 should raise 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered', but it didn't."
             "You can:\n "
             "1. Check is you bump your torch version which contain the fix https://github.com/pytorch/pytorch/pull/166867\n"
             "2. If yes, remove patch apply_torch_aten_alias_tensor_patch in nemo_rl/models/policy/workers/patches.py \n"
             "3. Remove the patching call in nemo_rl/models/policy/workers/dtensor_policy_worker.py and nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py \n"
             "4. Delete this test"
         )
     except NotImplementedError:
         apply_torch_aten_alias_tensor_patch()
         alias_dtensor = torch.ops.aten.alias.default(dtensor)
         assert alias_dtensor.shape == global_shape
         assert torch.allclose(alias_dtensor.to_local().cpu(), tensor.cpu())
🧰 Tools
🪛 Ruff (0.14.10)

452-452: Unused function argument: rank

(ARG001)


470-470: Do not assert False (python -O removes these calls), raise AssertionError()

Replace assert False

(B011)


@pytest.mark.cuda
def test_aten_alias_sharding_still_missing_upstream(
distributed_test_runner, world_size=2
):
"""Test that sharding rule for aten.alias.default is still missing upstream."""
distributed_test_runner(build_sharded_3d, world_size=world_size)
Loading