Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 115 additions & 42 deletions megatron/core/distributed/fsdp/src/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,24 +116,30 @@ fully_shard(model)
# Your model is now ready for distributed training!
```

### `torch.compile` support
### `torch.compile` Compatibility

Megatron-FSDP supports `torch.compile`, but this feature is still experimental and may introduce performance regressions in some workloads.
Megatron-FSDP is compatible with `torch.compile`, but this feature is still experimental and may introduce performance regressions in some workloads.

## `fully_shard` / `MegatronFSDP` API - Advanced Features
## 📖 Megatron-FSDP Comprehensive Walkthrough

Megatron-FSDP's `fully_shard_*` API has a comprehensive set of arguments for fine-tuning your model's performance:
### Import `megatron_fsdp`.

```python
import torch
from megatron_fsdp import (
fully_shard_model,
fully_shard_optimizer,
)
```

### Set up a distributed environment using `DeviceMesh`.

`DeviceMesh` simplifies the construction of complex arrangements of devices
to support various parallelisms.

```python
from torch.distributed.device_mesh import DeviceMesh

"""
Megatron-FSDP DeviceMesh Distributed Environment
"""
# Initialize DeviceMesh.
device_mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda",
Expand All @@ -148,20 +154,22 @@ device_mesh[("dp_shard", "cp")]._flatten("dp_shard_cp")
# Only required if using HSDP. Otherwise, don't pass hybrid_fsdp_group.
device_mesh[("dp_outer", "dp_shard", "cp")]._flatten("hsdp")
hsdp_group = device_mesh["hsdp"].get_group()

# Initialize DeviceMesh for expert parallel (EP) modules when using FSDP + EP.
expert_device_mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda",
mesh_shape=(expt_dp_shard_size, expt_tp_size),
mesh_dim_names=("dp_shard", "tp"),
expt_device_mesh = DeviceMesh.from_group(
[expt_dp_group, expt_tp_group],
device_type="cuda",
mesh=expt_mesh.tolist(),
mesh_dim_names=["dp_shard_cp", "tp"],
)
```

"""
Fully-shard the model for Megatron-FSDP. This wraps the model in a MegatronFSDP
class that schedules the sharding lifecycle of the model parameters and gradients
during training and inference.
### Convert models into fully-sharded `MegatronFSDP` models with `fully_shard_model`.

The original `torch.nn.Module` can be accessed at `MegatronFSDP.module`.
"""
This wraps the model in a MegatronFSDP class that schedules the sharding
lifecycle of the model parameters and gradients during training and inference.

```python
model = fully_shard_model(
# PyTorch (Root) Module
model,
Expand Down Expand Up @@ -196,25 +204,43 @@ model = fully_shard_model(
# Preprocess state dict for DCP checkpointing. Required for Torch Distributed Checkpoint.
preproc_state_dict_for_dcp_ckpt=True,
)
```

# Initialize your optimizer on the Megatron-FSDP model distributed Parameter(s).
# If your optimizer has already been initialized, either use the `fully_shard`
# entrypoint, or use `optimizer.add_param_group({"params": model.parameters()})`
# after resetting your optimizer state via `optimizer.param_groups.clear()`
# and `optimizer.state.clear()`.
The original `torch.nn.Module` can be accessed at `MegatronFSDP.module`.

### Initialize and fully-shard your optimizer on the `MegatronFSDP` model.

Initialize your optimizer on the Megatron-FSDP model distributed `Parameter`(s).
If your optimizer has already been initialized, either use the `fully_shard`
entrypoint, or use `optimizer.add_param_group({"params": model.parameters()})`
after resetting your optimizer state via `optimizer.param_groups.clear()`
and `optimizer.state.clear()`.

```python
optimizer = torch.optim.Optimizer(model.parameters())
```

"""
Fully-shard your optimizer, which just modifies your `optimizer.step()`, `optimizer.zero_grad()`,
and distributed optimizer parameters to punctually trigger scheduled FSDP operations for Megatron-FSDP.
`fully_shard_optimizer` modifies your `optimizer.step()`, `optimizer.zero_grad()`,
and distributed optimizer parameters to punctually trigger scheduled FSDP operations
for Megatron-FSDP.

```python
fully_shard_optimizer(
# PyTorch Optimizer
optimizer,
# Preprocess state dict for DCP checkpointing.
# Required for Torch Distributed Checkpoint.
preproc_state_dict_for_dcp_ckpt=True,
)
```

These operations can be customized precisely via extended arguments to `step()` and `zero_grad()`:
Extended arguments to `step()` and `zero_grad()` control these FSDP operations:

```python
optimizer.step(
...,
# Sync all gradients before the optimizer step. Not necessary and disabled
# automatically when `sync_model_each_microbatch=True` in MegatronFSDP, in
# which case we already synchronize gradients every step but lose performance.
# Sync all gradients before the optimizer step. Alternatively enabled using
# `sync_model_each_microbatch=True` in MegatronFSDP.
sync_grad_before_optimizer_step=True,
# After `optimizer.step()`, install optimized weights into MegatronFSDP's buffers.
install_optimized_model_weights=True,
Expand All @@ -225,19 +251,20 @@ These operations can be customized precisely via extended arguments to `step()`
# Also zero out MegatronFSDP's gradient accumulation buffers.
zero_grad_buffer=True
)
"""
fully_shard_optimizer(
# PyTorch Optimizer
optimizer,
# Preprocess state dict for DCP checkpointing. Required for Torch Distributed Checkpoint.
preproc_state_dict_for_dcp_ckpt=True,
)
```

"""
Megatron-FSDP Model Checkpointing
"""
### `MegatronFSDP` Distributed Checkpointing

Distributed checkpoints can be saved and loaded using Torch DCP. Alternatively,
you can load non-distributed checkpoints before fully-sharding your model with
any existing checkpoint utility compatible with PyTorch Modules.

```python
# Save model and optimizer state.
torch.distributed.checkpoint.save({"model": model.state_dict(), "optimizer": optimizer.state_dict()}, checkpoint_id=str(CKPT_DIR))
torch.distributed.checkpoint.save(
{"model": model.state_dict(), "optimizer": optimizer.state_dict()},
checkpoint_id=str(CKPT_DIR)
)

# Load model and optimizer state.
ckpt_state_dict = {"model": model.state_dict(), "optimizer": optimizer.state_dict()}
Expand All @@ -249,6 +276,10 @@ model.load_state_dict(ckpt_state_dict["model"], strict=False)
optimizer.load_state_dict(ckpt_state_dict["optimizer"])
```

## ⚙️ `fully_shard` / `MegatronFSDP` API - Advanced Features

Megatron-FSDP's `fully_shard_*` API has a comprehensive set of arguments for fine-tuning your model's performance.

- `fsdp_unit_modules` is a list of sub-module classes or `str` import-paths associated with modules that you want `MegatronFSDP` to fully-shard.
- Required if `1`, `2`, or `3` are specified as the sharding strategy. Defaults to `None`, in which case Megatron-FSDP will replicate the parameters similar to DDP.
- `zero_dp_strategy` (and `outer_dp_sharding_strategy`) configure different degrees of zero-redundancy data parallelism as described in [ZeRO (Zero Redundancy Optimizer)](https://arxiv.org/abs/1910.02054). It reduces CUDA memory utilization during model training by distributing model parameters, gradients, and optimizer states across multiple devices in the DP `ProcessGroup`, and collectively communicating subsets of parameters and gradients to specific devices when needed for computation or differentiation. More aggressive sharding strategies will entail more communication overhead, with `no_shard` being the least memory efficient but most communication efficient, and `optim_grads_params` being the most memory efficient but least communication efficient. `outer_dp_sharding_strategy` has the same options, except for the (required) "outer" DP group (`dp_outer_dim` / `hybrid_fsdp_group`) when using [Hybrid-Sharded Data Parallelism (HSDP)](https://arxiv.org/pdf/2304.11277), and only `no_shard` (DP Replication) and `optim` (Optimizer State Hybrid Sharding, requires `zero_dp_strategy='optim_grads_params`) are supported.
Expand Down Expand Up @@ -280,8 +311,9 @@ optimizer.load_state_dict(ckpt_state_dict["optimizer"])
- Both default to `True`.
- `sync_model_each_microbatch` will trigger a `wait` (`MegatronFSDP.finish_grad_sync()`) on gradient reduction, parameter de-allocation, and optimizer parameter / gradient installation (in preparation for `optimizer.step()`) after every forward-backward pass. When using HSDP, parameters and gradients will be all-gathered and reduced respectively on the "outer" DP group each training step instead of each optimization cycle. This behavior is desirable for a transparent and user-friendly sharded training loop where post-backward transformations on the gradient and a clean compute / memory state are necessary between training iterations, but damages performance in situations where optimization is delayed (e.g. gradient accumulation) where the communications of the previous training iteration can be overlapped with the compute of the next training iteration. Will also override `is_last_microbatch` / `microbatch_count` logic in `MegatronFSDP`.
- Defaults to `True` for `fully_shard`, but defaults to `False` when using the `MegatronFSDP` class directly.
- `keep_fp8_transpose_cache_when_using_custom_fsdp` will keep the fp8 transpose cache when using `MegatronFSDP`. This option will cause (number of parameter $\times$ 1 Byte) of memory overhead, but can skip the weight transpose operation in the backward propagation. This feature will not give any benefit from the Blackwell architecture.
- **Only effective when using Megatron-LM.**
- `enable_fine_grained_param_gather` modifies FSDP to all-gather parameters with per-Module granularity instead of collectively unsharding all sub-modules of a unit module in Megatron-FSDP.
- Defaults to `False`.
- `keep_fp8_transpose_cache` will keep the fp8 transpose cache when using `MegatronFSDP`. This option will cause (number of parameter $\times$ 1 Byte) of memory overhead, but can skip the weight transpose operation in the backward propagation. This feature will not give any benefit from the Blackwell architecture.
- Defaults to `False`.
- `nccl_ub` will allocate and register the NCCL userbuffer for param and grad buffers. This option enables an SM-efficient NCCL algorithm that could improve the performance of overlapped computations. This flag will be much more effective when used together with SHARP if the FSDP communication includes both NVL and IB domains. Enabling this option will cause additional memory overhead due to the requirement to enable the `fsdp_double_buffer` option.
- **Only effective when using with Megatron-Core.**
Expand All @@ -297,3 +329,44 @@ optimizer.load_state_dict(ckpt_state_dict["optimizer"])
- Defaults to `False`. Automatically overridden to `True` when `nccl_ub` is enabled.
- `preproc_state_dict_for_dcp_ckpt` adds `model.state_dict()` and `optimizer.state_dict()` post-hooks that modify the model and optimizer state in preparation for `torch.distributed.checkpoint.{save,load}` ([Torch DCP](https://docs.pytorch.org/docs/stable/distributed.checkpoint.html)) checkpointing. Specifically, it adds `__create_write_items__` and `__create_chunk_list__` methods to Tensors utilized by Torch DCP to redistribute parameters when saving and loading model and optimizer checkpoints. Can be deactivated should the user need a custom distributed checkpointing strategy.
- Defaults to `True`.

## 🧮 Using Megatron-FSDP with [`TransformerEngine`](https://github.com/NVIDIA/TransformerEngine)

Megatron-FSDP natively supports mixed-precision activations and parameter sharding in conjunction with [TransformerEngine](https://github.com/NVIDIA/TransformerEngine).

- Within the [`transformer_engine.pytorch.autocast(recipe: transformer_engine.common.recipe.Recipe)`](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html#transformer_engine.pytorch.autocast) context, model activations are converted based on the recipe.
- Within the [`transformer_engine.pytorch.quantized_model_init(recipe: transformer_engine.common.recipe.Recipe)`](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html#transformer_engine.pytorch.quantized_model_init) context, TransformerEngine native modules (e.g. [`transformer_engine.pytorch.TransformerLayer`](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html#transformer_engine.pytorch.TransformerLayer)) have their parameters converted based on the recipe.
- Requires FP8 model activations, i.e. `transformer_engine.pytorch.autocast`.

```python
# FP8 Recipe
fp8_recipe = transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.HYBRID,
)

# Construct TransformerEngine model with FP8 parameters.
with transformer_engine.pytorch.quantized_model_init(
recipe=fp8_recipe,
# Needed for FP8 parameters with Megatron-FSDP.
preserve_high_precision_init_val=True,
):
te_model = transformer_engine.pytorch.TransformerLayer(...)

# Fully-shard the model.
mfsdp_model = fully_shard_model(
module=te_model,
fsdp_unit_modules=[te.pytorch.TransformerLayer],
# Only FSDP / ZeRO-3 supports FP8 parameters.
zero_dp_strategy=3,
# Needed for FP8 parameters. (Default is already True.)
preserve_fp32_weights=True,
# Needed for select FP8 recipes.
keep_fp8_transpose_cache=True,
)

# Evaluate and differentiate the model with FP8 activations.
with transformer_engine.pytorch.autocast(recipe=fp8_recipe):
mfsdp_model(x).sum().backward()
```

ℹ️ `TransformerEngine` kernels have a fair bit of configuration constraints when using FP8-quantized parameters, such as using fused QKV parameters or defining activations and parameters with shapes compatible to FP8 CuBLAS kernels on supported hardware from NVIDIA. To properly initialize `TransformerLayer`, you can refer to the toy model used in our FP8 unit tests: `Megatron-LM/tests/unit_tests/distributed/fsdp/test_mfsdp_fully_shard.py::TestMegatronFsdpFullyShard::test_fully_shard_te_quantized`.
20 changes: 17 additions & 3 deletions megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def fully_shard_model(
nccl_ub: bool = False,
fsdp_double_buffer: bool = False,
disable_symmetric_registration: bool = False,
enable_fine_grained_param_gather: bool = False,
) -> torch.nn.Module:
"""
Fully-shard the model for Megatron-FSDP. This wraps the model in a MegatronFSDP
Expand Down Expand Up @@ -232,6 +233,13 @@ class that schedules the sharding lifecycle of the model parameters and gradient
disable_symmetric_registration (bool):
Whether to disable symmetric (window) registration for NCCL UB registration.
This option forces conventional (local) UB registration when nccl_ub is set.
Defaults to False.

enable_fine_grained_param_gather (bool):
Whether to enable "fine-grained" param all-gather, which can improve performance
when using MXFP8 parameters with activation recomputation. Specifically, it
unshards parameters per-Module instead of unsharding all sub-modules of an FSDP
unit module simultaneously. Defaults to False.

Returns:
model (MegatronFSDP): The wrapped Megatron-FSDP model configured for FSDP.
Expand All @@ -241,14 +249,17 @@ class that schedules the sharding lifecycle of the model parameters and gradient
if device_mesh is None:
if dp_shard_dim is None:
dp_shard_dim = "fsdp"
if tp_dim is None:
# Trivial TP dimension to seamlessly support TransformerEngine.
tp_dim = "tp"
# Deactivate DP-Outer, which needs to be consistent with Expert DeviceMesh.
dp_outer_dim = None
hybrid_fsdp_group = None
outer_dp_sharding_strategy = ShardingStrategy.NO_SHARD
device_mesh = init_device_mesh(
device_type="cuda",
mesh_shape=(torch.distributed.get_world_size(),),
mesh_dim_names=(dp_shard_dim,),
mesh_shape=(torch.distributed.get_world_size(), 1),
mesh_dim_names=(dp_shard_dim, tp_dim),
)

# Parse zero_dp_strategy and outer_dp_sharding_strategy.
Expand Down Expand Up @@ -293,7 +304,7 @@ class that schedules the sharding lifecycle of the model parameters and gradient
if _outer_fsdp_sharding and zero_dp_strategy != "optim_grads_params":
# If sharding on outer DP using HSDP, then we must use HSDP buffers and
# we must be fully-sharding on inner DP. HSDP is an extension of FSDP.
# FIXME(@shjwudp, @cspades): This is an unexpected lack of support.
# TODO(@shjwudp, @cspades): Requires various modifications to support.
raise ValueError(
f"Sharding with Hybrid (Fully) Sharded Data Parallel (HSDP) requires "
"zero_dp_strategy to use FSDP ('optim_grads_params', 3), because "
Expand Down Expand Up @@ -358,6 +369,7 @@ class that schedules the sharding lifecycle of the model parameters and gradient
calculate_per_token_loss=calculate_per_token_loss,
init_model_with_meta_device=init_model_with_meta_device,
sync_model_each_microbatch=sync_model_each_microbatch,
enable_fine_grained_param_gather_hook=enable_fine_grained_param_gather,
)

# Register a state dict post-hook to add Torch DCP metadata for writing checkpoints.
Expand Down Expand Up @@ -529,6 +541,7 @@ def fully_shard(
nccl_ub: bool = False,
fsdp_double_buffer: bool = False,
disable_symmetric_registration: bool = False,
enable_fine_grained_param_gather: bool = False,
) -> tuple[MegatronFSDP, torch.optim.Optimizer]:
"""
Fully shard the model and the optimizer for Megatron-FSDP.
Expand Down Expand Up @@ -575,6 +588,7 @@ def fully_shard(
nccl_ub=nccl_ub,
fsdp_double_buffer=fsdp_double_buffer,
disable_symmetric_registration=disable_symmetric_registration,
enable_fine_grained_param_gather=enable_fine_grained_param_gather,
)

# Extend optimizer methods to support Megatron-FSDP operations.
Expand Down
Loading
Loading