cp: fix: patch pytorch aten.alias.default shard strategy (1728) into r0.5.0#1753
cp: fix: patch pytorch aten.alias.default shard strategy (1728) into r0.5.0#1753
fix: patch pytorch aten.alias.default shard strategy (1728) into r0.5.0#1753Conversation
Signed-off-by: ruit <ruit@nvidia.com> Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
ℹ️ File Consistency CheckCheck based on commit: be6b54e (PR #1753 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
📝 WalkthroughWalkthroughThe PR introduces a new patch function to handle torch.ops.aten.alias.default sharding and applies it to DTensorPolicyWorker implementations. The patch is invoked during worker initialization, and the prior runtime error blocking sequence parallel with tp_size > 1 is removed. Tests are added to validate the patch behavior. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In @nemo_rl/models/policy/workers/dtensor_policy_worker.py:
- Around line 79-80: The __init__ docstring in DTensorPolicyWorker is not the
first statement because apply_torch_aten_alias_tensor_patch() is called before
it; move the string literal """Initialize the DTensorPolicyWorker.""" to be the
very first line inside the __init__ method (so it is recognized as the
docstring), then call apply_torch_aten_alias_tensor_patch() after that docstring
and proceed with the rest of the initialization.
🧹 Nitpick comments (1)
tests/unit/models/policy/test_patches.py (1)
485-490: Consider version-gating or marking expected behavior explicitly to reduce future maintenance churn.Once PyTorch includes the upstream fix (or if a different 2.9.x build behaves differently), this test becomes a forced failure. If the intent is “fail loudly so we remove the patch,” that’s fine—but then it should at least be guarded by the exact torch version(s) where the failure is expected to avoid surprising breakage during upgrades.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
nemo_rl/models/policy/workers/dtensor_policy_worker.pynemo_rl/models/policy/workers/dtensor_policy_worker_v2.pynemo_rl/models/policy/workers/patches.pytests/unit/models/policy/test_dtensor_worker.pytests/unit/models/policy/test_patches.py
🧰 Additional context used
📓 Path-based instructions (4)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code
Files:
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.pynemo_rl/models/policy/workers/dtensor_policy_worker.pytests/unit/models/policy/test_patches.pytests/unit/models/policy/test_dtensor_worker.pynemo_rl/models/policy/workers/patches.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes
Files:
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.pynemo_rl/models/policy/workers/dtensor_policy_worker.pynemo_rl/models/policy/workers/patches.py
!(**/tests/**|**/test_*.py|**/test_*.sh)
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year
Files:
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.pynemo_rl/models/policy/workers/dtensor_policy_worker.pytests/unit/models/policy/test_patches.pytests/unit/models/policy/test_dtensor_worker.pynemo_rl/models/policy/workers/patches.py
**/*.{py,sh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)
Files:
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.pynemo_rl/models/policy/workers/dtensor_policy_worker.pytests/unit/models/policy/test_patches.pytests/unit/models/policy/test_dtensor_worker.pynemo_rl/models/policy/workers/patches.py
🧬 Code graph analysis (2)
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)
nemo_rl/models/policy/workers/patches.py (2)
apply_torch_aten_alias_tensor_patch(113-128)apply_transformer_engine_patch(51-110)
nemo_rl/models/policy/workers/dtensor_policy_worker.py (1)
nemo_rl/models/policy/workers/patches.py (1)
apply_torch_aten_alias_tensor_patch(113-128)
🪛 Ruff (0.14.10)
tests/unit/models/policy/test_patches.py
452-452: Unused function argument: rank
(ARG001)
470-470: Do not assert False (python -O removes these calls), raise AssertionError()
Replace assert False
(B011)
nemo_rl/models/policy/workers/patches.py
127-127: Do not catch blind exception: Exception
(BLE001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: Lint check
- GitHub Check: Lint check
- GitHub Check: Lint check
- GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (2)
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)
89-92: Patch call ordering looks right; ensure alias patch remains safe/idempotent across processes.Calling
apply_transformer_engine_patch()and thenapply_torch_aten_alias_tensor_patch()early in init matches the intended “patch before DTensor setup” flow.Main concern is the alias patch helper’s current brittleness (version assert + broad exception handling); see comments on
nemo_rl/models/policy/workers/patches.py.Also applies to: 131-135
tests/unit/models/policy/test_dtensor_worker.py (1)
584-587: Good coverage addition for TP=2 + SP=True; please sanity-check total parametrized runtime.This expands coverage for the newly-unblocked configuration, but the matrix is already large—worth confirming this doesn’t push CI over time limits.
| from nemo_rl.models.policy.workers.patches import apply_torch_aten_alias_tensor_patch | ||
| from nemo_rl.utils.native_checkpoint import ( |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
head -170 nemo_rl/models/policy/workers/dtensor_policy_worker.py | tail -30Repository: NVIDIA-NeMo/RL
Length of output: 1191
🏁 Script executed:
sed -n '155,175p' nemo_rl/models/policy/workers/dtensor_policy_worker.pyRepository: NVIDIA-NeMo/RL
Length of output: 995
🏁 Script executed:
sed -n '75,85p' nemo_rl/models/policy/workers/dtensor_policy_worker.pyRepository: NVIDIA-NeMo/RL
Length of output: 498
Move __init__ docstring before the patch call (docstring must be the first statement).
The """Initialize the DTensorPolicyWorker.""" string is no longer recognized as a docstring because apply_torch_aten_alias_tensor_patch() is called before it. In Python, docstrings must be the first statement in a function or method body to be recognized as such. Move the docstring to the first line of the method and place the patch call after it.
Proposed fix
def __init__(
self,
config: PolicyConfig,
tokenizer: AutoTokenizer,
processor: Optional[AutoProcessor] = None,
weights_path: Optional[str] = None,
optimizer_path: Optional[str] = None,
init_optimizer: bool = True,
init_reference_model: bool = True,
**kwargs: Any,
):
+ """Initialize the DTensorPolicyWorker."""
+
# Apply patch to work around 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered'
apply_torch_aten_alias_tensor_patch()
-
- """Initialize the DTensorPolicyWorker."""
self.tokenizer = tokenizer🤖 Prompt for AI Agents
In @nemo_rl/models/policy/workers/dtensor_policy_worker.py around lines 79 - 80,
The __init__ docstring in DTensorPolicyWorker is not the first statement because
apply_torch_aten_alias_tensor_patch() is called before it; move the string
literal """Initialize the DTensorPolicyWorker.""" to be the very first line
inside the __init__ method (so it is recognized as the docstring), then call
apply_torch_aten_alias_tensor_patch() after that docstring and proceed with the
rest of the initialization.
| import torch | ||
| from torch.distributed.tensor._ops._tensor_ops import propagate_single_input_strategy | ||
| from torch.distributed.tensor._ops.utils import register_op_strategy | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
find . -type f -name "patches.py" | grep -E "nemo_rl.*workers" | head -20Repository: NVIDIA-NeMo/RL
Length of output: 100
🏁 Script executed:
cat -n nemo_rl/models/policy/workers/patches.py | head -150Repository: NVIDIA-NeMo/RL
Length of output: 6194
Guard private torch imports, replace assert with explicit version check, and raise on patch failure.
Lines 18–21 import from private torch.distributed.tensor._ops.* modules without error handling, which will fail on incompatible torch versions. Lines 120–128 use assert for version gating (which can be stripped with python -O) and silently swallow registration failures with a bare Exception catch that only prints.
Suggested fix
@@
-import torch
-from torch.distributed.tensor._ops._tensor_ops import propagate_single_input_strategy
-from torch.distributed.tensor._ops.utils import register_op_strategy
+import torch
+try:
+ # NOTE: These are internal torch APIs; keep guarded and remove when upstream fix is available.
+ from torch.distributed.tensor._ops._tensor_ops import propagate_single_input_strategy
+ from torch.distributed.tensor._ops.utils import register_op_strategy
+except ImportError as e:
+ propagate_single_input_strategy = None # type: ignore[assignment]
+ register_op_strategy = None # type: ignore[assignment]
+ _DTENSOR_OPS_IMPORT_ERROR = e
+else:
+ _DTENSOR_OPS_IMPORT_ERROR = None
@@
def apply_torch_aten_alias_tensor_patch():
"""Register a sharding rule for `torch.ops.aten.alias.default`.
@@
- assert torch.__version__.startswith("2.9.0"), (
- "This patch is needed for torch 2.9.0. Please retest if you upgrade torch to a newer version and remove this patch."
- )
- try:
- register_op_strategy(torch.ops.aten.alias.default)(
- propagate_single_input_strategy
- )
- except Exception as e:
- print(f"Error applying torch.ops.aten.alias.default patch: {e}")
+ if _DTENSOR_OPS_IMPORT_ERROR is not None:
+ raise RuntimeError(
+ "DTensor internal APIs needed for the aten.alias.default sharding patch "
+ "are not available in this torch build."
+ ) from _DTENSOR_OPS_IMPORT_ERROR
+
+ if not torch.__version__.startswith("2.9.0"):
+ # Patch is intended for torch 2.9.x; avoid crashing worker init on other versions.
+ return
+
+ try:
+ register_op_strategy(torch.ops.aten.alias.default)(propagate_single_input_strategy)
+ except RuntimeError as e:
+ raise RuntimeError(
+ "Failed to register sharding strategy for torch.ops.aten.alias.default"
+ ) from e| def build_sharded_3d(rank: int, world_size: int): | ||
| """ | ||
| Build tensor, DTensor, and test the sharding rule for torch.ops.aten.alias.default. | ||
| """ | ||
| from torch.distributed.tensor import DeviceMesh, DTensor | ||
|
|
||
| mesh = DeviceMesh("cuda", list(range(world_size))) | ||
| global_shape = (4, 2, 4) | ||
| tensor = torch.arange( | ||
| torch.tensor(global_shape).prod(), dtype=torch.float32 | ||
| ).reshape(global_shape) | ||
|
|
||
| try: | ||
| dtensor = DTensor.from_local( | ||
| tensor, | ||
| device_mesh=mesh, | ||
| ) | ||
| alias_dtensor = torch.ops.aten.alias.default(dtensor) | ||
| assert False, ( | ||
| "Torch==2.9 should raise 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered', but it didn't." | ||
| "You can:\n " | ||
| "1. Check is you bump your torch version which contain the fix https://github.com/pytorch/pytorch/pull/166867\n" | ||
| "2. If yes, remove patch apply_torch_aten_alias_tensor_patch in nemo_rl/models/policy/workers/patches.py \n" | ||
| "3. Remove the patching call in nemo_rl/models/policy/workers/dtensor_policy_worker.py and nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py \n" | ||
| "4. Delete this test" | ||
| ) | ||
| except NotImplementedError: | ||
| apply_torch_aten_alias_tensor_patch() | ||
| alias_dtensor = torch.ops.aten.alias.default(dtensor) | ||
| assert alias_dtensor.shape == global_shape | ||
| assert torch.allclose(alias_dtensor.to_local().cpu(), tensor.cpu()) | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, locate and examine the actual test file
fd -t f test_patches.py -x wc -l {}Repository: NVIDIA-NeMo/RL
Length of output: 104
🏁 Script executed:
# Read the specific lines mentioned in the review
sed -n '452,490p' tests/unit/models/policy/test_patches.pyRepository: NVIDIA-NeMo/RL
Length of output: 1838
🏁 Script executed:
# Check the imports at the top of the file to understand available tools
head -50 tests/unit/models/policy/test_patches.pyRepository: NVIDIA-NeMo/RL
Length of output: 1920
🏁 Script executed:
# Search for other uses of DTensor in the codebase to understand best practices
rg "DTensor\.from_local" -A 3 -B 1Repository: NVIDIA-NeMo/RL
Length of output: 13928
Fix CUDA placement, replace assert False, and address unused rank in build_sharded_3d().
The local tensor is created on CPU, but DeviceMesh("cuda", ...) requires a CUDA tensor. Also, assert False is brittle under optimization flags, and rank is unused.
Proposed fix
-def build_sharded_3d(rank: int, world_size: int):
- """
- Build tensor, DTensor, and test the sharding rule for torch.ops.aten.alias.default.
- """
+def build_sharded_3d(_rank: int, world_size: int):
+ """Build tensor/DTensor and validate the sharding rule for `aten.alias.default`."""
from torch.distributed.tensor import DeviceMesh, DTensor
mesh = DeviceMesh("cuda", list(range(world_size)))
global_shape = (4, 2, 4)
tensor = torch.arange(
torch.tensor(global_shape).prod(), dtype=torch.float32
- ).reshape(global_shape)
+ ).reshape(global_shape).cuda()
try:
dtensor = DTensor.from_local(
tensor,
device_mesh=mesh,
)
alias_dtensor = torch.ops.aten.alias.default(dtensor)
- assert False, (
+ raise AssertionError(
"Torch==2.9 should raise 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered', but it didn't."
"You can:\n "
"1. Check is you bump your torch version which contain the fix https://github.com/pytorch/pytorch/pull/166867\n"
"2. If yes, remove patch apply_torch_aten_alias_tensor_patch in nemo_rl/models/policy/workers/patches.py \n"
"3. Remove the patching call in nemo_rl/models/policy/workers/dtensor_policy_worker.py and nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py \n"
"4. Delete this test"
)
except NotImplementedError:
apply_torch_aten_alias_tensor_patch()
alias_dtensor = torch.ops.aten.alias.default(dtensor)
assert alias_dtensor.shape == global_shape
assert torch.allclose(alias_dtensor.to_local().cpu(), tensor.cpu())🧰 Tools
🪛 Ruff (0.14.10)
452-452: Unused function argument: rank
(ARG001)
470-470: Do not assert False (python -O removes these calls), raise AssertionError()
Replace assert False
(B011)
beep boop [🤖]: Hi @RayenTian 👋,
Summary by CodeRabbit
New Features
Bug Fixes
Tests
✏️ Tip: You can customize this high-level summary in your review settings.