diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 8571e5680c..8e10785bdf 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -739,6 +739,11 @@ class Experimental: enable_simplefsdp_passes: bool = False + enable_inductor_aten_fx_overlap_scheduler: bool = False + enable_inductor_aten_fx_overlap_scheduler_bucketing: bool = False + enable_autoparallel_asynctp: bool = False + + @dataclass class Validation: enable: bool = False diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index 6648f29ab8..6851ecbc4a 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -10,7 +10,6 @@ from autoparallel.api import AutoParallel -from torch.distributed import DeviceMesh from torch.distributed.fsdp import MixedPrecisionPolicy from torch.distributed.tensor.placement_types import Replicate, Shard @@ -33,6 +32,7 @@ def parallelize_llama( the model must fit on GPU or CPU memory. """ world_mesh = parallel_dims.world_mesh + def input_fn(): global_batch_size = job_config.training.global_batch_size if global_batch_size < 0: @@ -62,6 +62,52 @@ def input_fn(): lambda bucket_idx: 1000 / parallel_dims.tp ) + enable_overlap_scheduling = ( + job_config.experimental.enable_inductor_aten_fx_overlap_scheduler + ) + enable_overlap_scheduling_bucketing = ( + job_config.experimental.enable_inductor_aten_fx_overlap_scheduler_bucketing + ) + if enable_overlap_scheduling_bucketing: + assert ( + enable_overlap_scheduling + ), "bucketing can not be used without overlap scheduling" + + if enable_overlap_scheduling: + from torch._inductor.fx_passes.overlap_scheduling import OverlapScheduler + + torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = ( + enable_overlap_scheduling_bucketing + ) + + def _overlap_bucketing_pass(graph): + overlap_scheduler = OverlapScheduler(graph.owning_module) + overlap_scheduler.run() + + torch._inductor.config.post_grad_custom_post_pass = _overlap_bucketing_pass + + enable_asynctp = job_config.experimental.enable_autoparallel_asynctp + if enable_asynctp: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + assert "tp" in world_mesh.mesh_dim_names + enable_symm_mem_for_group(world_mesh["tp"].get_group().group_name) + torch._inductor.config._micro_pipeline_tp = False + # Disable inductor AsyncTP passes, in favor of using Autoparallel passes fork. + from autoparallel.asynctp import micro_pipeline_tp_pass + + existing_post_grad_custom_post_pass = ( + torch._inductor.config.post_grad_custom_post_pass + ) + + def _pass(graph): + if existing_post_grad_custom_post_pass is not None: + existing_post_grad_custom_post_pass(graph) + + micro_pipeline_tp_pass(graph, None) + + torch._inductor.config.post_grad_custom_post_pass = _pass + # bail out # model = model_fn() # return model @@ -101,7 +147,8 @@ def input_fn(): ) out_sharding = x_sharding loss_parallel_enabled = ( - parallel_dims.tp_enabled and not job_config.parallelism.disable_loss_parallel + parallel_dims.tp_enabled + and not job_config.parallelism.disable_loss_parallel ) if loss_parallel_enabled: out_sharding = tuple(