From 027be201c1fe311c7ba4a62dc1a8caf7e4bcad1f Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 8 Aug 2025 02:43:27 -0700 Subject: [PATCH 1/2] Add support for loss parallel IMO we should just add the loss in the model and let autoparallel parallelize it for us --- .../auto_parallel/parallelize_llama.py | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index 2d3a3e2e2c..b641b67080 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -73,18 +73,49 @@ def input_fn(): "dp_shard": Shard(0), "tp": Replicate(), } + # only used if loss parallel is enabled + possible_output_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_replicate": Shard(0), + "dp_shard": Shard(0), + "tp": Shard(2), + } assert all( name in possible_input_shardings for name in world_mesh.mesh_dim_names ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" x_sharding = tuple( possible_input_shardings[name] for name in world_mesh.mesh_dim_names ) + out_sharding = x_sharding + if parallel_dims.loss_parallel_enabled: + out_sharding = tuple( + possible_output_shardings[name] + for name in world_mesh.mesh_dim_names + if name != "dp_replicate" + ) autop.add_input_constraints([x_sharding]) - autop.add_output_constraints([x_sharding]) + autop.add_output_constraints([out_sharding]) t0 = time.time() sharding_placement = autop.optimize_placement() t1 = time.time() logger.info(f"AutoParallel took {t1 - t0} seconds") parallel_mod = autop.apply_placement(sharding_placement) + if parallel_dims.loss_parallel_enabled: + + # current PyTorch's implementation of loss parallel assumes + # that the DTensor has a 1d device mesh. This is not true + # in our case, but we can work around it by adding + # casting the output to a DTensor on a 1d device mesh. + # We should just use AutoParallel to do this for us, but + # it would require putting the loss inside the model as well + def _return_as_dtensor_for_loss_parallel(module, args, output): + return torch.distributed.tensor.DTensor.from_local( + output, world_mesh["tp"], (Shard(2),) + ) + + # not keeping a reference to the hook, don't plan on + # removing it at any point + parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel) + return parallel_mod From 09aeafa4efb9094379f5da3f42d91c95f8ddb122 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Sun, 10 Aug 2025 10:11:12 -0700 Subject: [PATCH 2/2] Address review feedback --- torchtitan/experiments/auto_parallel/parallelize_llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index 29b253a9d7..49a8bc49ff 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -91,7 +91,6 @@ def input_fn(): # only used if loss parallel is enabled possible_output_shardings = { # maps relative to mesh dim names used in torchtitan - "dp_replicate": Shard(0), "dp_shard": Shard(0), "tp": Shard(2), }