diff --git a/torchtitan/experiments/simple_fsdp/backend.py b/torchtitan/experiments/simple_fsdp/backend.py index 55fb099afb..36abe4ad0b 100644 --- a/torchtitan/experiments/simple_fsdp/backend.py +++ b/torchtitan/experiments/simple_fsdp/backend.py @@ -21,12 +21,14 @@ def get_compile_backend(backend_name: str) -> Union[str, callable]: # Perform auto optimization in aten fx-level and execute code in aot_eager backend # The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960 from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend + + from torch._inductor.config import aten_distributed_optimizations as dist_opts from torch._inductor.fx_passes.overlap_scheduling import ( schedule_overlap_bucketing, ) - torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = True - torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps = False + dist_opts.collective_bucketing = True + dist_opts.insert_overlap_deps = False torch._inductor.config.allow_buffer_reuse = False def aten_autobucketing_reordering_pass(