Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions torchtitan/experiments/simple_fsdp/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ def get_compile_backend(backend_name: str) -> Union[str, callable]:
# Perform auto optimization in aten fx-level and execute code in aot_eager backend
# The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960
from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend

from torch._inductor.config import aten_distributed_optimizations as dist_opts
from torch._inductor.fx_passes.overlap_scheduling import (
schedule_overlap_bucketing,
)

torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = True
torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps = False
dist_opts.collective_bucketing = True
dist_opts.insert_overlap_deps = False
torch._inductor.config.allow_buffer_reuse = False

def aten_autobucketing_reordering_pass(
Expand Down
46 changes: 17 additions & 29 deletions torchtitan/experiments/simple_fsdp/simple_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from collections.abc import Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from typing import List, Optional

import torch
import torch.nn as nn
Expand All @@ -19,7 +18,7 @@
Replicate,
Shard,
)
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor._redistribute import redistribute_local_tensor
from torch.distributed.tensor.placement_types import _StridedShard, Placement
Expand All @@ -45,8 +44,8 @@ def disable_active_parametrization():

@dataclass(frozen=True)
class MixedPrecisionPolicy:
param_dtype: Optional[torch.dtype] = None
reduce_dtype: Optional[torch.dtype] = None
param_dtype: torch.dtype | None = None
reduce_dtype: torch.dtype | None = None


class _ScaledPartial(Partial):
Expand Down Expand Up @@ -95,19 +94,7 @@ def _distribute_dtensor(
"""
inner_spec = tensor._spec
outer_mesh, inner_mesh = device_mesh, inner_spec.mesh
outer_global_mesh = _mesh_resources.get_root_mesh(outer_mesh)
inner_global_mesh = _mesh_resources.get_root_mesh(inner_mesh)
if outer_global_mesh != inner_global_mesh or (
outer_global_mesh is None or inner_global_mesh is None
):
raise AssertionError(
"Cannot distribute tensor across two meshes without the same root mesh: \n"
f"outer global mesh: {outer_global_mesh}\ninner global mesh: {inner_global_mesh}"
)
assert outer_mesh.mesh_dim_names is not None
assert inner_mesh.mesh_dim_names is not None
submesh_names = outer_mesh.mesh_dim_names + inner_mesh.mesh_dim_names
spanned_mesh = outer_global_mesh[submesh_names]
spanned_mesh = DeviceMesh._concatenate([outer_mesh, inner_mesh])

if len(dp_placements) == 1:
assert dp_placements[0].is_replicate() or dp_placements[0].is_shard()
Expand Down Expand Up @@ -173,8 +160,8 @@ def _distribute_dtensor(


def _register_parametrization(
module: nn.Module, param_names: List[str], parametrization: nn.Module
):
module: nn.Module, param_names: list[str], parametrization: nn.Module
) -> None:
"""
It works with state_dict without incurring parametrization calls because
state_dict accesses parameters directly from self._parameters, not from getters
Expand Down Expand Up @@ -242,16 +229,14 @@ def __init__(
self.param_dtype = mp_policy.param_dtype
self.reduce_dtype = mp_policy.reduce_dtype

def replicate_compute(self, x):
def replicate_compute(self, x: DTensor) -> torch.Tensor:
# data parallel runtime replicate parameters and do local compute
# the gradients are partial tensors that needs to perform reduction
# (i.e. DDP: allreduce, FSDP: reduce_scatter, HSDP: mix of both)
# support FSDP/DDP/HSDP + EP + TP (assuming TP shards the inner-most dim)
non_dp_mesh_dims = x._spec.mesh.ndim - self.device_mesh.ndim
assert non_dp_mesh_dims <= 2, "Only DP + EP/TP/EP+TP is supported"
if non_dp_mesh_dims > 0:
# TODO: remove tp_mesh as an input arg to data_parallel API and use x._spec.mesh["tp"]
# after DeviceMesh supports slicing a non-root mesh
dp_mesh = self.device_mesh
# re-wrap 2D DTensor to 1D DTensor on dp_mesh for efficient FSDP all-gather
sharded_local_tensor = x.to_local()
Expand Down Expand Up @@ -295,7 +280,7 @@ def replicate_compute(self, x):

return output

def forward(self, x):
def forward(self, x: DTensor) -> torch.Tensor:
global _active_parametrization
# This should never be set to true during forward, only outside for model
# inspection / debugging / initialization
Expand All @@ -308,7 +293,10 @@ def forward(self, x):
if self.regional_ac and self.mode in ("fully_shard", "hybrid_shard"):
# apply checkpointing to implement reshard_after_forward
output = checkpoint(
self.replicate_compute, x, use_reentrant=False, context_fn=fsdp_policy
self.replicate_compute,
x,
use_reentrant=False,
context_fn=fsdp_policy,
)
else:
output = self.replicate_compute(x)
Expand All @@ -317,13 +305,13 @@ def forward(self, x):


def data_parallel(
model,
device_mesh,
mode="replicate",
model: nn.Module,
device_mesh: DeviceMesh,
mode: str = "replicate",
ac_mode: str = "none",
mp_policy: Optional[MixedPrecisionPolicy] = None,
mp_policy: MixedPrecisionPolicy | None = None,
shard_dim: int = 0,
reduction_divide_factor: Optional[float] = None,
reduction_divide_factor: float | None = None,
):
if mode == "replicate":
param_sharding = (Replicate(),)
Expand Down