Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions megatron/core/distributed/fsdp/mcore_fsdp_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ def __init__(
dist_index=self.megatron_fsdp_dist_index,
calculate_per_token_loss=config.calculate_per_token_loss,
init_model_with_meta_device=config.init_model_with_meta_device,
enable_fine_grained_param_gather_hook=(
config.fp8_recipe == "mxfp8" and ddp_config.fp8_param_gather
),
),
)
self.param_and_grad_buffer = self.module.param_and_grad_buffer
Expand All @@ -123,6 +126,7 @@ def __init__(
self.broadcast_params = self.module.broadcast_params
self.module.state_dict_for_save_checkpoint = self.module.state_dict
self.state_dict_for_save_checkpoint = self.state_dict
self.module.config = config

self.sync_rng_states_across_tp_group()

Expand Down
157 changes: 104 additions & 53 deletions megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@
import torch.nn as nn
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten

from .mixed_precision import (
fp8_create_transpose_cache,
fp8_discard_transpose_cache,
is_float8tensor,
)
from .param_and_grad_buffer import (
AllGatherPipeline,
BucketingPolicy,
GradReducePipeline,
ParamAndGradBuffer,
PrefetchOrder,
override_sharded_param_methods_with_safety_checks,
to_local_if_dtensor,
)
from .utils import FSDPDistributedIndex

logger = logging.getLogger(__name__)
Expand All @@ -34,23 +48,12 @@
from megatron.core.distributed.distributed_data_parallel_config import (
DistributedDataParallelConfig,
)
from megatron.core.fp8_utils import is_float8tensor
from megatron.core.utils import is_submodule
except ImportError:
# Megatron-LM is not installed, use Megatron-FSDP as a standalone module.
logger.info("Megatron Core is not installed, Megatron-FSDP will run without Megatron Core.")
from .distributed_data_parallel_config import DistributedDataParallelConfig
from .utils import is_float8tensor, is_submodule

from .param_and_grad_buffer import (
AllGatherPipeline,
BucketingPolicy,
GradReducePipeline,
ParamAndGradBuffer,
PrefetchOrder,
override_sharded_param_methods_with_safety_checks,
to_local_if_dtensor,
)
from .utils import is_submodule


class TrainingState(Enum):
Expand Down Expand Up @@ -168,6 +171,7 @@ def __init__(
nccl_ub: bool = False,
fsdp_double_buffer: bool = False,
disable_symmetric_registration: bool = False,
enable_fine_grained_param_gather_hook: bool = False,
):
super().__init__()
# If device is not specified, use the current device.
Expand Down Expand Up @@ -217,6 +221,7 @@ def __init__(

self.calculate_per_token_loss = calculate_per_token_loss
self.init_model_with_meta_device = init_model_with_meta_device
self.enable_fine_grained_param_gather_hook = enable_fine_grained_param_gather_hook

# Whether to constantly synchronize the model every training iteration,
# which defaults to False to overlap communication with computation
Expand Down Expand Up @@ -406,6 +411,7 @@ def all_gather_and_wait_parameters_ready(
prefetch=True,
prefetch_order=PrefetchOrder.FORWARD_PASS_ORDER,
wait_bucket_ready=True,
bwd=False,
):
"""
All-gather parameters across the data parallel group and wait for
Expand All @@ -432,11 +438,14 @@ def all_gather_and_wait_parameters_ready(
and self.ddp_config.outer_dp_sharding_strategy != "no_shard"
and (self.microbatch_count == 0 or self.model_auto_sync)
),
bwd=bwd,
)
if wait_bucket_ready:
for param in params:
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
ag_pipeline.wait_bucket_ready(bucket_id)
ag_pipeline.wait_bucket_ready(bucket_id, bwd)
if bwd and is_float8tensor(param):
fp8_create_transpose_cache(param)

for param in params:
# This setting is needed to make FSDP store the weight object when used
Expand Down Expand Up @@ -495,19 +504,17 @@ def _register_fsdp_hooks(self, root_module):
"""
fsdp_unit_modules = self.fsdp_unit_modules

def release_module_parameters(module, *unused):
def release_module_parameters(module, bwd, *unused):
for param in module.parameters():
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
self.all_gather_pipeline.release_bucket(bucket_id)

self.all_gather_pipeline.release_bucket(bucket_id, bwd)
if not self.ddp_config.keep_fp8_transpose_cache:
release_params_fp8_transpose_cache(module.parameters())

def release_params_fp8_transpose_cache(params):
for param in params:
if is_float8tensor(param):
param._transpose_invalid = True
param._transpose = None
fp8_discard_transpose_cache(param)

def _grad_acc(param):
"""
Expand Down Expand Up @@ -564,12 +571,15 @@ def _post_backward(module, *unused):
if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
# Deallocate the module parameters after the backward pass,
# because we have our data-parallel gradients computed.
release_module_parameters(module)
release_module_parameters(module, bwd=True)
module._training_state = TrainingState.IDLE
param_list = list(module.parameters())
else:
param_list = list(module.parameters(recurse=False))

if self.enable_fine_grained_param_gather_hook:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we combine this if condition with the above?

param_list = list(module.parameters(recurse=False))

# If the parameter is shared, we do not accumulate gradients
# here, as the gradients will be accumulated in the
# root post-backward hook.
Expand Down Expand Up @@ -621,6 +631,9 @@ def _pre_forward_param_unshard(
# to allocate as little memory as possible for this forward pass.
param_list = list(module.parameters(recurse=False))

if self.enable_fine_grained_param_gather_hook:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

param_list = list(module.parameters(recurse=False))

# All-gather the parameters before the forward pass.
self.all_gather_and_wait_parameters_ready(
params=param_list,
Expand Down Expand Up @@ -720,7 +733,7 @@ def _root_post_backward(*unused):
if self.model_auto_sync:
self.finish_grad_sync()

def _pre_backward(module: nn.Module, *unused):
def _pre_backward_param_unshard(module: nn.Module, *unused):
"""
Sub-module pre-backward hook to all-gather the module parameters
before the backward pass.
Expand All @@ -729,11 +742,19 @@ def _pre_backward(module: nn.Module, *unused):
# and unsharding operations when performing activation recomputation
# / gradient checkpointing.
module._training_state = TrainingState.PRE_BACKWARD

if isinstance(module, tuple(fsdp_unit_modules)):
# All-gather / unshard the module parameters before the backward pass.
self.all_gather_and_wait_parameters_ready(
list(module.parameters()), prefetch_order=PrefetchOrder.BACKWARD_PASS_ORDER
)
param_list = list(module.parameters())
else:
param_list = list(module.parameters(recurse=False))

if self.enable_fine_grained_param_gather_hook:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

param_list = list(module.parameters(recurse=False))

# All-gather / unshard the module parameters before the backward pass.
self.all_gather_and_wait_parameters_ready(
param_list, prefetch_order=PrefetchOrder.BACKWARD_PASS_ORDER, bwd=True
)

self._root_pre_backward_hook_issued = False

Expand All @@ -760,7 +781,9 @@ def _root_pre_backward(module: nn.Module, *unused):
for bucket_id in range(ag_pipeline.num_buckets):
group = self.param_and_grad_buffer.parameter_groups[bucket_id]
if group.fsdp_unit_id is not None:
ag_pipeline.bucket_can_be_released[bucket_id] = True
ag_pipeline.bucket_can_be_released[
ag_pipeline.get_bucket_key(bucket_id, bwd=False)
] = True
# Track parameters that require gradient reduction and optimization.
self._params_require_handle_grad = set()
for param_group in self.param_and_grad_buffer.parameter_groups:
Expand All @@ -782,8 +805,12 @@ def _post_forward(module: nn.Module, input: Any, output: Any):
# during activation recomputation / gradient checkpointing.
return output

assert isinstance(
module, tuple(fsdp_unit_modules)
), "_post_forward hook should only be registered on FSDP unit modules."

# Release the module parameters after the forward pass to save memory.
release_module_parameters(module)
release_module_parameters(module, bwd=False)
module._training_state = TrainingState.IDLE

return output
Expand Down Expand Up @@ -824,21 +851,55 @@ def forward_hook(_module, inputs, output):
# on the output tensor(s).
return module.register_forward_hook(forward_hook)

def _register_pre_forward_param_unshard_hook(module):
"""
Register the forward pre-hook to unshard parameters before the forward pass.
If we are not sharding anything, we do not have a model weight buffer and thus
have nothing to all-gather / un-shard.
"""
if self.ddp_config.data_parallel_sharding_strategy != "no_shard":
self.forward_pre_hooks[f"{module._get_name()} parameter unshard"] = (
module.register_forward_pre_hook(
_pre_forward_param_unshard, prepend=True, with_kwargs=True
)
)

def _register_pre_backward_param_unshard_hook(module):
"""
Register the backward pre-hook to unshard FSDP unit module parameters
immediately before the backward pass via attaching a gradient-triggered
hook to the output tensor(s) of a module during a post-forward hook.
"""
self.backward_pre_hooks[f"all-gather {module._get_name()} parameters"] = (
create_custom_backward_hook(module, _pre_backward_param_unshard)
)

def _register_grad_acc_and_reduce_hook(module):
"""
Register the post-backward hook to deallocate model parameters and
reduce-scatter gradients immediately after the module backward pass
has completed to conserve memory for the subsequent backward pass.
"""
self.forward_pre_hooks[f"module {name} register post-backward hook"] = (
module.register_forward_pre_hook(
functools.partial(_register_post_backward_hook, _post_backward),
with_kwargs=True,
)
)

fsdp_modules = []
for name, module in root_module.named_modules():
if self.enable_fine_grained_param_gather_hook:
_register_pre_forward_param_unshard_hook(module)
_register_pre_backward_param_unshard_hook(module)
_register_grad_acc_and_reduce_hook(module)

# Skip if the module is already registered in fsdp_modules.
if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules):
continue

# Register the forward pre-hook to unshard parameters before the forward pass.
# If we are not sharding anything, we do not have a model weight buffer and thus
# have nothing to all-gather / un-shard.
if self.ddp_config.data_parallel_sharding_strategy != "no_shard":
self.forward_pre_hooks[f"module {name} parameter unshard"] = (
module.register_forward_pre_hook(
_pre_forward_param_unshard, prepend=True, with_kwargs=True
)
)
if not self.enable_fine_grained_param_gather_hook:
_register_pre_forward_param_unshard_hook(module)

if isinstance(module, tuple(fsdp_unit_modules)):
fsdp_modules.append(module)
Expand All @@ -849,12 +910,8 @@ def forward_hook(_module, inputs, output):
module.register_forward_hook(_post_forward, prepend=False)
)

# Register the backward pre-hook to unshard FSDP unit module parameters
# immediately before the backward pass via attaching a gradient-triggered
# hook to the output tensor(s) of a module during a post-forward hook.
self.backward_pre_hooks[f"all-gather module {name} parameters"] = (
create_custom_backward_hook(module, _pre_backward)
)
if not self.enable_fine_grained_param_gather_hook:
_register_pre_backward_param_unshard_hook(module)
elif (
not self.ddp_config.keep_fp8_transpose_cache
and self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params"
Expand All @@ -867,15 +924,8 @@ def forward_hook(_module, inputs, output):
module.register_forward_hook(_release_module_fp8_transpose_cache, prepend=False)
)

# Register the post-backward hook to deallocate model parameters and
# reduce-scatter gradients immediately after the module backward pass
# has completed to conserve memory for the subsequent backward pass.
self.forward_pre_hooks[f"module {name} register post-backward hook"] = (
module.register_forward_pre_hook(
functools.partial(_register_post_backward_hook, _post_backward),
with_kwargs=True,
)
)
if not self.enable_fine_grained_param_gather_hook:
_register_grad_acc_and_reduce_hook(module)

# Register root module pre- and post-backward hooks in cases where the
# forward function of root module is not called, but rather the forward
Expand Down Expand Up @@ -992,17 +1042,18 @@ def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bo
else:
self.synchronize_param_gather()
for bucket_id in range(self.all_gather_pipeline.num_buckets):
self.all_gather_pipeline.async_bucket_gather(bucket_id=bucket_id)
self.all_gather_pipeline.async_bucket_gather(bucket_id=bucket_id, bwd=False)
group = self.param_and_grad_buffer.parameter_groups[bucket_id]
if group.model_weight_buffer is None:
continue

if group.model_weight_buffer.is_data_distributed:
# If model weight is sharded, we wait for the all-gather to complete and
# then release the bucket immediately to save memory usage.
self.all_gather_pipeline.wait_bucket_ready(bucket_id)
self.all_gather_pipeline.wait_bucket_ready(bucket_id, False)

for bucket_id in range(self.all_gather_pipeline.num_buckets):
self.all_gather_pipeline.wait_bucket_ready(bucket_id)
self.all_gather_pipeline.wait_bucket_ready(bucket_id, False)

def start_grad_sync(self, *unused):
"""
Expand Down
Loading
Loading