File tree Expand file tree Collapse file tree 1 file changed +2
-14
lines changed
torchtitan/experiments/simple_fsdp Expand file tree Collapse file tree 1 file changed +2
-14
lines changed Original file line number Diff line number Diff line change 1919 Replicate ,
2020 Shard ,
2121)
22- from torch .distributed .device_mesh import _mesh_resources , DeviceMesh
22+ from torch .distributed .device_mesh import DeviceMesh
2323from torch .distributed .tensor ._dtensor_spec import DTensorSpec
2424from torch .distributed .tensor ._redistribute import redistribute_local_tensor
2525from torch .distributed .tensor .placement_types import _StridedShard , Placement
@@ -95,19 +95,7 @@ def _distribute_dtensor(
9595 """
9696 inner_spec = tensor ._spec
9797 outer_mesh , inner_mesh = device_mesh , inner_spec .mesh
98- outer_global_mesh = _mesh_resources .get_root_mesh (outer_mesh )
99- inner_global_mesh = _mesh_resources .get_root_mesh (inner_mesh )
100- if outer_global_mesh != inner_global_mesh or (
101- outer_global_mesh is None or inner_global_mesh is None
102- ):
103- raise AssertionError (
104- "Cannot distribute tensor across two meshes without the same root mesh: \n "
105- f"outer global mesh: { outer_global_mesh } \n inner global mesh: { inner_global_mesh } "
106- )
107- assert outer_mesh .mesh_dim_names is not None
108- assert inner_mesh .mesh_dim_names is not None
109- submesh_names = outer_mesh .mesh_dim_names + inner_mesh .mesh_dim_names
110- spanned_mesh = outer_global_mesh [submesh_names ]
98+ spanned_mesh = DeviceMesh ._concatenate ([outer_mesh , inner_mesh ])
11199
112100 if len (dp_placements ) == 1 :
113101 assert dp_placements [0 ].is_replicate () or dp_placements [0 ].is_shard ()
You can’t perform that action at this time.
0 commit comments