@@ -88,18 +88,48 @@ def input_fn():
8888 "dp_shard" : Shard (0 ),
8989 "tp" : Replicate (),
9090 }
91+ # only used if loss parallel is enabled
92+ possible_output_shardings = {
93+ # maps relative to mesh dim names used in torchtitan
94+ "dp_shard" : Shard (0 ),
95+ "tp" : Shard (2 ),
96+ }
9197 assert all (
9298 name in possible_input_shardings for name in world_mesh .mesh_dim_names
9399 ), f"Unsupported mesh dim in world mesh, only { possible_input_shardings .keys ()} are supported by AutoParallel"
94100 x_sharding = tuple (
95101 possible_input_shardings [name ] for name in world_mesh .mesh_dim_names
96102 )
103+ out_sharding = x_sharding
104+ if parallel_dims .loss_parallel_enabled :
105+ out_sharding = tuple (
106+ possible_output_shardings [name ]
107+ for name in world_mesh .mesh_dim_names
108+ if name != "dp_replicate"
109+ )
97110 autop .add_input_constraints ([x_sharding ])
98- autop .add_output_constraints ([x_sharding ])
111+ autop .add_output_constraints ([out_sharding ])
99112 t0 = time .time ()
100113 sharding_placement = autop .optimize_placement ()
101114 t1 = time .time ()
102115 logger .info (f"AutoParallel took { t1 - t0 } seconds" )
103116 parallel_mod = autop .apply_placement (sharding_placement )
104117
118+ if parallel_dims .loss_parallel_enabled :
119+
120+ # current PyTorch's implementation of loss parallel assumes
121+ # that the DTensor has a 1d device mesh. This is not true
122+ # in our case, but we can work around it by adding
123+ # casting the output to a DTensor on a 1d device mesh.
124+ # We should just use AutoParallel to do this for us, but
125+ # it would require putting the loss inside the model as well
126+ def _return_as_dtensor_for_loss_parallel (module , args , output ):
127+ return torch .distributed .tensor .DTensor .from_local (
128+ output , world_mesh ["tp" ], (Shard (2 ),)
129+ )
130+
131+ # not keeping a reference to the hook, don't plan on
132+ # removing it at any point
133+ parallel_mod .register_forward_hook (_return_as_dtensor_for_loss_parallel )
134+
105135 return parallel_mod
0 commit comments