-
Notifications
You must be signed in to change notification settings - Fork 333
cp: fix: patch pytorch aten.alias.default shard strategy (1728) into r0.5.0
#1753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,10 @@ | |
| import os | ||
| from importlib.util import find_spec | ||
|
|
||
| import torch | ||
| from torch.distributed.tensor._ops._tensor_ops import propagate_single_input_strategy | ||
| from torch.distributed.tensor._ops.utils import register_op_strategy | ||
|
|
||
|
Comment on lines
+18
to
+21
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: find . -type f -name "patches.py" | grep -E "nemo_rl.*workers" | head -20Repository: NVIDIA-NeMo/RL Length of output: 100 🏁 Script executed: cat -n nemo_rl/models/policy/workers/patches.py | head -150Repository: NVIDIA-NeMo/RL Length of output: 6194 Guard private torch imports, replace Lines 18–21 import from private Suggested fix@@
-import torch
-from torch.distributed.tensor._ops._tensor_ops import propagate_single_input_strategy
-from torch.distributed.tensor._ops.utils import register_op_strategy
+import torch
+try:
+ # NOTE: These are internal torch APIs; keep guarded and remove when upstream fix is available.
+ from torch.distributed.tensor._ops._tensor_ops import propagate_single_input_strategy
+ from torch.distributed.tensor._ops.utils import register_op_strategy
+except ImportError as e:
+ propagate_single_input_strategy = None # type: ignore[assignment]
+ register_op_strategy = None # type: ignore[assignment]
+ _DTENSOR_OPS_IMPORT_ERROR = e
+else:
+ _DTENSOR_OPS_IMPORT_ERROR = None
@@
def apply_torch_aten_alias_tensor_patch():
"""Register a sharding rule for `torch.ops.aten.alias.default`.
@@
- assert torch.__version__.startswith("2.9.0"), (
- "This patch is needed for torch 2.9.0. Please retest if you upgrade torch to a newer version and remove this patch."
- )
- try:
- register_op_strategy(torch.ops.aten.alias.default)(
- propagate_single_input_strategy
- )
- except Exception as e:
- print(f"Error applying torch.ops.aten.alias.default patch: {e}")
+ if _DTENSOR_OPS_IMPORT_ERROR is not None:
+ raise RuntimeError(
+ "DTensor internal APIs needed for the aten.alias.default sharding patch "
+ "are not available in this torch build."
+ ) from _DTENSOR_OPS_IMPORT_ERROR
+
+ if not torch.__version__.startswith("2.9.0"):
+ # Patch is intended for torch 2.9.x; avoid crashing worker init on other versions.
+ return
+
+ try:
+ register_op_strategy(torch.ops.aten.alias.default)(propagate_single_input_strategy)
+ except RuntimeError as e:
+ raise RuntimeError(
+ "Failed to register sharding strategy for torch.ops.aten.alias.default"
+ ) from e |
||
|
|
||
| def _get_transformer_engine_file(relative_path: str) -> str: | ||
| """Return absolute path to a Transformer Engine file or raise if it cannot be found. | ||
|
|
@@ -104,3 +108,21 @@ def apply_transformer_engine_patch(): | |
|
|
||
| except Exception as e: | ||
| print(f"Error checking/patching transformer_engine: {e}") | ||
|
|
||
|
|
||
| def apply_torch_aten_alias_tensor_patch(): | ||
| """Register a sharding rule for `torch.ops.aten.alias.default`. | ||
|
|
||
| Work around 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered' | ||
| in PyTorch 2.9. See https://github.com/pytorch/pytorch/pull/166867 for the upstream fix. | ||
| We can remove this patch when we upgrade torch to include this fix. | ||
| """ | ||
| assert torch.__version__.startswith("2.9.0"), ( | ||
| "This patch is needed for torch 2.9.0. Please retest if you upgrade torch to a newer version and remove this patch." | ||
| ) | ||
| try: | ||
| register_op_strategy(torch.ops.aten.alias.default)( | ||
| propagate_single_input_strategy | ||
| ) | ||
| except Exception as e: | ||
| print(f"Error applying torch.ops.aten.alias.default patch: {e}") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,9 +18,11 @@ | |
| from unittest.mock import MagicMock, mock_open, patch | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from nemo_rl.models.policy.workers.patches import ( | ||
| _get_transformer_engine_file, | ||
| apply_torch_aten_alias_tensor_patch, | ||
| apply_transformer_engine_patch, | ||
| ) | ||
|
|
||
|
|
@@ -445,3 +447,44 @@ def permutation_kernel(x): | |
| assert captured.out.count("Successfully patched") == 1 | ||
| finally: | ||
| os.unlink(tmp_path) | ||
|
|
||
|
|
||
| def build_sharded_3d(rank: int, world_size: int): | ||
| """ | ||
| Build tensor, DTensor, and test the sharding rule for torch.ops.aten.alias.default. | ||
| """ | ||
| from torch.distributed.tensor import DeviceMesh, DTensor | ||
|
|
||
| mesh = DeviceMesh("cuda", list(range(world_size))) | ||
| global_shape = (4, 2, 4) | ||
| tensor = torch.arange( | ||
| torch.tensor(global_shape).prod(), dtype=torch.float32 | ||
| ).reshape(global_shape) | ||
|
|
||
| try: | ||
| dtensor = DTensor.from_local( | ||
| tensor, | ||
| device_mesh=mesh, | ||
| ) | ||
| alias_dtensor = torch.ops.aten.alias.default(dtensor) | ||
| assert False, ( | ||
| "Torch==2.9 should raise 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered', but it didn't." | ||
| "You can:\n " | ||
| "1. Check is you bump your torch version which contain the fix https://github.com/pytorch/pytorch/pull/166867\n" | ||
| "2. If yes, remove patch apply_torch_aten_alias_tensor_patch in nemo_rl/models/policy/workers/patches.py \n" | ||
| "3. Remove the patching call in nemo_rl/models/policy/workers/dtensor_policy_worker.py and nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py \n" | ||
| "4. Delete this test" | ||
| ) | ||
| except NotImplementedError: | ||
| apply_torch_aten_alias_tensor_patch() | ||
| alias_dtensor = torch.ops.aten.alias.default(dtensor) | ||
| assert alias_dtensor.shape == global_shape | ||
| assert torch.allclose(alias_dtensor.to_local().cpu(), tensor.cpu()) | ||
|
|
||
|
Comment on lines
+452
to
+483
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: # First, locate and examine the actual test file
fd -t f test_patches.py -x wc -l {}Repository: NVIDIA-NeMo/RL Length of output: 104 🏁 Script executed: # Read the specific lines mentioned in the review
sed -n '452,490p' tests/unit/models/policy/test_patches.pyRepository: NVIDIA-NeMo/RL Length of output: 1838 🏁 Script executed: # Check the imports at the top of the file to understand available tools
head -50 tests/unit/models/policy/test_patches.pyRepository: NVIDIA-NeMo/RL Length of output: 1920 🏁 Script executed: # Search for other uses of DTensor in the codebase to understand best practices
rg "DTensor\.from_local" -A 3 -B 1Repository: NVIDIA-NeMo/RL Length of output: 13928 Fix CUDA placement, replace The local tensor is created on CPU, but Proposed fix-def build_sharded_3d(rank: int, world_size: int):
- """
- Build tensor, DTensor, and test the sharding rule for torch.ops.aten.alias.default.
- """
+def build_sharded_3d(_rank: int, world_size: int):
+ """Build tensor/DTensor and validate the sharding rule for `aten.alias.default`."""
from torch.distributed.tensor import DeviceMesh, DTensor
mesh = DeviceMesh("cuda", list(range(world_size)))
global_shape = (4, 2, 4)
tensor = torch.arange(
torch.tensor(global_shape).prod(), dtype=torch.float32
- ).reshape(global_shape)
+ ).reshape(global_shape).cuda()
try:
dtensor = DTensor.from_local(
tensor,
device_mesh=mesh,
)
alias_dtensor = torch.ops.aten.alias.default(dtensor)
- assert False, (
+ raise AssertionError(
"Torch==2.9 should raise 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered', but it didn't."
"You can:\n "
"1. Check is you bump your torch version which contain the fix https://github.com/pytorch/pytorch/pull/166867\n"
"2. If yes, remove patch apply_torch_aten_alias_tensor_patch in nemo_rl/models/policy/workers/patches.py \n"
"3. Remove the patching call in nemo_rl/models/policy/workers/dtensor_policy_worker.py and nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py \n"
"4. Delete this test"
)
except NotImplementedError:
apply_torch_aten_alias_tensor_patch()
alias_dtensor = torch.ops.aten.alias.default(dtensor)
assert alias_dtensor.shape == global_shape
assert torch.allclose(alias_dtensor.to_local().cpu(), tensor.cpu())🧰 Tools🪛 Ruff (0.14.10)452-452: Unused function argument: (ARG001) 470-470: Do not Replace (B011) |
||
|
|
||
| @pytest.mark.cuda | ||
| def test_aten_alias_sharding_still_missing_upstream( | ||
| distributed_test_runner, world_size=2 | ||
| ): | ||
| """Test that sharding rule for aten.alias.default is still missing upstream.""" | ||
| distributed_test_runner(build_sharded_3d, world_size=world_size) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
head -170 nemo_rl/models/policy/workers/dtensor_policy_worker.py | tail -30Repository: NVIDIA-NeMo/RL
Length of output: 1191
🏁 Script executed:
sed -n '155,175p' nemo_rl/models/policy/workers/dtensor_policy_worker.pyRepository: NVIDIA-NeMo/RL
Length of output: 995
🏁 Script executed:
sed -n '75,85p' nemo_rl/models/policy/workers/dtensor_policy_worker.pyRepository: NVIDIA-NeMo/RL
Length of output: 498
Move
__init__docstring before the patch call (docstring must be the first statement).The
"""Initialize the DTensorPolicyWorker."""string is no longer recognized as a docstring becauseapply_torch_aten_alias_tensor_patch()is called before it. In Python, docstrings must be the first statement in a function or method body to be recognized as such. Move the docstring to the first line of the method and place the patch call after it.Proposed fix
def __init__( self, config: PolicyConfig, tokenizer: AutoTokenizer, processor: Optional[AutoProcessor] = None, weights_path: Optional[str] = None, optimizer_path: Optional[str] = None, init_optimizer: bool = True, init_reference_model: bool = True, **kwargs: Any, ): + """Initialize the DTensorPolicyWorker.""" + # Apply patch to work around 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered' apply_torch_aten_alias_tensor_patch() - - """Initialize the DTensorPolicyWorker.""" self.tokenizer = tokenizer🤖 Prompt for AI Agents