Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 44 additions & 12 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.

import copy
import itertools
import warnings
from contextlib import ExitStack, contextmanager
from types import MethodType
Expand All @@ -21,7 +20,7 @@
from torch._logging import trace_structured
from torch._subclasses import FakeTensorMode
from torch.distributed.fsdp import MixedPrecisionPolicy
from torch.distributed.tensor import DeviceMesh
from torch.distributed.tensor import DeviceMesh, DTensor
from torch.export._unlift import _assign_attr
from torch.export.unflatten import _AttrKind

Expand Down Expand Up @@ -457,24 +456,57 @@ def apply_placement(self, sharding_placement=None):
self.joint_with_descriptors
)

param_mappings = {}

def fwd_hook(model, args):
nonlocal param_mappings
param_mappings.clear()
for module in model.modules():
for name, p in module.named_parameters(recurse=False):
if not isinstance(p, DTensor):
continue
p_new = torch.nn.Parameter(p._local_tensor)
param_mappings[p_new] = p
module._parameters[name] = p_new

def bwd_hook(model, grad_input, grad_output):
nonlocal param_mappings
for module in model.modules():
for name, p in module.named_parameters(recurse=False):
if p not in param_mappings:
continue
orig_p = param_mappings[p]
grad = p.grad
if grad is not None:
grad = DTensor(
grad.detach(),
orig_p._spec,
requires_grad=grad.requires_grad,
)
orig_p.grad = grad
module._parameters[name] = orig_p

# TODO: this probably belongs in the AOTAutograd API
# TODO: pytree handling
class AutoParallelModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_forward_pre_hook(fwd_hook, prepend=True)
self.register_full_backward_hook(bwd_hook)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thing that seems at least a bit nicer about this compared to the SimpleFSDP setup is that since we are calling compile ourselves, we don't actually have to worry about these hooks causing graph breaks (since we are calling compile manually on the fw/bw graphs instead of the user calling compile on the entire module themselves). Although I guess we still have the "composability risk" of the params being implicitly-sharded plain tensors rather than DTensors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, indeed the AutoParallel case is simpler than in SimpleFSDP, but the general idea for SimpleFSDP was to introduce graph breaks only at the outer-most FSDP block (which performs the fwd / bwd hooks).

If the model has no graph breaks, then it would hopefully be equivalent to having a single full-graph, as the graph break introduced by this change would be in outer-most wrapper.

Does it make sense?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, I think we're on the same page - in the SimpleFSDP setup, even if we force the graph break on the top-level module that has the backward hooks, we can still expect to capture all of the model's actual compute/comms in a single graph in the inner module.

The only thing I really meant by my comment is that "graph breaks are spooky" (at the very least they add noise to tlparse), so compiling only the the stuff inside the wrapper feels a tiny bit nicer (but the graph break idea for SimpleFSDP still seems perfectly reasonable)


def forward(self, *args):
# NB: don't close over the parameters/buffers, as the user may
# reassign the module!
# TODO: It's this to just exactly match
# prepare_aot_module_simplified, this seems like an API gap
params = [
v.to_local()
for k, v in
# TODO: this is very slow
itertools.chain(
dict(self.named_parameters(remove_duplicate=False)).items(),
dict(self.named_buffers(remove_duplicate=False)).items(),
)
]
boxed_args = [*params, *args]
params = tuple(
v for k, v in self.named_parameters(remove_duplicate=False)
)
# buffers aren't impacted by the fwd hook
buffers = tuple(
v.to_local() for k, v in self.named_buffers(remove_duplicate=False)
)
boxed_args = [*params, *buffers, *args]
del params
# NB: don't do self.parallel_model_fn work around Dynamo bug
out = parallel_model_fn(boxed_args)
Expand Down
Loading