Skip to content

cp: fix: patch pytorch aten.alias.default shard strategy (1728) into r0.5.0#1753

Merged
yuki-97 merged 1 commit intor0.5.0from
cherry-pick-1728-r0.5.0
Jan 9, 2026
Merged

cp: fix: patch pytorch aten.alias.default shard strategy (1728) into r0.5.0#1753
yuki-97 merged 1 commit intor0.5.0from
cherry-pick-1728-r0.5.0

Conversation

@chtruong814
Copy link
Copy Markdown
Contributor

@chtruong814 chtruong814 commented Jan 9, 2026

beep boop [🤖]: Hi @RayenTian 👋,

we've cherry picked #1728 into  for you! 🚀

Please review and approve this cherry pick by your convenience!

Summary by CodeRabbit

  • New Features

    • Sequence parallel now works with larger tensor parallel configurations (previously blocked)
  • Bug Fixes

    • Resolved tensor sharding compatibility issue with PyTorch 2.9.0
  • Tests

    • Added test coverage for tensor parallel and sequence parallel configurations

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: ruit <ruit@nvidia.com>
Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
@github-actions
Copy link
Copy Markdown

github-actions bot commented Jan 9, 2026

ℹ️ File Consistency Check

Check based on commit: be6b54e (PR #1753 from cherry-pick-1728-r0.5.0)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@yuki-97 yuki-97 added the CI:L1 Run doctests, unit tests, and functional tests label Jan 9, 2026
@yuki-97 yuki-97 enabled auto-merge (squash) January 9, 2026 09:26
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jan 9, 2026

📝 Walkthrough

Walkthrough

The PR introduces a new patch function to handle torch.ops.aten.alias.default sharding and applies it to DTensorPolicyWorker implementations. The patch is invoked during worker initialization, and the prior runtime error blocking sequence parallel with tp_size > 1 is removed. Tests are added to validate the patch behavior.

Changes

Cohort / File(s) Summary
Core patch implementation and worker integration
nemo_rl/models/policy/workers/patches.py
Added apply_torch_aten_alias_tensor_patch() function to register sharding strategy for torch.ops.aten.alias.default using propagate_single_input_strategy, with PyTorch 2.9.0 version assertion and error handling. Includes necessary imports: torch, propagate_single_input_strategy, register_op_strategy.
Worker initialization updates
nemo_rl/models/policy/workers/dtensor_policy_worker.py, nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
Updated imports to include apply_torch_aten_alias_tensor_patch and added patch invocation in __init__ method. Removed runtime error branch that previously blocked sequence\_parallel with tp_size > 1, allowing this configuration to proceed.
Test coverage
tests/unit/models/policy/test_dtensor_worker.py, tests/unit/models/policy/test_patches.py
Extended training\_setup parametrization with test cases for TP=2 and SP=True configurations. Added ATEN alias sharding integration tests and helper utility build_sharded_3d to validate patch application and tensor equivalence.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

r0.5.0, CI:L1

Suggested reviewers

  • RayenTian
  • terrykong
  • yuki-97
🚥 Pre-merge checks | ✅ 3 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR description lacks test results or validation information despite non-trivial changes to core tensor sharding operations and enabling new functionality. Update PR description to include test execution results, validation that new parameter combinations work without regressions, and performance/numeric validation demonstrating no regression.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The pull request title accurately describes the main change: fixing and patching PyTorch's aten.alias.default sharding strategy.
Docstring Coverage ✅ Passed Docstring coverage is 88.89% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In @nemo_rl/models/policy/workers/dtensor_policy_worker.py:
- Around line 79-80: The __init__ docstring in DTensorPolicyWorker is not the
first statement because apply_torch_aten_alias_tensor_patch() is called before
it; move the string literal """Initialize the DTensorPolicyWorker.""" to be the
very first line inside the __init__ method (so it is recognized as the
docstring), then call apply_torch_aten_alias_tensor_patch() after that docstring
and proceed with the rest of the initialization.
🧹 Nitpick comments (1)
tests/unit/models/policy/test_patches.py (1)

485-490: Consider version-gating or marking expected behavior explicitly to reduce future maintenance churn.

Once PyTorch includes the upstream fix (or if a different 2.9.x build behaves differently), this test becomes a forced failure. If the intent is “fail loudly so we remove the patch,” that’s fine—but then it should at least be guarded by the exact torch version(s) where the failure is expected to avoid surprising breakage during upgrades.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 53342b1 and be6b54e.

📒 Files selected for processing (5)
  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • nemo_rl/models/policy/workers/patches.py
  • tests/unit/models/policy/test_dtensor_worker.py
  • tests/unit/models/policy/test_patches.py
🧰 Additional context used
📓 Path-based instructions (4)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code

Files:

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • tests/unit/models/policy/test_patches.py
  • tests/unit/models/policy/test_dtensor_worker.py
  • nemo_rl/models/policy/workers/patches.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes

Files:

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/patches.py
!(**/tests/**|**/test_*.py|**/test_*.sh)

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year

Files:

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • tests/unit/models/policy/test_patches.py
  • tests/unit/models/policy/test_dtensor_worker.py
  • nemo_rl/models/policy/workers/patches.py
**/*.{py,sh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)

Files:

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • tests/unit/models/policy/test_patches.py
  • tests/unit/models/policy/test_dtensor_worker.py
  • nemo_rl/models/policy/workers/patches.py
🧬 Code graph analysis (2)
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)
nemo_rl/models/policy/workers/patches.py (2)
  • apply_torch_aten_alias_tensor_patch (113-128)
  • apply_transformer_engine_patch (51-110)
nemo_rl/models/policy/workers/dtensor_policy_worker.py (1)
nemo_rl/models/policy/workers/patches.py (1)
  • apply_torch_aten_alias_tensor_patch (113-128)
🪛 Ruff (0.14.10)
tests/unit/models/policy/test_patches.py

452-452: Unused function argument: rank

(ARG001)


470-470: Do not assert False (python -O removes these calls), raise AssertionError()

Replace assert False

(B011)

nemo_rl/models/policy/workers/patches.py

127-127: Do not catch blind exception: Exception

(BLE001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: Lint check
  • GitHub Check: Lint check
  • GitHub Check: Lint check
  • GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (2)
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)

89-92: Patch call ordering looks right; ensure alias patch remains safe/idempotent across processes.

Calling apply_transformer_engine_patch() and then apply_torch_aten_alias_tensor_patch() early in init matches the intended “patch before DTensor setup” flow.

Main concern is the alias patch helper’s current brittleness (version assert + broad exception handling); see comments on nemo_rl/models/policy/workers/patches.py.

Also applies to: 131-135

tests/unit/models/policy/test_dtensor_worker.py (1)

584-587: Good coverage addition for TP=2 + SP=True; please sanity-check total parametrized runtime.

This expands coverage for the newly-unblocked configuration, but the matrix is already large—worth confirming this doesn’t push CI over time limits.

Comment on lines +79 to 80
from nemo_rl.models.policy.workers.patches import apply_torch_aten_alias_tensor_patch
from nemo_rl.utils.native_checkpoint import (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

head -170 nemo_rl/models/policy/workers/dtensor_policy_worker.py | tail -30

Repository: NVIDIA-NeMo/RL

Length of output: 1191


🏁 Script executed:

sed -n '155,175p' nemo_rl/models/policy/workers/dtensor_policy_worker.py

Repository: NVIDIA-NeMo/RL

Length of output: 995


🏁 Script executed:

sed -n '75,85p' nemo_rl/models/policy/workers/dtensor_policy_worker.py

Repository: NVIDIA-NeMo/RL

Length of output: 498


Move __init__ docstring before the patch call (docstring must be the first statement).

The """Initialize the DTensorPolicyWorker.""" string is no longer recognized as a docstring because apply_torch_aten_alias_tensor_patch() is called before it. In Python, docstrings must be the first statement in a function or method body to be recognized as such. Move the docstring to the first line of the method and place the patch call after it.

Proposed fix
     def __init__(
         self,
         config: PolicyConfig,
         tokenizer: AutoTokenizer,
         processor: Optional[AutoProcessor] = None,
         weights_path: Optional[str] = None,
         optimizer_path: Optional[str] = None,
         init_optimizer: bool = True,
         init_reference_model: bool = True,
         **kwargs: Any,
     ):
+        """Initialize the DTensorPolicyWorker."""
+
         # Apply patch to work around 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered'
         apply_torch_aten_alias_tensor_patch()
-
-        """Initialize the DTensorPolicyWorker."""
         self.tokenizer = tokenizer
🤖 Prompt for AI Agents
In @nemo_rl/models/policy/workers/dtensor_policy_worker.py around lines 79 - 80,
The __init__ docstring in DTensorPolicyWorker is not the first statement because
apply_torch_aten_alias_tensor_patch() is called before it; move the string
literal """Initialize the DTensorPolicyWorker.""" to be the very first line
inside the __init__ method (so it is recognized as the docstring), then call
apply_torch_aten_alias_tensor_patch() after that docstring and proceed with the
rest of the initialization.

Comment on lines +18 to +21
import torch
from torch.distributed.tensor._ops._tensor_ops import propagate_single_input_strategy
from torch.distributed.tensor._ops.utils import register_op_strategy

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

find . -type f -name "patches.py" | grep -E "nemo_rl.*workers" | head -20

Repository: NVIDIA-NeMo/RL

Length of output: 100


🏁 Script executed:

cat -n nemo_rl/models/policy/workers/patches.py | head -150

Repository: NVIDIA-NeMo/RL

Length of output: 6194


Guard private torch imports, replace assert with explicit version check, and raise on patch failure.

Lines 18–21 import from private torch.distributed.tensor._ops.* modules without error handling, which will fail on incompatible torch versions. Lines 120–128 use assert for version gating (which can be stripped with python -O) and silently swallow registration failures with a bare Exception catch that only prints.

Suggested fix
@@
-import torch
-from torch.distributed.tensor._ops._tensor_ops import propagate_single_input_strategy
-from torch.distributed.tensor._ops.utils import register_op_strategy
+import torch
+try:
+    # NOTE: These are internal torch APIs; keep guarded and remove when upstream fix is available.
+    from torch.distributed.tensor._ops._tensor_ops import propagate_single_input_strategy
+    from torch.distributed.tensor._ops.utils import register_op_strategy
+except ImportError as e:
+    propagate_single_input_strategy = None  # type: ignore[assignment]
+    register_op_strategy = None  # type: ignore[assignment]
+    _DTENSOR_OPS_IMPORT_ERROR = e
+else:
+    _DTENSOR_OPS_IMPORT_ERROR = None
@@
 def apply_torch_aten_alias_tensor_patch():
     """Register a sharding rule for `torch.ops.aten.alias.default`.
@@
-    assert torch.__version__.startswith("2.9.0"), (
-        "This patch is needed for torch 2.9.0. Please retest if you upgrade torch to a newer version and remove this patch."
-    )
-    try:
-        register_op_strategy(torch.ops.aten.alias.default)(
-            propagate_single_input_strategy
-        )
-    except Exception as e:
-        print(f"Error applying torch.ops.aten.alias.default patch: {e}")
+    if _DTENSOR_OPS_IMPORT_ERROR is not None:
+        raise RuntimeError(
+            "DTensor internal APIs needed for the aten.alias.default sharding patch "
+            "are not available in this torch build."
+        ) from _DTENSOR_OPS_IMPORT_ERROR
+
+    if not torch.__version__.startswith("2.9.0"):
+        # Patch is intended for torch 2.9.x; avoid crashing worker init on other versions.
+        return
+
+    try:
+        register_op_strategy(torch.ops.aten.alias.default)(propagate_single_input_strategy)
+    except RuntimeError as e:
+        raise RuntimeError(
+            "Failed to register sharding strategy for torch.ops.aten.alias.default"
+        ) from e

Comment on lines +452 to +483
def build_sharded_3d(rank: int, world_size: int):
"""
Build tensor, DTensor, and test the sharding rule for torch.ops.aten.alias.default.
"""
from torch.distributed.tensor import DeviceMesh, DTensor

mesh = DeviceMesh("cuda", list(range(world_size)))
global_shape = (4, 2, 4)
tensor = torch.arange(
torch.tensor(global_shape).prod(), dtype=torch.float32
).reshape(global_shape)

try:
dtensor = DTensor.from_local(
tensor,
device_mesh=mesh,
)
alias_dtensor = torch.ops.aten.alias.default(dtensor)
assert False, (
"Torch==2.9 should raise 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered', but it didn't."
"You can:\n "
"1. Check is you bump your torch version which contain the fix https://github.com/pytorch/pytorch/pull/166867\n"
"2. If yes, remove patch apply_torch_aten_alias_tensor_patch in nemo_rl/models/policy/workers/patches.py \n"
"3. Remove the patching call in nemo_rl/models/policy/workers/dtensor_policy_worker.py and nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py \n"
"4. Delete this test"
)
except NotImplementedError:
apply_torch_aten_alias_tensor_patch()
alias_dtensor = torch.ops.aten.alias.default(dtensor)
assert alias_dtensor.shape == global_shape
assert torch.allclose(alias_dtensor.to_local().cpu(), tensor.cpu())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, locate and examine the actual test file
fd -t f test_patches.py -x wc -l {}

Repository: NVIDIA-NeMo/RL

Length of output: 104


🏁 Script executed:

# Read the specific lines mentioned in the review
sed -n '452,490p' tests/unit/models/policy/test_patches.py

Repository: NVIDIA-NeMo/RL

Length of output: 1838


🏁 Script executed:

# Check the imports at the top of the file to understand available tools
head -50 tests/unit/models/policy/test_patches.py

Repository: NVIDIA-NeMo/RL

Length of output: 1920


🏁 Script executed:

# Search for other uses of DTensor in the codebase to understand best practices
rg "DTensor\.from_local" -A 3 -B 1

Repository: NVIDIA-NeMo/RL

Length of output: 13928


Fix CUDA placement, replace assert False, and address unused rank in build_sharded_3d().

The local tensor is created on CPU, but DeviceMesh("cuda", ...) requires a CUDA tensor. Also, assert False is brittle under optimization flags, and rank is unused.

Proposed fix
-def build_sharded_3d(rank: int, world_size: int):
-    """
-    Build tensor, DTensor, and test the sharding rule for torch.ops.aten.alias.default.
-    """
+def build_sharded_3d(_rank: int, world_size: int):
+    """Build tensor/DTensor and validate the sharding rule for `aten.alias.default`."""
     from torch.distributed.tensor import DeviceMesh, DTensor
 
     mesh = DeviceMesh("cuda", list(range(world_size)))
     global_shape = (4, 2, 4)
     tensor = torch.arange(
         torch.tensor(global_shape).prod(), dtype=torch.float32
-    ).reshape(global_shape)
+    ).reshape(global_shape).cuda()
 
     try:
         dtensor = DTensor.from_local(
             tensor,
             device_mesh=mesh,
         )
         alias_dtensor = torch.ops.aten.alias.default(dtensor)
-        assert False, (
+        raise AssertionError(
             "Torch==2.9 should raise 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered', but it didn't."
             "You can:\n "
             "1. Check is you bump your torch version which contain the fix https://github.com/pytorch/pytorch/pull/166867\n"
             "2. If yes, remove patch apply_torch_aten_alias_tensor_patch in nemo_rl/models/policy/workers/patches.py \n"
             "3. Remove the patching call in nemo_rl/models/policy/workers/dtensor_policy_worker.py and nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py \n"
             "4. Delete this test"
         )
     except NotImplementedError:
         apply_torch_aten_alias_tensor_patch()
         alias_dtensor = torch.ops.aten.alias.default(dtensor)
         assert alias_dtensor.shape == global_shape
         assert torch.allclose(alias_dtensor.to_local().cpu(), tensor.cpu())
🧰 Tools
🪛 Ruff (0.14.10)

452-452: Unused function argument: rank

(ARG001)


470-470: Do not assert False (python -O removes these calls), raise AssertionError()

Replace assert False

(B011)

@yuki-97 yuki-97 merged commit 420d69d into r0.5.0 Jan 9, 2026
68 of 71 checks passed
@yuki-97 yuki-97 deleted the cherry-pick-1728-r0.5.0 branch January 9, 2026 17:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cherry-pick CI:L1 Run doctests, unit tests, and functional tests Run CICD

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants