-
Notifications
You must be signed in to change notification settings - Fork 4.1k
FP8 params support for megatron-fsdp (MXFP8/Blockwise) #2239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
BoxiangW
merged 4 commits into
NVIDIA:main
from
kunlunl:kunlunl/megatron-fsdp-fp8-params_main
Jan 9, 2026
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
feb6753
FP8 params support for megatron-fsdp (MXFP8/Blockwise) (#2086)
kunlunl 984912e
handle fp8_tensor _data is None situation
shjwudp 31624f7
Merge pull request #9 from shjwudp/megatron-fsdp-fp8-params_main
kunlunl cba67e3
Merge branch 'main' into kunlunl/megatron-fsdp-fp8-params_main
cspades File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,20 @@ | |
| import torch.nn as nn | ||
| from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten | ||
|
|
||
| from .mixed_precision import ( | ||
| fp8_create_transpose_cache, | ||
| fp8_discard_transpose_cache, | ||
| is_float8tensor, | ||
| ) | ||
| from .param_and_grad_buffer import ( | ||
| AllGatherPipeline, | ||
| BucketingPolicy, | ||
| GradReducePipeline, | ||
| ParamAndGradBuffer, | ||
| PrefetchOrder, | ||
| override_sharded_param_methods_with_safety_checks, | ||
| to_local_if_dtensor, | ||
| ) | ||
| from .utils import FSDPDistributedIndex | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
@@ -34,23 +48,12 @@ | |
| from megatron.core.distributed.distributed_data_parallel_config import ( | ||
| DistributedDataParallelConfig, | ||
| ) | ||
| from megatron.core.fp8_utils import is_float8tensor | ||
| from megatron.core.utils import is_submodule | ||
| except ImportError: | ||
| # Megatron-LM is not installed, use Megatron-FSDP as a standalone module. | ||
| logger.info("Megatron Core is not installed, Megatron-FSDP will run without Megatron Core.") | ||
| from .distributed_data_parallel_config import DistributedDataParallelConfig | ||
| from .utils import is_float8tensor, is_submodule | ||
|
|
||
| from .param_and_grad_buffer import ( | ||
| AllGatherPipeline, | ||
| BucketingPolicy, | ||
| GradReducePipeline, | ||
| ParamAndGradBuffer, | ||
| PrefetchOrder, | ||
| override_sharded_param_methods_with_safety_checks, | ||
| to_local_if_dtensor, | ||
| ) | ||
| from .utils import is_submodule | ||
|
|
||
|
|
||
| class TrainingState(Enum): | ||
|
|
@@ -168,6 +171,7 @@ def __init__( | |
| nccl_ub: bool = False, | ||
| fsdp_double_buffer: bool = False, | ||
| disable_symmetric_registration: bool = False, | ||
| enable_fine_grained_param_gather_hook: bool = False, | ||
| ): | ||
| super().__init__() | ||
| # If device is not specified, use the current device. | ||
|
|
@@ -217,6 +221,7 @@ def __init__( | |
|
|
||
| self.calculate_per_token_loss = calculate_per_token_loss | ||
| self.init_model_with_meta_device = init_model_with_meta_device | ||
| self.enable_fine_grained_param_gather_hook = enable_fine_grained_param_gather_hook | ||
|
|
||
| # Whether to constantly synchronize the model every training iteration, | ||
| # which defaults to False to overlap communication with computation | ||
|
|
@@ -406,6 +411,7 @@ def all_gather_and_wait_parameters_ready( | |
| prefetch=True, | ||
| prefetch_order=PrefetchOrder.FORWARD_PASS_ORDER, | ||
| wait_bucket_ready=True, | ||
| bwd=False, | ||
| ): | ||
| """ | ||
| All-gather parameters across the data parallel group and wait for | ||
|
|
@@ -432,11 +438,14 @@ def all_gather_and_wait_parameters_ready( | |
| and self.ddp_config.outer_dp_sharding_strategy != "no_shard" | ||
| and (self.microbatch_count == 0 or self.model_auto_sync) | ||
| ), | ||
| bwd=bwd, | ||
| ) | ||
| if wait_bucket_ready: | ||
| for param in params: | ||
| bucket_id = self.param_and_grad_buffer.param_to_param_group[param] | ||
| ag_pipeline.wait_bucket_ready(bucket_id) | ||
| ag_pipeline.wait_bucket_ready(bucket_id, bwd) | ||
| if bwd and is_float8tensor(param): | ||
| fp8_create_transpose_cache(param) | ||
|
|
||
| for param in params: | ||
| # This setting is needed to make FSDP store the weight object when used | ||
|
|
@@ -495,19 +504,17 @@ def _register_fsdp_hooks(self, root_module): | |
| """ | ||
| fsdp_unit_modules = self.fsdp_unit_modules | ||
|
|
||
| def release_module_parameters(module, *unused): | ||
| def release_module_parameters(module, bwd, *unused): | ||
| for param in module.parameters(): | ||
| bucket_id = self.param_and_grad_buffer.param_to_param_group[param] | ||
| self.all_gather_pipeline.release_bucket(bucket_id) | ||
|
|
||
| self.all_gather_pipeline.release_bucket(bucket_id, bwd) | ||
| if not self.ddp_config.keep_fp8_transpose_cache: | ||
| release_params_fp8_transpose_cache(module.parameters()) | ||
|
|
||
| def release_params_fp8_transpose_cache(params): | ||
| for param in params: | ||
| if is_float8tensor(param): | ||
| param._transpose_invalid = True | ||
| param._transpose = None | ||
| fp8_discard_transpose_cache(param) | ||
|
|
||
| def _grad_acc(param): | ||
| """ | ||
|
|
@@ -564,12 +571,15 @@ def _post_backward(module, *unused): | |
| if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params": | ||
| # Deallocate the module parameters after the backward pass, | ||
| # because we have our data-parallel gradients computed. | ||
| release_module_parameters(module) | ||
| release_module_parameters(module, bwd=True) | ||
| module._training_state = TrainingState.IDLE | ||
| param_list = list(module.parameters()) | ||
| else: | ||
| param_list = list(module.parameters(recurse=False)) | ||
|
|
||
| if self.enable_fine_grained_param_gather_hook: | ||
| param_list = list(module.parameters(recurse=False)) | ||
|
|
||
| # If the parameter is shared, we do not accumulate gradients | ||
| # here, as the gradients will be accumulated in the | ||
| # root post-backward hook. | ||
|
|
@@ -621,6 +631,9 @@ def _pre_forward_param_unshard( | |
| # to allocate as little memory as possible for this forward pass. | ||
| param_list = list(module.parameters(recurse=False)) | ||
|
|
||
| if self.enable_fine_grained_param_gather_hook: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here |
||
| param_list = list(module.parameters(recurse=False)) | ||
|
|
||
| # All-gather the parameters before the forward pass. | ||
| self.all_gather_and_wait_parameters_ready( | ||
| params=param_list, | ||
|
|
@@ -720,7 +733,7 @@ def _root_post_backward(*unused): | |
| if self.model_auto_sync: | ||
| self.finish_grad_sync() | ||
|
|
||
| def _pre_backward(module: nn.Module, *unused): | ||
| def _pre_backward_param_unshard(module: nn.Module, *unused): | ||
| """ | ||
| Sub-module pre-backward hook to all-gather the module parameters | ||
| before the backward pass. | ||
|
|
@@ -729,11 +742,19 @@ def _pre_backward(module: nn.Module, *unused): | |
| # and unsharding operations when performing activation recomputation | ||
| # / gradient checkpointing. | ||
| module._training_state = TrainingState.PRE_BACKWARD | ||
|
|
||
| if isinstance(module, tuple(fsdp_unit_modules)): | ||
| # All-gather / unshard the module parameters before the backward pass. | ||
| self.all_gather_and_wait_parameters_ready( | ||
| list(module.parameters()), prefetch_order=PrefetchOrder.BACKWARD_PASS_ORDER | ||
| ) | ||
| param_list = list(module.parameters()) | ||
| else: | ||
| param_list = list(module.parameters(recurse=False)) | ||
|
|
||
| if self.enable_fine_grained_param_gather_hook: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||
| param_list = list(module.parameters(recurse=False)) | ||
|
|
||
| # All-gather / unshard the module parameters before the backward pass. | ||
| self.all_gather_and_wait_parameters_ready( | ||
| param_list, prefetch_order=PrefetchOrder.BACKWARD_PASS_ORDER, bwd=True | ||
| ) | ||
|
|
||
| self._root_pre_backward_hook_issued = False | ||
|
|
||
|
|
@@ -760,7 +781,9 @@ def _root_pre_backward(module: nn.Module, *unused): | |
| for bucket_id in range(ag_pipeline.num_buckets): | ||
| group = self.param_and_grad_buffer.parameter_groups[bucket_id] | ||
| if group.fsdp_unit_id is not None: | ||
| ag_pipeline.bucket_can_be_released[bucket_id] = True | ||
| ag_pipeline.bucket_can_be_released[ | ||
| ag_pipeline.get_bucket_key(bucket_id, bwd=False) | ||
| ] = True | ||
| # Track parameters that require gradient reduction and optimization. | ||
| self._params_require_handle_grad = set() | ||
| for param_group in self.param_and_grad_buffer.parameter_groups: | ||
|
|
@@ -782,8 +805,12 @@ def _post_forward(module: nn.Module, input: Any, output: Any): | |
| # during activation recomputation / gradient checkpointing. | ||
| return output | ||
|
|
||
| assert isinstance( | ||
| module, tuple(fsdp_unit_modules) | ||
| ), "_post_forward hook should only be registered on FSDP unit modules." | ||
|
|
||
| # Release the module parameters after the forward pass to save memory. | ||
| release_module_parameters(module) | ||
| release_module_parameters(module, bwd=False) | ||
| module._training_state = TrainingState.IDLE | ||
|
|
||
| return output | ||
|
|
@@ -824,21 +851,55 @@ def forward_hook(_module, inputs, output): | |
| # on the output tensor(s). | ||
| return module.register_forward_hook(forward_hook) | ||
|
|
||
| def _register_pre_forward_param_unshard_hook(module): | ||
| """ | ||
| Register the forward pre-hook to unshard parameters before the forward pass. | ||
| If we are not sharding anything, we do not have a model weight buffer and thus | ||
| have nothing to all-gather / un-shard. | ||
| """ | ||
| if self.ddp_config.data_parallel_sharding_strategy != "no_shard": | ||
| self.forward_pre_hooks[f"{module._get_name()} parameter unshard"] = ( | ||
| module.register_forward_pre_hook( | ||
| _pre_forward_param_unshard, prepend=True, with_kwargs=True | ||
| ) | ||
| ) | ||
|
|
||
| def _register_pre_backward_param_unshard_hook(module): | ||
| """ | ||
| Register the backward pre-hook to unshard FSDP unit module parameters | ||
| immediately before the backward pass via attaching a gradient-triggered | ||
| hook to the output tensor(s) of a module during a post-forward hook. | ||
| """ | ||
| self.backward_pre_hooks[f"all-gather {module._get_name()} parameters"] = ( | ||
| create_custom_backward_hook(module, _pre_backward_param_unshard) | ||
| ) | ||
|
|
||
| def _register_grad_acc_and_reduce_hook(module): | ||
| """ | ||
| Register the post-backward hook to deallocate model parameters and | ||
| reduce-scatter gradients immediately after the module backward pass | ||
| has completed to conserve memory for the subsequent backward pass. | ||
| """ | ||
| self.forward_pre_hooks[f"module {name} register post-backward hook"] = ( | ||
| module.register_forward_pre_hook( | ||
| functools.partial(_register_post_backward_hook, _post_backward), | ||
| with_kwargs=True, | ||
| ) | ||
| ) | ||
|
|
||
| fsdp_modules = [] | ||
| for name, module in root_module.named_modules(): | ||
| if self.enable_fine_grained_param_gather_hook: | ||
| _register_pre_forward_param_unshard_hook(module) | ||
| _register_pre_backward_param_unshard_hook(module) | ||
| _register_grad_acc_and_reduce_hook(module) | ||
|
|
||
| # Skip if the module is already registered in fsdp_modules. | ||
| if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules): | ||
| continue | ||
|
|
||
| # Register the forward pre-hook to unshard parameters before the forward pass. | ||
| # If we are not sharding anything, we do not have a model weight buffer and thus | ||
| # have nothing to all-gather / un-shard. | ||
| if self.ddp_config.data_parallel_sharding_strategy != "no_shard": | ||
| self.forward_pre_hooks[f"module {name} parameter unshard"] = ( | ||
| module.register_forward_pre_hook( | ||
| _pre_forward_param_unshard, prepend=True, with_kwargs=True | ||
| ) | ||
| ) | ||
| if not self.enable_fine_grained_param_gather_hook: | ||
| _register_pre_forward_param_unshard_hook(module) | ||
|
|
||
| if isinstance(module, tuple(fsdp_unit_modules)): | ||
| fsdp_modules.append(module) | ||
|
|
@@ -849,12 +910,8 @@ def forward_hook(_module, inputs, output): | |
| module.register_forward_hook(_post_forward, prepend=False) | ||
| ) | ||
|
|
||
| # Register the backward pre-hook to unshard FSDP unit module parameters | ||
| # immediately before the backward pass via attaching a gradient-triggered | ||
| # hook to the output tensor(s) of a module during a post-forward hook. | ||
| self.backward_pre_hooks[f"all-gather module {name} parameters"] = ( | ||
| create_custom_backward_hook(module, _pre_backward) | ||
| ) | ||
| if not self.enable_fine_grained_param_gather_hook: | ||
| _register_pre_backward_param_unshard_hook(module) | ||
| elif ( | ||
| not self.ddp_config.keep_fp8_transpose_cache | ||
| and self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params" | ||
|
|
@@ -867,15 +924,8 @@ def forward_hook(_module, inputs, output): | |
| module.register_forward_hook(_release_module_fp8_transpose_cache, prepend=False) | ||
| ) | ||
|
|
||
| # Register the post-backward hook to deallocate model parameters and | ||
| # reduce-scatter gradients immediately after the module backward pass | ||
| # has completed to conserve memory for the subsequent backward pass. | ||
| self.forward_pre_hooks[f"module {name} register post-backward hook"] = ( | ||
| module.register_forward_pre_hook( | ||
| functools.partial(_register_post_backward_hook, _post_backward), | ||
| with_kwargs=True, | ||
| ) | ||
| ) | ||
| if not self.enable_fine_grained_param_gather_hook: | ||
| _register_grad_acc_and_reduce_hook(module) | ||
|
|
||
| # Register root module pre- and post-backward hooks in cases where the | ||
| # forward function of root module is not called, but rather the forward | ||
|
|
@@ -992,17 +1042,18 @@ def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bo | |
| else: | ||
| self.synchronize_param_gather() | ||
| for bucket_id in range(self.all_gather_pipeline.num_buckets): | ||
| self.all_gather_pipeline.async_bucket_gather(bucket_id=bucket_id) | ||
| self.all_gather_pipeline.async_bucket_gather(bucket_id=bucket_id, bwd=False) | ||
| group = self.param_and_grad_buffer.parameter_groups[bucket_id] | ||
| if group.model_weight_buffer is None: | ||
| continue | ||
|
|
||
| if group.model_weight_buffer.is_data_distributed: | ||
| # If model weight is sharded, we wait for the all-gather to complete and | ||
| # then release the bucket immediately to save memory usage. | ||
| self.all_gather_pipeline.wait_bucket_ready(bucket_id) | ||
| self.all_gather_pipeline.wait_bucket_ready(bucket_id, False) | ||
|
|
||
| for bucket_id in range(self.all_gather_pipeline.num_buckets): | ||
| self.all_gather_pipeline.wait_bucket_ready(bucket_id) | ||
| self.all_gather_pipeline.wait_bucket_ready(bucket_id, False) | ||
|
|
||
| def start_grad_sync(self, *unused): | ||
| """ | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we combine this if condition with the above?