@@ -41,19 +41,28 @@ def input_fn():
4141 # step.
4242 dp_degree = parallel_dims .dp_replicate * parallel_dims .dp_shard
4343 global_batch_size = job_config .training .local_batch_size * dp_degree
44- return torch .randint (
45- 0 ,
46- # job_config.training.vocab_size,
47- model .vocab_size ,
48- (global_batch_size , job_config .training .seq_len ),
49- device = torch .device ("cuda" ),
50- ),
44+ return (
45+ torch .randint (
46+ 0 ,
47+ # job_config.training.vocab_size,
48+ model .vocab_size ,
49+ (global_batch_size , job_config .training .seq_len ),
50+ device = torch .device ("cuda" ),
51+ ),
52+ )
5153
5254 # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP
5355 assert parallel_dims .dp_replicate_enabled is False , "DDP not supported yet"
5456 assert parallel_dims .cp_enabled is False , "CP not supported yet"
5557 assert parallel_dims .pp_enabled is False , "PP not supported yet"
5658
59+ torch ._inductor .config .bucket_all_gathers_fx_bucket_size_determinator = (
60+ lambda bucket_idx : 500 / parallel_dims .tp
61+ )
62+ torch ._inductor .config .bucket_reduce_scatters_fx_bucket_size_determinator = (
63+ lambda bucket_idx : 1000 / parallel_dims .tp
64+ )
65+
5766 # bail out
5867 # model = model_fn()
5968 # return model
@@ -64,7 +73,13 @@ def input_fn():
6473 param_dtype = TORCH_DTYPE_MAP [job_config .training .mixed_precision_param ]
6574 reduce_dtype = TORCH_DTYPE_MAP [job_config .training .mixed_precision_reduce ]
6675 mp_policy = MixedPrecisionPolicy (param_dtype = param_dtype , reduce_dtype = reduce_dtype )
67- with AutoParallel (model , input_fn , world_mesh , mp_policy = mp_policy , compile = job_config .training .compile ) as autop :
76+ with AutoParallel (
77+ model ,
78+ input_fn ,
79+ world_mesh ,
80+ mp_policy = mp_policy ,
81+ compile = job_config .training .compile ,
82+ ) as autop :
6883 autop .add_parameter_memory_constraint (low = None , high = None )
6984
7085 possible_input_shardings = {
0 commit comments