Skip to content

Commit 3f04d22

Browse files
authored
Add support for loss parallel (#1546)
IMO we should just add the loss in the model and let autoparallel parallelize it for us. But for now, let's follow how the other models are implemented
1 parent 4712163 commit 3f04d22

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

torchtitan/experiments/auto_parallel/parallelize_llama.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,18 +88,48 @@ def input_fn():
8888
"dp_shard": Shard(0),
8989
"tp": Replicate(),
9090
}
91+
# only used if loss parallel is enabled
92+
possible_output_shardings = {
93+
# maps relative to mesh dim names used in torchtitan
94+
"dp_shard": Shard(0),
95+
"tp": Shard(2),
96+
}
9197
assert all(
9298
name in possible_input_shardings for name in world_mesh.mesh_dim_names
9399
), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel"
94100
x_sharding = tuple(
95101
possible_input_shardings[name] for name in world_mesh.mesh_dim_names
96102
)
103+
out_sharding = x_sharding
104+
if parallel_dims.loss_parallel_enabled:
105+
out_sharding = tuple(
106+
possible_output_shardings[name]
107+
for name in world_mesh.mesh_dim_names
108+
if name != "dp_replicate"
109+
)
97110
autop.add_input_constraints([x_sharding])
98-
autop.add_output_constraints([x_sharding])
111+
autop.add_output_constraints([out_sharding])
99112
t0 = time.time()
100113
sharding_placement = autop.optimize_placement()
101114
t1 = time.time()
102115
logger.info(f"AutoParallel took {t1 - t0} seconds")
103116
parallel_mod = autop.apply_placement(sharding_placement)
104117

118+
if parallel_dims.loss_parallel_enabled:
119+
120+
# current PyTorch's implementation of loss parallel assumes
121+
# that the DTensor has a 1d device mesh. This is not true
122+
# in our case, but we can work around it by adding
123+
# casting the output to a DTensor on a 1d device mesh.
124+
# We should just use AutoParallel to do this for us, but
125+
# it would require putting the loss inside the model as well
126+
def _return_as_dtensor_for_loss_parallel(module, args, output):
127+
return torch.distributed.tensor.DTensor.from_local(
128+
output, world_mesh["tp"], (Shard(2),)
129+
)
130+
131+
# not keeping a reference to the hook, don't plan on
132+
# removing it at any point
133+
parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel)
134+
105135
return parallel_mod

0 commit comments

Comments
 (0)