Skip to content

Commit 4712163

Browse files
authored
Fix bucket sizes for AutoParallel 1D (#1545)
This PR makes bucket sizes for all-gather and reduce-scatter to be of the same size for 1d FSDP.
1 parent 6c782eb commit 4712163

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

torchtitan/experiments/auto_parallel/parallelize_llama.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,28 @@ def input_fn():
4141
# step.
4242
dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard
4343
global_batch_size = job_config.training.local_batch_size * dp_degree
44-
return torch.randint(
45-
0,
46-
# job_config.training.vocab_size,
47-
model.vocab_size,
48-
(global_batch_size, job_config.training.seq_len),
49-
device=torch.device("cuda"),
50-
),
44+
return (
45+
torch.randint(
46+
0,
47+
# job_config.training.vocab_size,
48+
model.vocab_size,
49+
(global_batch_size, job_config.training.seq_len),
50+
device=torch.device("cuda"),
51+
),
52+
)
5153

5254
# TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP
5355
assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet"
5456
assert parallel_dims.cp_enabled is False, "CP not supported yet"
5557
assert parallel_dims.pp_enabled is False, "PP not supported yet"
5658

59+
torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = (
60+
lambda bucket_idx: 500 / parallel_dims.tp
61+
)
62+
torch._inductor.config.bucket_reduce_scatters_fx_bucket_size_determinator = (
63+
lambda bucket_idx: 1000 / parallel_dims.tp
64+
)
65+
5766
# bail out
5867
# model = model_fn()
5968
# return model
@@ -64,7 +73,13 @@ def input_fn():
6473
param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param]
6574
reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce]
6675
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
67-
with AutoParallel(model, input_fn, world_mesh, mp_policy=mp_policy, compile=job_config.training.compile) as autop:
76+
with AutoParallel(
77+
model,
78+
input_fn,
79+
world_mesh,
80+
mp_policy=mp_policy,
81+
compile=job_config.training.compile,
82+
) as autop:
6883
autop.add_parameter_memory_constraint(low=None, high=None)
6984

7085
possible_input_shardings = {

0 commit comments

Comments
 (0)