Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions torchtitan/experiments/simple_fsdp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,16 @@ SimpleFSDP relies on compiler backend to perform optimizations (i.e., bucketing
1. no optimization: default torch.compile backends (e.g., "inductor", "aot_eager", "eager")

2. auto optimization: perform auto-bucketing & reordering without user inputs. **Note: it is not guaranteed that users will get the most optimized training performance**
- "aot_eager_autobucketing": perform autobucketing at aten fx-level, and perform code execution with aot_eager backend.


users can specify the pass (e.g., "aot_eager_autobucketing") via additional configs:

```bash
--compile.model_backend_override "aot_eager_autobucketing"
```
- "auto_bucketing": perform autobucketing at aten fx-level, and perform code execution with aot_eager backend. (We also support `inductor` backend).
```bash
--compile.backend "aot_eager" --compile.compiler_passes "auto_bucketing"
```

3. manual optimization: perform manual bucketing & reordering with user FQN inputs.
- "transformer_block_bucketing": perform manual bucketing at aten fx-level, and perform code execution with aot_eager backend. (We also support `inductor` backend).
```bash
--compile.backend "aot_eager" --compile.compiler_passes "transformer_block_bucketing"
```

### Citation

Expand Down
133 changes: 107 additions & 26 deletions torchtitan/experiments/simple_fsdp/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,48 +9,129 @@
import torch
import torch._functorch.config as functorch_config

from .job_config import Compile as CompileConfig

from .reshard_after_forward import annotate_fsdp_all_gather


def get_compile_backend(
backend_name: str, fsdp_reshard_after_forward: bool
def get_compile_backend_and_passes(
compile_config: CompileConfig,
fsdp_reshard_after_forward: bool,
fsdp_buckets: list[list[str] | str],
) -> callable:
# return the compile backends used in SimpleFSDP training
# Step1: check if backend_name is inside available torch.compile backends
# Step2: check if the backend_name has been registered as a customized backend
available_torch_backend = torch._dynamo.list_backends(exclude_tags=())

if backend_name in available_torch_backend:
backend = torch._dynamo.lookup_backend(backend_name)
elif backend_name == "aot_eager_autobucketing":
# Perform auto optimization in aten fx-level and execute code in aot_eager backend
# The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960
from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend
"""
Apply compile backend and additional graph passes.
Args:
compile_config: compile configs to apply torch.compile.
fsdp_reshard_after_forward: whether to enable reshard_after_forward in SimpleFSDP,
which is implemented via a customized AC graph pass.
fsdp_buckets: used in transformer_block_bucketing to define which modules should be bucketed.
Returns:
compile backend with applied graph passes.
"""
backend = torch._dynamo.lookup_backend(compile_config.backend)

# Apply bucketing and overlapping pass on fwd and bwd graph separately
if compile_config.compiler_passes == "auto_bucketing":
# Perform auto optimization in aten fx-level and execute code in aot_eager/inductor backend
# The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960
from torch._inductor.config import aten_distributed_optimizations as dist_opts
from torch._inductor.fx_passes.overlap_scheduling import (
schedule_overlap_bucketing,
)

dist_opts.collective_bucketing = True
dist_opts.insert_overlap_deps = False
torch._inductor.config.allow_buffer_reuse = False

def aten_autobucketing_reordering_pass(
gm: torch.fx.GraphModule, example_inputs: Any
) -> torch.fx.GraphModule:
schedule_overlap_bucketing(gm)
gm.recompile()
return gm

backend = aot_autograd_backend(
fw_compiler=aten_autobucketing_reordering_pass,
bw_compiler=aten_autobucketing_reordering_pass,
keep_inference_input_mutations=True,
if compile_config.backend == "aot_eager":
from torch._dynamo.backends.common import (
aot_autograd as aot_autograd_backend,
)

def aot_eager_autobucketing_reordering_pass(
gm: torch.fx.GraphModule, example_inputs: Any
) -> torch.fx.GraphModule:
schedule_overlap_bucketing(gm)
gm.recompile()
return gm

dist_opts.insert_overlap_deps = False
backend = aot_autograd_backend(
fw_compiler=aot_eager_autobucketing_reordering_pass,
bw_compiler=aot_eager_autobucketing_reordering_pass,
keep_inference_input_mutations=True,
)
elif compile_config.backend == "inductor":

def inductor_autobucketing_reordering_pass(
gm: torch.fx.Graph,
) -> torch.fx.GraphModule:
return schedule_overlap_bucketing(gm.owning_module)

dist_opts.insert_overlap_deps = True
torch._inductor.config.reorder_for_peak_memory = False
torch._inductor.config.reorder_for_compute_comm_overlap = False
torch._inductor.config.post_grad_custom_post_pass = (
inductor_autobucketing_reordering_pass
)
else:
raise ValueError(
f"Unsupported backend {compile_config.backend} for auto_bucketing pass"
)

elif compile_config.compiler_passes == "transformer_block_bucketing":
# Perform manual optimization in aten fx-level and execute code in aot_eager/inductor backend
# The manualbucketing logic is here: https://github.com/pytorch/pytorch/pull/165487
from functools import partial

from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend
from torch._inductor.fx_passes.overlap_manual_scheduling import (
manual_overlap_bucketing,
)

torch._inductor.config.allow_buffer_reuse = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens by default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In bucketing, we shouldn't allow buffer reuse; otherwise newly created comm copy-in/copy-out buffers will reuse prev buffer, which messed up the copied out data value and made the loss nan.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aren't we doing passes in fx graph / aot_eager backend? why it has anything to do with inductor?

In fact, I have this confusion for all other torch._inductor fields.

Copy link
Contributor Author

@ruisizhang123 ruisizhang123 Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the passes live in torch/_inductor/fx_passes/ folder. It is a bit counter-intuitive that fx graph passes lives under _inductor..... But because of some legacy reasons that the pass is originally post-grad passes in inductor instead of for aot_eager fx pass. That's why you see these configs have torch._inductor fields -- They are controlling the pass via inductor's config.

manual_overlap_bucketing = partial(
manual_overlap_bucketing,
module_bucket_plans=fsdp_buckets,
)

if compile_config.backend == "aot_eager":

def aot_eager_transformer_block_bucketing_reordering_pass(
gm: torch.fx.GraphModule, example_inputs: Any
) -> torch.fx.GraphModule:
manual_overlap_bucketing(gm, insert_overlap_deps=False)
return gm

backend = aot_autograd_backend(
fw_compiler=aot_eager_transformer_block_bucketing_reordering_pass,
bw_compiler=aot_eager_transformer_block_bucketing_reordering_pass,
keep_inference_input_mutations=True,
)
elif compile_config.backend == "inductor":

def inductor_transformer_block_bucketing_reordering_pass(
gm: torch.fx.Graph,
) -> torch.fx.GraphModule:
return manual_overlap_bucketing(
gm.owning_module, insert_overlap_deps=True
)

torch._inductor.config.reorder_for_peak_memory = False
torch._inductor.config.reorder_for_compute_comm_overlap = False
torch._inductor.config.post_grad_custom_post_pass = (
inductor_transformer_block_bucketing_reordering_pass
)
else:
raise ValueError(
f"Unsupported backend {compile_config.backend} for transformer_block_bucketing pass"
)
else:
raise AssertionError(f"Unsupported customized backend: {backend_name}")
raise AssertionError(
f"Unsupported customized pass: {compile_config.compiler_passes}"
)

# Apply activation checkpointing on joint graph before partitioner
def joint_ac_pass(
gm: torch.fx.GraphModule, example_inputs: Any
) -> torch.fx.GraphModule:
Expand Down
34 changes: 29 additions & 5 deletions torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,35 @@
)
from torchtitan.tools.logging import logger

from ..backend import get_compile_backend
from ..backend import get_compile_backend_and_passes

from ..simple_fsdp import data_parallel, MixedPrecisionPolicy


def get_fsdp_buckets(model) -> list[list[str] | str]:
module_list = [
model.tok_embeddings,
[model.norm, model.output],
]
for layer_id, transformer_block in model.layers.items():
# [TODO](ruisizhang123) add EP support for transformer block bucketing
module_list.append(transformer_block)

def convert_modules_to_fqns(modules, module_to_fqn_mapping):
"""Convert a (possibly nested) list of modules to FQN strings."""
result = []
for m in modules:
if isinstance(m, list):
result.append(convert_modules_to_fqns(m, module_to_fqn_mapping))
else:
result.append(module_to_fqn_mapping.get(m, None))
return result

module_to_name = {m: n for n, m in model.named_modules()}
module_fqns = convert_modules_to_fqns(module_list, module_to_name)
return module_fqns


# Adapted from llama4/infra/parallelize.py
def parallelize_deepseekv3(
model: nn.Module,
Expand Down Expand Up @@ -177,13 +202,12 @@ def parallelize_deepseekv3(
f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}."
)

backend = (
getattr(job_config.compile, "model_backend_override", None)
or job_config.compile.backend
backend = get_compile_backend_and_passes(
job_config.compile, fsdp_reshard_after_forward, get_fsdp_buckets(model)
)
model = torch.compile(
model,
backend=get_compile_backend(backend, fsdp_reshard_after_forward),
backend=backend,
fullgraph=True,
)

Expand Down
7 changes: 5 additions & 2 deletions torchtitan/experiments/simple_fsdp/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@

@dataclass
class Compile:
model_backend_override: str | None = None
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing"""
compiler_passes: str | None = None
"""
Bucketing and overlapping passes in simplefsdp. Additional passes include:
aot_eager_autobucketing, transformer_block_bucketing
"""


@dataclass
Expand Down
32 changes: 27 additions & 5 deletions torchtitan/experiments/simple_fsdp/llama3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchtitan.models.llama3.infra.parallelize import apply_tp
from torchtitan.tools.logging import logger

from ..backend import get_compile_backend
from ..backend import get_compile_backend_and_passes

from ..simple_fsdp import data_parallel, MixedPrecisionPolicy

Expand All @@ -33,6 +33,29 @@
}


def get_fsdp_buckets(model) -> list[list[str] | str]:
module_list = [
model.tok_embeddings,
[model.norm, model.output],
]
for layer_id, transformer_block in model.layers.items():
module_list.append(transformer_block)

def convert_modules_to_fqns(modules, module_to_fqn_mapping):
"""Convert a (possibly nested) list of modules to FQN strings."""
result = []
for m in modules:
if isinstance(m, list):
result.append(convert_modules_to_fqns(m, module_to_fqn_mapping))
else:
result.append(module_to_fqn_mapping.get(m, None))
return result

module_to_name = {m: n for n, m in model.named_modules()}
module_fqns = convert_modules_to_fqns(module_list, module_to_name)
return module_fqns


def parallelize_llama(
model: nn.Module,
parallel_dims: ParallelDims,
Expand Down Expand Up @@ -139,13 +162,12 @@ def parallelize_llama(
f"Invalid fsdp_reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}."
)

backend = (
getattr(job_config.compile, "model_backend_override", None)
or job_config.compile.backend
backend = get_compile_backend_and_passes(
job_config.compile, fsdp_reshard_after_forward, get_fsdp_buckets(model)
)
model = torch.compile(
model,
backend=get_compile_backend(backend, fsdp_reshard_after_forward),
backend=backend,
fullgraph=True,
)

Expand Down
20 changes: 17 additions & 3 deletions torchtitan/experiments/simple_fsdp/tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,25 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"--model.name simple_fsdp.llama3",
"--compile.enable",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
"--compile.model_backend_override aot_eager_autobucketing",
"--compile.backend aot_eager",
"--compile.compiler_passes auto_bucketing",
],
],
"1D+aot_eager_autobucketing",
"1d_aot_eager_autobucketing",
"1D+autobucketing",
"1d_autobucketing",
),
OverrideDefinitions(
[
[
"--model.name simple_fsdp.llama3",
"--compile.enable",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
"--compile.backend aot_eager",
"--compile.compiler_passes transformer_block_bucketing",
],
],
"1D+transformer_block_bucketing",
"1d_transformer_block_bucketing",
),
OverrideDefinitions(
[
Expand Down
Loading