diff --git a/megatron/core/distributed/fsdp/src/README.md b/megatron/core/distributed/fsdp/src/README.md index c96e4425f9a..bc4cdaa078e 100644 --- a/megatron/core/distributed/fsdp/src/README.md +++ b/megatron/core/distributed/fsdp/src/README.md @@ -116,13 +116,13 @@ fully_shard(model) # Your model is now ready for distributed training! ``` -### `torch.compile` support +### `torch.compile` Compatibility -Megatron-FSDP supports `torch.compile`, but this feature is still experimental and may introduce performance regressions in some workloads. +Megatron-FSDP is compatible with `torch.compile`, but this feature is still experimental and may introduce performance regressions in some workloads. -## `fully_shard` / `MegatronFSDP` API - Advanced Features +## 📖 Megatron-FSDP Comprehensive Walkthrough -Megatron-FSDP's `fully_shard_*` API has a comprehensive set of arguments for fine-tuning your model's performance: +### Import `megatron_fsdp`. ```python import torch @@ -130,10 +130,16 @@ from megatron_fsdp import ( fully_shard_model, fully_shard_optimizer, ) +``` + +### Set up a distributed environment using `DeviceMesh`. + +`DeviceMesh` simplifies the construction of complex arrangements of devices +to support various parallelisms. + +```python +from torch.distributed.device_mesh import DeviceMesh -""" -Megatron-FSDP DeviceMesh Distributed Environment -""" # Initialize DeviceMesh. device_mesh = torch.distributed.device_mesh.init_device_mesh( "cuda", @@ -148,20 +154,22 @@ device_mesh[("dp_shard", "cp")]._flatten("dp_shard_cp") # Only required if using HSDP. Otherwise, don't pass hybrid_fsdp_group. device_mesh[("dp_outer", "dp_shard", "cp")]._flatten("hsdp") hsdp_group = device_mesh["hsdp"].get_group() + # Initialize DeviceMesh for expert parallel (EP) modules when using FSDP + EP. -expert_device_mesh = torch.distributed.device_mesh.init_device_mesh( - "cuda", - mesh_shape=(expt_dp_shard_size, expt_tp_size), - mesh_dim_names=("dp_shard", "tp"), +expt_device_mesh = DeviceMesh.from_group( + [expt_dp_group, expt_tp_group], + device_type="cuda", + mesh=expt_mesh.tolist(), + mesh_dim_names=["dp_shard_cp", "tp"], ) +``` -""" -Fully-shard the model for Megatron-FSDP. This wraps the model in a MegatronFSDP -class that schedules the sharding lifecycle of the model parameters and gradients -during training and inference. +### Convert models into fully-sharded `MegatronFSDP` models with `fully_shard_model`. -The original `torch.nn.Module` can be accessed at `MegatronFSDP.module`. -""" +This wraps the model in a MegatronFSDP class that schedules the sharding +lifecycle of the model parameters and gradients during training and inference. + +```python model = fully_shard_model( # PyTorch (Root) Module model, @@ -196,25 +204,43 @@ model = fully_shard_model( # Preprocess state dict for DCP checkpointing. Required for Torch Distributed Checkpoint. preproc_state_dict_for_dcp_ckpt=True, ) +``` -# Initialize your optimizer on the Megatron-FSDP model distributed Parameter(s). -# If your optimizer has already been initialized, either use the `fully_shard` -# entrypoint, or use `optimizer.add_param_group({"params": model.parameters()})` -# after resetting your optimizer state via `optimizer.param_groups.clear()` -# and `optimizer.state.clear()`. +The original `torch.nn.Module` can be accessed at `MegatronFSDP.module`. + +### Initialize and fully-shard your optimizer on the `MegatronFSDP` model. + +Initialize your optimizer on the Megatron-FSDP model distributed `Parameter`(s). +If your optimizer has already been initialized, either use the `fully_shard` +entrypoint, or use `optimizer.add_param_group({"params": model.parameters()})` +after resetting your optimizer state via `optimizer.param_groups.clear()` +and `optimizer.state.clear()`. + +```python optimizer = torch.optim.Optimizer(model.parameters()) +``` -""" -Fully-shard your optimizer, which just modifies your `optimizer.step()`, `optimizer.zero_grad()`, -and distributed optimizer parameters to punctually trigger scheduled FSDP operations for Megatron-FSDP. +`fully_shard_optimizer` modifies your `optimizer.step()`, `optimizer.zero_grad()`, +and distributed optimizer parameters to punctually trigger scheduled FSDP operations +for Megatron-FSDP. + +```python +fully_shard_optimizer( + # PyTorch Optimizer + optimizer, + # Preprocess state dict for DCP checkpointing. + # Required for Torch Distributed Checkpoint. + preproc_state_dict_for_dcp_ckpt=True, +) +``` -These operations can be customized precisely via extended arguments to `step()` and `zero_grad()`: +Extended arguments to `step()` and `zero_grad()` control these FSDP operations: +```python optimizer.step( ..., - # Sync all gradients before the optimizer step. Not necessary and disabled - # automatically when `sync_model_each_microbatch=True` in MegatronFSDP, in - # which case we already synchronize gradients every step but lose performance. + # Sync all gradients before the optimizer step. Alternatively enabled using + # `sync_model_each_microbatch=True` in MegatronFSDP. sync_grad_before_optimizer_step=True, # After `optimizer.step()`, install optimized weights into MegatronFSDP's buffers. install_optimized_model_weights=True, @@ -225,19 +251,20 @@ These operations can be customized precisely via extended arguments to `step()` # Also zero out MegatronFSDP's gradient accumulation buffers. zero_grad_buffer=True ) -""" -fully_shard_optimizer( - # PyTorch Optimizer - optimizer, - # Preprocess state dict for DCP checkpointing. Required for Torch Distributed Checkpoint. - preproc_state_dict_for_dcp_ckpt=True, -) +``` -""" -Megatron-FSDP Model Checkpointing -""" +### `MegatronFSDP` Distributed Checkpointing + +Distributed checkpoints can be saved and loaded using Torch DCP. Alternatively, +you can load non-distributed checkpoints before fully-sharding your model with +any existing checkpoint utility compatible with PyTorch Modules. + +```python # Save model and optimizer state. -torch.distributed.checkpoint.save({"model": model.state_dict(), "optimizer": optimizer.state_dict()}, checkpoint_id=str(CKPT_DIR)) +torch.distributed.checkpoint.save( + {"model": model.state_dict(), "optimizer": optimizer.state_dict()}, + checkpoint_id=str(CKPT_DIR) +) # Load model and optimizer state. ckpt_state_dict = {"model": model.state_dict(), "optimizer": optimizer.state_dict()} @@ -249,6 +276,10 @@ model.load_state_dict(ckpt_state_dict["model"], strict=False) optimizer.load_state_dict(ckpt_state_dict["optimizer"]) ``` +## âš™ī¸ `fully_shard` / `MegatronFSDP` API - Advanced Features + +Megatron-FSDP's `fully_shard_*` API has a comprehensive set of arguments for fine-tuning your model's performance. + - `fsdp_unit_modules` is a list of sub-module classes or `str` import-paths associated with modules that you want `MegatronFSDP` to fully-shard. - Required if `1`, `2`, or `3` are specified as the sharding strategy. Defaults to `None`, in which case Megatron-FSDP will replicate the parameters similar to DDP. - `zero_dp_strategy` (and `outer_dp_sharding_strategy`) configure different degrees of zero-redundancy data parallelism as described in [ZeRO (Zero Redundancy Optimizer)](https://arxiv.org/abs/1910.02054). It reduces CUDA memory utilization during model training by distributing model parameters, gradients, and optimizer states across multiple devices in the DP `ProcessGroup`, and collectively communicating subsets of parameters and gradients to specific devices when needed for computation or differentiation. More aggressive sharding strategies will entail more communication overhead, with `no_shard` being the least memory efficient but most communication efficient, and `optim_grads_params` being the most memory efficient but least communication efficient. `outer_dp_sharding_strategy` has the same options, except for the (required) "outer" DP group (`dp_outer_dim` / `hybrid_fsdp_group`) when using [Hybrid-Sharded Data Parallelism (HSDP)](https://arxiv.org/pdf/2304.11277), and only `no_shard` (DP Replication) and `optim` (Optimizer State Hybrid Sharding, requires `zero_dp_strategy='optim_grads_params`) are supported. @@ -280,8 +311,9 @@ optimizer.load_state_dict(ckpt_state_dict["optimizer"]) - Both default to `True`. - `sync_model_each_microbatch` will trigger a `wait` (`MegatronFSDP.finish_grad_sync()`) on gradient reduction, parameter de-allocation, and optimizer parameter / gradient installation (in preparation for `optimizer.step()`) after every forward-backward pass. When using HSDP, parameters and gradients will be all-gathered and reduced respectively on the "outer" DP group each training step instead of each optimization cycle. This behavior is desirable for a transparent and user-friendly sharded training loop where post-backward transformations on the gradient and a clean compute / memory state are necessary between training iterations, but damages performance in situations where optimization is delayed (e.g. gradient accumulation) where the communications of the previous training iteration can be overlapped with the compute of the next training iteration. Will also override `is_last_microbatch` / `microbatch_count` logic in `MegatronFSDP`. - Defaults to `True` for `fully_shard`, but defaults to `False` when using the `MegatronFSDP` class directly. -- `keep_fp8_transpose_cache_when_using_custom_fsdp` will keep the fp8 transpose cache when using `MegatronFSDP`. This option will cause (number of parameter $\times$ 1 Byte) of memory overhead, but can skip the weight transpose operation in the backward propagation. This feature will not give any benefit from the Blackwell architecture. - - **Only effective when using Megatron-LM.** +- `enable_fine_grained_param_gather` modifies FSDP to all-gather parameters with per-Module granularity instead of collectively unsharding all sub-modules of a unit module in Megatron-FSDP. + - Defaults to `False`. +- `keep_fp8_transpose_cache` will keep the fp8 transpose cache when using `MegatronFSDP`. This option will cause (number of parameter $\times$ 1 Byte) of memory overhead, but can skip the weight transpose operation in the backward propagation. This feature will not give any benefit from the Blackwell architecture. - Defaults to `False`. - `nccl_ub` will allocate and register the NCCL userbuffer for param and grad buffers. This option enables an SM-efficient NCCL algorithm that could improve the performance of overlapped computations. This flag will be much more effective when used together with SHARP if the FSDP communication includes both NVL and IB domains. Enabling this option will cause additional memory overhead due to the requirement to enable the `fsdp_double_buffer` option. - **Only effective when using with Megatron-Core.** @@ -297,3 +329,44 @@ optimizer.load_state_dict(ckpt_state_dict["optimizer"]) - Defaults to `False`. Automatically overridden to `True` when `nccl_ub` is enabled. - `preproc_state_dict_for_dcp_ckpt` adds `model.state_dict()` and `optimizer.state_dict()` post-hooks that modify the model and optimizer state in preparation for `torch.distributed.checkpoint.{save,load}` ([Torch DCP](https://docs.pytorch.org/docs/stable/distributed.checkpoint.html)) checkpointing. Specifically, it adds `__create_write_items__` and `__create_chunk_list__` methods to Tensors utilized by Torch DCP to redistribute parameters when saving and loading model and optimizer checkpoints. Can be deactivated should the user need a custom distributed checkpointing strategy. - Defaults to `True`. + +## 🧮 Using Megatron-FSDP with [`TransformerEngine`](https://github.com/NVIDIA/TransformerEngine) + +Megatron-FSDP natively supports mixed-precision activations and parameter sharding in conjunction with [TransformerEngine](https://github.com/NVIDIA/TransformerEngine). + +- Within the [`transformer_engine.pytorch.autocast(recipe: transformer_engine.common.recipe.Recipe)`](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html#transformer_engine.pytorch.autocast) context, model activations are converted based on the recipe. +- Within the [`transformer_engine.pytorch.quantized_model_init(recipe: transformer_engine.common.recipe.Recipe)`](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html#transformer_engine.pytorch.quantized_model_init) context, TransformerEngine native modules (e.g. [`transformer_engine.pytorch.TransformerLayer`](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html#transformer_engine.pytorch.TransformerLayer)) have their parameters converted based on the recipe. + - Requires FP8 model activations, i.e. `transformer_engine.pytorch.autocast`. + +```python +# FP8 Recipe +fp8_recipe = transformer_engine.common.recipe.MXFP8BlockScaling( + fp8_format=transformer_engine.common.recipe.Format.HYBRID, +) + +# Construct TransformerEngine model with FP8 parameters. +with transformer_engine.pytorch.quantized_model_init( + recipe=fp8_recipe, + # Needed for FP8 parameters with Megatron-FSDP. + preserve_high_precision_init_val=True, +): + te_model = transformer_engine.pytorch.TransformerLayer(...) + +# Fully-shard the model. +mfsdp_model = fully_shard_model( + module=te_model, + fsdp_unit_modules=[te.pytorch.TransformerLayer], + # Only FSDP / ZeRO-3 supports FP8 parameters. + zero_dp_strategy=3, + # Needed for FP8 parameters. (Default is already True.) + preserve_fp32_weights=True, + # Needed for select FP8 recipes. + keep_fp8_transpose_cache=True, +) + +# Evaluate and differentiate the model with FP8 activations. +with transformer_engine.pytorch.autocast(recipe=fp8_recipe): + mfsdp_model(x).sum().backward() +``` + +â„šī¸ `TransformerEngine` kernels have a fair bit of configuration constraints when using FP8-quantized parameters, such as using fused QKV parameters or defining activations and parameters with shapes compatible to FP8 CuBLAS kernels on supported hardware from NVIDIA. To properly initialize `TransformerLayer`, you can refer to the toy model used in our FP8 unit tests: `Megatron-LM/tests/unit_tests/distributed/fsdp/test_mfsdp_fully_shard.py::TestMegatronFsdpFullyShard::test_fully_shard_te_quantized`. \ No newline at end of file diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py index c3e50e769bf..df210f15f05 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py @@ -97,6 +97,7 @@ def fully_shard_model( nccl_ub: bool = False, fsdp_double_buffer: bool = False, disable_symmetric_registration: bool = False, + enable_fine_grained_param_gather: bool = False, ) -> torch.nn.Module: """ Fully-shard the model for Megatron-FSDP. This wraps the model in a MegatronFSDP @@ -232,6 +233,13 @@ class that schedules the sharding lifecycle of the model parameters and gradient disable_symmetric_registration (bool): Whether to disable symmetric (window) registration for NCCL UB registration. This option forces conventional (local) UB registration when nccl_ub is set. + Defaults to False. + + enable_fine_grained_param_gather (bool): + Whether to enable "fine-grained" param all-gather, which can improve performance + when using MXFP8 parameters with activation recomputation. Specifically, it + unshards parameters per-Module instead of unsharding all sub-modules of an FSDP + unit module simultaneously. Defaults to False. Returns: model (MegatronFSDP): The wrapped Megatron-FSDP model configured for FSDP. @@ -241,14 +249,17 @@ class that schedules the sharding lifecycle of the model parameters and gradient if device_mesh is None: if dp_shard_dim is None: dp_shard_dim = "fsdp" + if tp_dim is None: + # Trivial TP dimension to seamlessly support TransformerEngine. + tp_dim = "tp" # Deactivate DP-Outer, which needs to be consistent with Expert DeviceMesh. dp_outer_dim = None hybrid_fsdp_group = None outer_dp_sharding_strategy = ShardingStrategy.NO_SHARD device_mesh = init_device_mesh( device_type="cuda", - mesh_shape=(torch.distributed.get_world_size(),), - mesh_dim_names=(dp_shard_dim,), + mesh_shape=(torch.distributed.get_world_size(), 1), + mesh_dim_names=(dp_shard_dim, tp_dim), ) # Parse zero_dp_strategy and outer_dp_sharding_strategy. @@ -293,7 +304,7 @@ class that schedules the sharding lifecycle of the model parameters and gradient if _outer_fsdp_sharding and zero_dp_strategy != "optim_grads_params": # If sharding on outer DP using HSDP, then we must use HSDP buffers and # we must be fully-sharding on inner DP. HSDP is an extension of FSDP. - # FIXME(@shjwudp, @cspades): This is an unexpected lack of support. + # TODO(@shjwudp, @cspades): Requires various modifications to support. raise ValueError( f"Sharding with Hybrid (Fully) Sharded Data Parallel (HSDP) requires " "zero_dp_strategy to use FSDP ('optim_grads_params', 3), because " @@ -358,6 +369,7 @@ class that schedules the sharding lifecycle of the model parameters and gradient calculate_per_token_loss=calculate_per_token_loss, init_model_with_meta_device=init_model_with_meta_device, sync_model_each_microbatch=sync_model_each_microbatch, + enable_fine_grained_param_gather_hook=enable_fine_grained_param_gather, ) # Register a state dict post-hook to add Torch DCP metadata for writing checkpoints. @@ -529,6 +541,7 @@ def fully_shard( nccl_ub: bool = False, fsdp_double_buffer: bool = False, disable_symmetric_registration: bool = False, + enable_fine_grained_param_gather: bool = False, ) -> tuple[MegatronFSDP, torch.optim.Optimizer]: """ Fully shard the model and the optimizer for Megatron-FSDP. @@ -575,6 +588,7 @@ def fully_shard( nccl_ub=nccl_ub, fsdp_double_buffer=fsdp_double_buffer, disable_symmetric_registration=disable_symmetric_registration, + enable_fine_grained_param_gather=enable_fine_grained_param_gather, ) # Extend optimizer methods to support Megatron-FSDP operations. diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py index d93a13d241b..c1c11721f7e 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py @@ -139,6 +139,9 @@ class MegatronFSDP(torch.nn.Module): disable_symmetric_registration (bool): Whether to disable symmetric (window) registration for NCCL userbuffer registration. This option will force to use conventional (local) userbuffer registration when nccl_ub is set. + enable_fine_grained_param_gather (bool): Whether to enable "fine-grained" param all-gather, + which can improve performance when using MXFP8 parameters with activation recomputation. + Examples: >>> model = GPTModel(config) >>> model = MegatronFSDP( @@ -541,6 +544,7 @@ def _grad_acc(param): param.main_grad = param.get_main_grad() if param.grad is not None: # Copy the gradient into the allocated main gradient bucket. + # It will be reduce-scattered and accumulated into gbuf. param.main_grad.copy_(to_local_if_dtensor(param.grad)) del param.grad else: @@ -550,6 +554,7 @@ def _grad_acc(param): if not param.grad_added_to_main_grad: if param.grad is not None: # Add the gradient into the allocated main gradient bucket. + # For unsharded gradients, this is gradient accumulation. param.main_grad = param.get_main_grad() param.main_grad.add_(to_local_if_dtensor(param.grad)) del param.grad @@ -654,9 +659,8 @@ def _register_post_backward_hook( Pre-forward hook utilized to attach a gradient reduction post-backward hook to the module. """ - # Register the backward function to reduce gradients after the backward pass. - # And for optim_grads_params, we need to release the parameters after the backward pass. if not torch.is_grad_enabled(): + # No gradients / backward pass, don't attach the post-backward hook. return args, kwargs # Preprocess the input arguments. @@ -675,10 +679,10 @@ def _register_post_backward_hook( """ Bootstrapped identity autograd function that attaches a post-backward - "hook" to the module to trigger model resharding / deallocation and - gradient reduce-scatter immediately after the module backward pass has - completed to deallocate this layer's model and gradient memory before - the subsequent backward pass. + "hook" to the module to trigger model compute parameter deallocation + and gradient reduce-scatter immediately after the module backward pass + has completed to shard this layer's model and gradient memory after + the current backward pass stage is complete. """ inp_tensors = RegisterFSDPBackwardFunction.apply( functools.partial(post_backward_hook, module), *inp_tensors @@ -741,9 +745,7 @@ def _pre_backward_param_unshard(module: nn.Module, *unused): Sub-module pre-backward hook to all-gather the module parameters before the backward pass. """ - # Set the module's training state to PRE_BACKWARD to skip resharding - # and unsharding operations when performing activation recomputation - # / gradient checkpointing. + # Set the module's training state to PRE_BACKWARD. module._training_state = TrainingState.PRE_BACKWARD if isinstance(module, tuple(fsdp_unit_modules)): @@ -762,12 +764,13 @@ def _pre_backward_param_unshard(module: nn.Module, *unused): self._root_pre_backward_hook_issued = False def _root_pre_backward(module: nn.Module, *unused): - """Marks the module's training state as 'pre_backward' before the + """Marks the module's training state as PRE_BACKWARD before the backprop, this function is registered on the root module. - This marking enables us to determine whether forward pass needs to - perform reshard/unshard operations in activation recomputation - scenarios. + This root pre-backward hook informs all modules to skip forward + pre-fetching in the pre-forward hooks (for activation recomputation) + and skip weight deallocation / resharding in the post-forward hooks + during the backward pass, which are instead performed by backward hooks. """ if self._root_pre_backward_hook_issued: return @@ -776,7 +779,7 @@ def _root_pre_backward(module: nn.Module, *unused): if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params": for module in root_module.modules(): if isinstance(module, tuple(fsdp_unit_modules)): - # Set PRE_BACKWARD state to skip resharding and unsharding operations + # Set PRE_BACKWARD state to skip resharding and forward pre-fetching # when performing activation recomputation / gradient checkpointing. module._training_state = TrainingState.PRE_BACKWARD # set all param buckets can be released @@ -940,10 +943,7 @@ def _register_grad_acc_and_reduce_hook(module): if len(list(module.parameters())) != len(list(root_module.parameters())): # Only attach to root sub-module. continue - # Add a pre-backward hook to reshard / deallocate model parameters prior - # to the backward pass. - # Furthermore, add a gradient-triggered post-backward hook to reduce-scatter - # leftover gradients. + # Install the root pre-backward hook. self.backward_pre_hooks[f"{name} _root_pre_backward"] = create_custom_backward_hook( module, _root_pre_backward ) diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/mixed_precision.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/mixed_precision.py index d7156bea5c6..177e3b1caa2 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/mixed_precision.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/mixed_precision.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from contextlib import nullcontext from importlib.metadata import version from typing import List, Optional, Tuple @@ -43,6 +44,20 @@ except: TE_VERSION = None +# Detect the quantized_model_init or fp8_model_init context manager. +if HAVE_TE: + try: + from transformer_engine.pytorch import quantized_model_init + + QUANTIZED_MODEL_INIT_CLASS = quantized_model_init + except: + # Fallback to original FP8 model init. + from transformer_engine.pytorch import fp8_model_init + + QUANTIZED_MODEL_INIT_CLASS = fp8_model_init +else: + QUANTIZED_MODEL_INIT_CLASS = nullcontext + # Detect the FP8 tensor class try: from transformer_engine.pytorch.tensor import QuantizedTensor @@ -332,3 +347,15 @@ def _fp8_quantize_fallback( packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=data_parallel_group ) _multi_tensor_copy_this_to_that(packed_amax_views, amaxes, dummy_overflow_buf) + + +def get_quantized_model_init_context_cls(): + """ + Get the TransformerEngine model parameter quantization context manager. + """ + if QUANTIZED_MODEL_INIT_CLASS is nullcontext: + logger.warning( + f"quantized_model_init / fp8_model_init context was requested but does not exist. " + f"Verify TransformerEngine is installed (TE_INSTALLED={HAVE_TE})." + ) + return QUANTIZED_MODEL_INIT_CLASS diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py index 04ea09970f4..d63d2de951d 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py @@ -39,6 +39,7 @@ fp8_need_transpose_data_for_meta_device_init, fp8_quantize, fp8_set_raw_data, + get_quantized_model_init_context_cls, is_blockwise_float8tensor, is_float8tensor, is_te_min_version, @@ -74,7 +75,6 @@ logger.info("Megatron Core is not installed, Megatron-FSDP will run without Megatron Core.") try: - from transformer_engine.pytorch import fp8_model_init from transformer_engine.pytorch.module.base import TransformerEngineBaseModule HAVE_TE = True @@ -2641,7 +2641,12 @@ def num_buckets(self): @torch.no_grad() def copy_main_weights_to_model_weights(self): - """Update the model weights from the main weights.""" + """ + Update the model weights from the main weights. + + If FP8 parameters are utilized, this function will quantize the high-precision + main weights prior to installation into the model compute weight buffers. + """ dense_param_quantize_kwargs = { "model_params": [], "main_params": [], @@ -2737,9 +2742,16 @@ def _batch_quantize_blockwise_fp8_params( model_param = to_local_if_dtensor(param) main_weight = mbuf.get_item(item_id) + # TODO(@kunlunl, @cspades): Currently, we only support FP8 parameters + # for FSDP, i.e. fully-sharded compute parameters with a high-precision + # main weight buffer. Would it be possible to add if branches here to + # quantize the original param (no_shard) or wbuf data (optim, optim_grads) + # for a seamless user experience and coverage for ZeRO-1 and ZeRO-2? + if is_blockwise_float8tensor(param): fp8_params.append(param) if model_param.numel() == 0: + # Empty parameter. shard_fp32_from_fp8.append(None) shard_offsets_in_fp8.append(None) shard_model_params.append([None, None]) @@ -2768,6 +2780,7 @@ def _batch_quantize_blockwise_fp8_params( if is_float8tensor(param): fp8_params.append(param) if model_param.numel() == 0: + # Empty parameter. shard_fp32_from_fp8.append(None) shard_offsets_in_fp8.append(None) shard_model_params.append([None, None]) @@ -3731,11 +3744,26 @@ def __init__(self, init_param_with_fp8=False, with_cuda_rng_tracker=False): def __enter__(self): self.stack = ExitStack() if self.init_param_with_fp8: - assert HAVE_TE - args = {"enabled": True} - if "preserve_high_precision_init_val" in inspect.signature(fp8_model_init).parameters: - args["preserve_high_precision_init_val"] = True - self.stack.enter_context(fp8_model_init(**args)) + # FIXME(@cspades): This appears to be a legacy dependency that is not needed for + # more recent versions of TransformerEngine, which only requires this context during + # TransformerEngineBaseModule.__init__. Should be removed if backwards compatibility + # is confirmed, because overwrites the quantized_model_init context specified by user. + assert ( + HAVE_TE + ), "TransformerEngine is required for using FP8 parameters with Megatron-FSDP." + # Retrieve import for quantized_model_init (new) or fp8_model_init (old). + # Will be nullcontext if TE is not installed. + te_quantized_model_init_cls = get_quantized_model_init_context_cls() + if te_quantized_model_init_cls is not nullcontext: + # Enable TE quantized parameter context manager. + args = {"enabled": True} + if ( + "preserve_high_precision_init_val" + in inspect.signature(te_quantized_model_init_cls).parameters + ): + # Required for Megatron-FSDP + FP8 parameters. + args["preserve_high_precision_init_val"] = True + self.stack.enter_context(te_quantized_model_init_cls(**args)) if self.with_cuda_rng_tracker: # Megatron / TE RNG tracker needs to be initialized and seeded by the user or FW diff --git a/tests/unit_tests/distributed/fsdp/test_mfsdp_fully_shard.py b/tests/unit_tests/distributed/fsdp/test_mfsdp_fully_shard.py index 191aac3e01b..cbca505b405 100644 --- a/tests/unit_tests/distributed/fsdp/test_mfsdp_fully_shard.py +++ b/tests/unit_tests/distributed/fsdp/test_mfsdp_fully_shard.py @@ -2,6 +2,7 @@ import logging import shutil +from contextlib import nullcontext from copy import deepcopy from pathlib import Path @@ -33,6 +34,10 @@ DIM_SIZE = 2 NUM_LAYERS = 2 NUM_STEPS = 2 +DELAYED_FP8_RECIPE = "fp8_delayed_scaling" +CURRENT_FP8_RECIPE = "fp8_current_scaling" +BLOCKWISE_FP8_RECIPE = "fp8_blockwise_scaling" +MXFP8_BLOCKWISE_RECIPE = "mxfp8_blockwise" # Needed for `torch.distributed.checkpoint.{save,load}` because # multiple processes need to write to the same directory. @@ -119,17 +124,33 @@ def forward(self, x, y): class ToyTETransformer(torch.nn.Module): """Toy Transformer model for testing Megatron-FSDP with Transformer Engine.""" - def __init__(self, model_dim, num_heads, num_layers, output_dim): + def __init__( + self, + model_dim, + num_heads, + num_layers, + output_dim, + fuse_qkv_params=False, + params_dtype=torch.float32, + device="cuda", + ): super().__init__() self.layers = torch.nn.ModuleList( [ te.pytorch.TransformerLayer( - hidden_size=model_dim, ffn_hidden_size=model_dim, num_attention_heads=num_heads + hidden_size=model_dim, + ffn_hidden_size=model_dim, + num_attention_heads=num_heads, + fuse_qkv_params=fuse_qkv_params, + params_dtype=params_dtype, + device=device, ) for _ in range(num_layers) ] ) - self.fc_out = te.pytorch.Linear(model_dim, output_dim) + self.fc_out = te.pytorch.Linear( + model_dim, output_dim, params_dtype=params_dtype, device=device + ) def forward(self, x): for layer in self.layers: @@ -166,7 +187,11 @@ def build_toy_model(model_type: str, init_model_with_meta_device: bool, seed=Non fsdp_unit_modules = [torch.nn.Transformer] elif model_type == TE_TRANSFORMER: toy_model = ToyTETransformer( - model_dim=DIM_SIZE, num_heads=2, num_layers=NUM_LAYERS, output_dim=DIM_SIZE + model_dim=DIM_SIZE, + num_heads=2, + num_layers=NUM_LAYERS, + output_dim=DIM_SIZE, + device="meta" if init_model_with_meta_device else "cuda", ) fsdp_unit_modules = [te.pytorch.TransformerLayer] @@ -272,7 +297,7 @@ def test_fully_shard( ) elif dp_outer_strategy == OPTIM: if dp_shard_strategy != OPTIM_GRADS_PARAMS: - # FIXME(@shjwudp, @cspades): This is an unexpected lack of support. + # TODO(@shjwudp, @cspades): Requires various modifications to support. # [default0]:FAILED tests/unit_tests/distributed/test_mfsdp_fully_shard.py # [False-True-True-True-mesh_dim_config0-optim-optim-cnn] # [False-True-True-True-mesh_dim_config0-optim-optim_grads-cnn] @@ -650,3 +675,102 @@ def test_fully_shard_ez(self, shard_strategy): # Optimizer step. optimizer.step() optimizer.zero_grad() + + @pytest.mark.parametrize("init_model_with_meta_device", [True, False]) + @pytest.mark.parametrize( + "te_recipe", + [DELAYED_FP8_RECIPE, CURRENT_FP8_RECIPE, BLOCKWISE_FP8_RECIPE, MXFP8_BLOCKWISE_RECIPE], + ) + def test_fully_shard_te_quantized(self, init_model_with_meta_device, te_recipe): + """ + Test Megatron-FSDP with FP8 activations and parameters via TransformerEngine. + """ + if te_recipe == MXFP8_BLOCKWISE_RECIPE: + # TODO(@cspades, @ko3n1g): Add this test case in. + pytest.skip(f"[Megatron CI/CD] MXFP8 requires Blackwell nodes to test.") + + from megatron.core.distributed.fsdp.src.megatron_fsdp.fully_shard import ( + fully_shard_model, + fully_shard_optimizer, + ) + + # Build FP8 recipe. + te_quant_recipe = None + if te_recipe == MXFP8_BLOCKWISE_RECIPE: + te_quant_recipe = te.common.recipe.MXFP8BlockScaling( + fp8_format=te.common.recipe.Format.HYBRID + ) + elif te_recipe == DELAYED_FP8_RECIPE: + te_quant_recipe = te.common.recipe.DelayedScaling() + elif te_recipe == CURRENT_FP8_RECIPE: + te_quant_recipe = te.common.recipe.Float8CurrentScaling() + elif te_recipe == BLOCKWISE_FP8_RECIPE: + te_quant_recipe = te.common.recipe.Float8BlockScaling() + + # Construct toy model compatible with FP8. + with ( + te.pytorch.quantized_model_init( + recipe=te_quant_recipe, + # Needed for FP8 parameters with Megatron-FSDP. + preserve_high_precision_init_val=True, + ) + if te_quant_recipe is not None + else nullcontext() + ): + # Fused QKV, BF16 precision for high-precision weights, + # and hidden dimension divisibility by 32 is required + # for some FP8 recipes such as MXFP8. + toy_model = ToyTETransformer( + model_dim=64, + num_heads=2, + num_layers=2, + output_dim=64, + fuse_qkv_params=True, + params_dtype=torch.bfloat16, + device="meta" if init_model_with_meta_device else "cuda", + ) + + # Fully-shard the model. + mfsdp_model = fully_shard_model( + module=toy_model, + fsdp_unit_modules=[te.pytorch.TransformerLayer, te.pytorch.Linear], + # Only ZeRO-3 / FSDP supports FP8 parameters. + zero_dp_strategy=3, + init_model_with_meta_device=init_model_with_meta_device, + # Required for FP8 parameter support, except for MXFP8 which has + # its own row-wise and col-wise (transpose) buffer management + # schedule that is natively managed by Megatron-FSDP. + keep_fp8_transpose_cache=True, + # Required for FP8 parameters. The optimizer state (and gradients) + # are never quantized, as TE produces high-precision wgrad and + # dgrad from FP8 weights and activations. Already defaults to True. + preserve_fp32_weights=True, + ) + + # Initialize the distributed optimizer on the MegatronFSDP model. + toy_adam = Adam(params=mfsdp_model.parameters(), lr=0.01) + optimizer = fully_shard_optimizer(optimizer=toy_adam) + + # Mock input and target. Requires 2^N batch size for (MX)FP8 kernels. + toy_input = torch.randn(16, 64, 64, dtype=torch.bfloat16).to("cuda") + toy_target = torch.randn(16, 64, 64, dtype=torch.bfloat16).to("cuda") + + for step in range(NUM_STEPS): + + # Forward pass. + with ( + te.pytorch.autocast(recipe=te_quant_recipe) + if te_quant_recipe is not None + else nullcontext() + ): + output = mfsdp_model(toy_input) + + # Loss. + loss = mse_loss(output, toy_target) + + # Backward pass. + loss.backward() + + # Optimizer step. + optimizer.step() + optimizer.zero_grad()