Skip to content

Commit a0fc25b

Browse files
feginjquesnelle
authored andcommitted
Fix how SimpleFSDP get the nD mesh (pytorch#1959)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * pytorch#1960 * __->__ pytorch#1959 This is the recommended way to get the nD mesh now that DeviceMesh has _concatenate(). **Squash and Merge button won't work for this PR. I'll merge by myself.**
1 parent 9c54cd4 commit a0fc25b

File tree

1 file changed

+2
-14
lines changed

1 file changed

+2
-14
lines changed

torchtitan/experiments/simple_fsdp/simple_fsdp.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
Replicate,
2020
Shard,
2121
)
22-
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
22+
from torch.distributed.device_mesh import DeviceMesh
2323
from torch.distributed.tensor._dtensor_spec import DTensorSpec
2424
from torch.distributed.tensor._redistribute import redistribute_local_tensor
2525
from torch.distributed.tensor.placement_types import _StridedShard, Placement
@@ -95,19 +95,7 @@ def _distribute_dtensor(
9595
"""
9696
inner_spec = tensor._spec
9797
outer_mesh, inner_mesh = device_mesh, inner_spec.mesh
98-
outer_global_mesh = _mesh_resources.get_root_mesh(outer_mesh)
99-
inner_global_mesh = _mesh_resources.get_root_mesh(inner_mesh)
100-
if outer_global_mesh != inner_global_mesh or (
101-
outer_global_mesh is None or inner_global_mesh is None
102-
):
103-
raise AssertionError(
104-
"Cannot distribute tensor across two meshes without the same root mesh: \n"
105-
f"outer global mesh: {outer_global_mesh}\ninner global mesh: {inner_global_mesh}"
106-
)
107-
assert outer_mesh.mesh_dim_names is not None
108-
assert inner_mesh.mesh_dim_names is not None
109-
submesh_names = outer_mesh.mesh_dim_names + inner_mesh.mesh_dim_names
110-
spanned_mesh = outer_global_mesh[submesh_names]
98+
spanned_mesh = DeviceMesh._concatenate([outer_mesh, inner_mesh])
11199

112100
if len(dp_placements) == 1:
113101
assert dp_placements[0].is_replicate() or dp_placements[0].is_shard()

0 commit comments

Comments
 (0)