-
Notifications
You must be signed in to change notification settings - Fork 599
[SimpleFSDP] add manual bucketing pass #1881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,48 +9,129 @@ | |
| import torch | ||
| import torch._functorch.config as functorch_config | ||
|
|
||
| from .job_config import Compile as CompileConfig | ||
|
|
||
| from .reshard_after_forward import annotate_fsdp_all_gather | ||
|
|
||
|
|
||
| def get_compile_backend( | ||
| backend_name: str, fsdp_reshard_after_forward: bool | ||
| def get_compile_backend_and_passes( | ||
| compile_config: CompileConfig, | ||
| fsdp_reshard_after_forward: bool, | ||
| fsdp_buckets: list[list[str] | str], | ||
| ) -> callable: | ||
| # return the compile backends used in SimpleFSDP training | ||
| # Step1: check if backend_name is inside available torch.compile backends | ||
| # Step2: check if the backend_name has been registered as a customized backend | ||
| available_torch_backend = torch._dynamo.list_backends(exclude_tags=()) | ||
|
|
||
| if backend_name in available_torch_backend: | ||
| backend = torch._dynamo.lookup_backend(backend_name) | ||
| elif backend_name == "aot_eager_autobucketing": | ||
| # Perform auto optimization in aten fx-level and execute code in aot_eager backend | ||
| # The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960 | ||
| from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend | ||
| """ | ||
| Apply compile backend and additional graph passes. | ||
| Args: | ||
| compile_config: compile configs to apply torch.compile. | ||
| fsdp_reshard_after_forward: whether to enable reshard_after_forward in SimpleFSDP, | ||
| which is implemented via a customized AC graph pass. | ||
| fsdp_buckets: used in transformer_block_bucketing to define which modules should be bucketed. | ||
| Returns: | ||
| compile backend with applied graph passes. | ||
| """ | ||
| backend = torch._dynamo.lookup_backend(compile_config.backend) | ||
|
|
||
| # Apply bucketing and overlapping pass on fwd and bwd graph separately | ||
| if compile_config.compiler_passes == "auto_bucketing": | ||
| # Perform auto optimization in aten fx-level and execute code in aot_eager/inductor backend | ||
| # The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960 | ||
| from torch._inductor.config import aten_distributed_optimizations as dist_opts | ||
| from torch._inductor.fx_passes.overlap_scheduling import ( | ||
| schedule_overlap_bucketing, | ||
| ) | ||
|
|
||
| dist_opts.collective_bucketing = True | ||
| dist_opts.insert_overlap_deps = False | ||
| torch._inductor.config.allow_buffer_reuse = False | ||
|
|
||
| def aten_autobucketing_reordering_pass( | ||
| gm: torch.fx.GraphModule, example_inputs: Any | ||
| ) -> torch.fx.GraphModule: | ||
| schedule_overlap_bucketing(gm) | ||
| gm.recompile() | ||
| return gm | ||
|
|
||
| backend = aot_autograd_backend( | ||
| fw_compiler=aten_autobucketing_reordering_pass, | ||
| bw_compiler=aten_autobucketing_reordering_pass, | ||
| keep_inference_input_mutations=True, | ||
| if compile_config.backend == "aot_eager": | ||
| from torch._dynamo.backends.common import ( | ||
| aot_autograd as aot_autograd_backend, | ||
| ) | ||
|
|
||
| def aot_eager_autobucketing_reordering_pass( | ||
| gm: torch.fx.GraphModule, example_inputs: Any | ||
| ) -> torch.fx.GraphModule: | ||
| schedule_overlap_bucketing(gm) | ||
| gm.recompile() | ||
| return gm | ||
|
|
||
| dist_opts.insert_overlap_deps = False | ||
| backend = aot_autograd_backend( | ||
| fw_compiler=aot_eager_autobucketing_reordering_pass, | ||
| bw_compiler=aot_eager_autobucketing_reordering_pass, | ||
| keep_inference_input_mutations=True, | ||
| ) | ||
| elif compile_config.backend == "inductor": | ||
|
|
||
| def inductor_autobucketing_reordering_pass( | ||
| gm: torch.fx.Graph, | ||
| ) -> torch.fx.GraphModule: | ||
| return schedule_overlap_bucketing(gm.owning_module) | ||
|
|
||
| dist_opts.insert_overlap_deps = True | ||
| torch._inductor.config.reorder_for_peak_memory = False | ||
| torch._inductor.config.reorder_for_compute_comm_overlap = False | ||
| torch._inductor.config.post_grad_custom_post_pass = ( | ||
| inductor_autobucketing_reordering_pass | ||
| ) | ||
| else: | ||
| raise ValueError( | ||
| f"Unsupported backend {compile_config.backend} for auto_bucketing pass" | ||
| ) | ||
|
|
||
| elif compile_config.compiler_passes == "transformer_block_bucketing": | ||
| # Perform manual optimization in aten fx-level and execute code in aot_eager/inductor backend | ||
| # The manualbucketing logic is here: https://github.com/pytorch/pytorch/pull/165487 | ||
| from functools import partial | ||
|
|
||
| from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend | ||
| from torch._inductor.fx_passes.overlap_manual_scheduling import ( | ||
| manual_overlap_bucketing, | ||
| ) | ||
|
|
||
| torch._inductor.config.allow_buffer_reuse = False | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. aren't we doing passes in fx graph / aot_eager backend? why it has anything to do with inductor? In fact, I have this confusion for all other
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the passes live in |
||
| manual_overlap_bucketing = partial( | ||
| manual_overlap_bucketing, | ||
| module_bucket_plans=fsdp_buckets, | ||
| ) | ||
|
|
||
| if compile_config.backend == "aot_eager": | ||
|
|
||
| def aot_eager_transformer_block_bucketing_reordering_pass( | ||
| gm: torch.fx.GraphModule, example_inputs: Any | ||
| ) -> torch.fx.GraphModule: | ||
| manual_overlap_bucketing(gm, insert_overlap_deps=False) | ||
| return gm | ||
|
|
||
| backend = aot_autograd_backend( | ||
| fw_compiler=aot_eager_transformer_block_bucketing_reordering_pass, | ||
| bw_compiler=aot_eager_transformer_block_bucketing_reordering_pass, | ||
| keep_inference_input_mutations=True, | ||
| ) | ||
| elif compile_config.backend == "inductor": | ||
|
|
||
| def inductor_transformer_block_bucketing_reordering_pass( | ||
| gm: torch.fx.Graph, | ||
| ) -> torch.fx.GraphModule: | ||
| return manual_overlap_bucketing( | ||
| gm.owning_module, insert_overlap_deps=True | ||
| ) | ||
|
|
||
| torch._inductor.config.reorder_for_peak_memory = False | ||
| torch._inductor.config.reorder_for_compute_comm_overlap = False | ||
| torch._inductor.config.post_grad_custom_post_pass = ( | ||
| inductor_transformer_block_bucketing_reordering_pass | ||
| ) | ||
| else: | ||
| raise ValueError( | ||
| f"Unsupported backend {compile_config.backend} for transformer_block_bucketing pass" | ||
| ) | ||
| else: | ||
| raise AssertionError(f"Unsupported customized backend: {backend_name}") | ||
| raise AssertionError( | ||
| f"Unsupported customized pass: {compile_config.compiler_passes}" | ||
| ) | ||
|
|
||
| # Apply activation checkpointing on joint graph before partitioner | ||
| def joint_ac_pass( | ||
| gm: torch.fx.GraphModule, example_inputs: Any | ||
| ) -> torch.fx.GraphModule: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens by default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In bucketing, we shouldn't allow buffer reuse; otherwise newly created comm copy-in/copy-out buffers will reuse prev buffer, which messed up the copied out data value and made the loss nan.