-
Notifications
You must be signed in to change notification settings - Fork 540
[PyTorch] FSDP2 Support for TE #2245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
…ngine into fsdp2_issue_fix Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
…ngine into fsdp2_issue_fix
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
…rgst Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
…es when required instead of doing upfront in fwd pass Signed-off-by: Varun Thumbe <[email protected]>
…ling in fsdp hook functions Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
Co-authored-by: Tim Moon <[email protected]> Signed-off-by: vthumbe1503 <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Adds FSDP2 (Fully Sharded Data Parallel 2) support for FP8 and MXFP8 quantized tensors in PyTorch Transformer Engine, enabling distributed training with FP8 mixed-precision.
Key Changes:
- Implemented
fsdp_pre_all_gatherandfsdp_post_all_gatherhooks for bothFloat8TensorandMXFP8Tensorto handle FSDP weight sharding/gathering lifecycle - Added custom
__torch_dispatch__handlers for FSDP-required operations:aten.split.Tensor,aten.new_zeros,aten.as_strided,aten.copy_, andaten.slice.Tensor - Enhanced transpose caching logic to properly maintain transposed views through various tensor operations
- Added training state-aware quantizer usage control (rowwise vs columnwise) based on forward/backward pass detection
Major Implementation Details:
- FSDP2 integration distinguishes between forward/backward passes using
TrainingState.PRE_BACKWARDto selectively gather only needed tensor representations - For MXFP8, operations validate 128-byte alignment constraints and fall back to dequantization when constraints aren't met
- Transpose cache maintenance across splits, views, and resharding ensures performance optimization for Hopper/L40 architectures
Issues Found:
Multiple critical None-handling bugs exist in MXFP8 dispatch handlers where operations assume non-None data/scale tensors, which would cause AttributeError at runtime when certain usage flags are disabled.
Confidence Score: 3/5
- This PR has several critical runtime issues that need resolution before merging, particularly around None-handling in MXFP8 tensor operations
- Score reflects multiple logic bugs identified by previous reviewers (AttributeError, NameError, variable shadowing) that would cause runtime failures in MXFP8 operations. While Float8Tensor changes appear more robust, MXFP8Tensor has ~8-10 critical None-handling issues across split, slice, copy, and post_all_gather operations that need fixes
transformer_engine/pytorch/tensor/mxfp8_tensor.pyrequires significant attention for None-handling fixes across all new dispatch handlers before this can safely merge
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/tensor/float8_tensor.py | 4/5 | Adds FSDP2 support with fsdp_pre_all_gather and fsdp_post_all_gather hooks, implements aten.split.Tensor, aten.new_zeros, and aten.as_strided handlers with transpose caching for FP8 tensors |
| transformer_engine/pytorch/tensor/mxfp8_tensor.py | 3/5 | Implements FSDP2 support and multiple torch dispatch handlers (split, as_strided, copy_, slice, new_zeros) for MXFP8 tensors; contains several critical None-handling issues that need resolution |
Sequence Diagram
sequenceDiagram
participant FSDP2
participant Float8Tensor/MXFP8Tensor
participant Quantizer
participant DeviceMesh
Note over FSDP2: Forward Pass (weights needed)
FSDP2->>Float8Tensor/MXFP8Tensor: fsdp_pre_all_gather(mesh, module, ...)
Float8Tensor/MXFP8Tensor->>Quantizer: check training_state & reshard_after_forward
Quantizer->>Quantizer: set_usage(rowwise=True, columnwise=False)
Float8Tensor/MXFP8Tensor->>FSDP2: return (sharded_data, metadata)
FSDP2->>DeviceMesh: all_gather(sharded_data)
DeviceMesh->>FSDP2: all_gather_outputs
FSDP2->>Float8Tensor/MXFP8Tensor: fsdp_post_all_gather(outputs, metadata, ...)
Float8Tensor/MXFP8Tensor->>Float8Tensor/MXFP8Tensor: reconstruct full tensor
Float8Tensor/MXFP8Tensor->>Float8Tensor/MXFP8Tensor: update_usage(rowwise=True)
Float8Tensor/MXFP8Tensor->>FSDP2: return gathered_tensor
Note over FSDP2: Compute forward pass
Note over FSDP2: Backward Pass (gradients computed)
FSDP2->>Float8Tensor/MXFP8Tensor: fsdp_pre_all_gather(mesh, module, ...)
Float8Tensor/MXFP8Tensor->>Quantizer: check training_state (PRE_BACKWARD)
Quantizer->>Quantizer: set_usage(rowwise=False, columnwise=True)
Float8Tensor/MXFP8Tensor->>FSDP2: return (transpose_data, metadata)
FSDP2->>DeviceMesh: all_gather(transpose_data)
DeviceMesh->>FSDP2: all_gather_outputs
FSDP2->>Float8Tensor/MXFP8Tensor: fsdp_post_all_gather(outputs, metadata, ...)
Float8Tensor/MXFP8Tensor->>Float8Tensor/MXFP8Tensor: reconstruct with transpose
Float8Tensor/MXFP8Tensor->>Float8Tensor/MXFP8Tensor: update_usage(columnwise=True)
Float8Tensor/MXFP8Tensor->>FSDP2: return gathered_tensor
2 files reviewed, 1 comment
Signed-off-by: Varun Thumbe <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR significantly enhances the FSDP2 test infrastructure for Transformer Engine by adding comprehensive support for FP8 mixed-precision training with distributed sharding.
Key Changes:
- Expanded FP8 recipe support: added
Float8CurrentScalingandMXFP8BlockScalingalongside existingDelayedScaling - Introduced flexible layer configuration system supporting 5 TE layer types (Linear, LayerNormLinear, LayerNormMLP, MultiheadAttention, TransformerLayer)
- Added meta device initialization workflow for deferred parameter materialization after FSDP2 sharding
- Implemented
test_fp8_fsdp2_allgather()validation function to verify FP8 allgather correctness against manual FP32 allgather - Enhanced custom attribute save/restore logic to handle
QuantizedTensormetadata correctly with FSDP2 DTensors - Replaced simple 3-layer network with configurable multi-layer architecture supporting both
reshard_after_forward=True/Falsetest cases
The test file is well-structured with clear separation of concerns: model initialization, FSDP2 setup, training loop, and validation logic.
Confidence Score: 5/5
- This PR is safe to merge with high confidence - the changes are well-tested, properly structured, and add comprehensive FSDP2 support.
- Score reflects thorough implementation with proper error handling, comprehensive test coverage of multiple FP8 recipes and layer types, correct FSDP2 integration patterns (save/restore custom attrs, DTensor handling), and validation logic to verify FP8 allgather correctness. The code follows established patterns and includes clear documentation.
- No files require special attention - the test file is comprehensive and correctly implements FSDP2 FP8 support.
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/distributed/run_fsdp2_model.py | 5/5 | Comprehensive FSDP2 test script adding support for multiple FP8 recipes, flexible layer configurations, meta device initialization, and FP8 allgather validation |
Sequence Diagram
sequenceDiagram
participant Main as Main Process
participant Init as Model Init
participant FSDP as FSDP2 Sharding
participant Train as Training Loop
participant Test as FP8 Test
Main->>Main: Parse args & setup distributed
Main->>Init: Create FP8 recipe (delayed/current/mx_fp8)
alt FP8 Init Enabled
Init->>Init: fp8_model_init(recipe)
else FP8 Init Disabled
Init->>Init: nullcontext()
end
Init->>Init: init_te_model(config)
Note over Init: Create model on meta/cuda device
Init->>FSDP: save_custom_attrs(model)
Note over FSDP: Save QuantizedTensor metadata
FSDP->>FSDP: get_device_mesh(world_size, sharding_dims)
Note over FSDP: Setup FSDP or HSDP mesh
FSDP->>FSDP: shard_model_with_fsdp2(model, mesh)
Note over FSDP: Apply fully_shard to children & root
FSDP->>FSDP: restore_custom_attrs(model, custom_attrs)
Note over FSDP: Restore FP8 metadata to DTensors
alt Meta Device Init
FSDP->>FSDP: reset_parameters()
Note over FSDP: Materialize sharded params on cuda
end
FSDP->>Train: Create optimizer
loop For each iteration
Train->>Train: Generate input & target
Train->>Train: Forward with te.autocast(recipe)
Train->>Train: Compute loss
Train->>Train: Backward pass
Train->>Train: Optimizer step
end
alt FP8 Init Enabled
Train->>Test: test_fp8_fsdp2_allgather(model)
Test->>Test: Manual FP32 allgather
Test->>Test: FSDP2 FP8 allgather (unshard)
Test->>Test: Validate both match
Test->>Test: Reshard model
end
Main->>Main: Destroy process group
1 file reviewed, no comments
Signed-off-by: Varun Thumbe <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds FSDP2 (Fully Sharded Data Parallel v2) support to Transformer Engine by enabling DTensor parameter handling in the base module.
Key Changes:
- Modified
register_parameter()to prevent overwriting FP8-specific metadata when FSDP2 re-registers parameters as DTensors - Enhanced
reset_parameters()to detect and handle DTensor parameters by operating on their local tensors - Added device mesh integration for
Float8CurrentScalingQuantizerto configure amax reduction groups for distributed training - Implemented proper DTensor reconstruction after meta-device materialization
- Ensured quantized local tensors are correctly wrapped back into DTensor parameters
Integration Points:
- DTensor detection via
isinstance(param, DTensor)check - Local tensor extraction and manipulation via
param._local_tensor - Device mesh group configuration for FP8 scaling synchronization across shards
- Parameter wrapping preserves both DTensor structure and FP8 quantization
Confidence Score: 4/5
- This PR is safe to merge with minor considerations for edge cases in DTensor handling
- The implementation correctly handles DTensor parameter registration and initialization. The logic properly distinguishes between DTensor and regular tensors, extracts local tensors for processing, and reconstructs DTensors with appropriate device mesh configuration. The amax reduction group setup for Float8CurrentScalingQuantizer is correctly conditioned on both DTensor type and quantizer type. However, the score is 4 instead of 5 because: (1) the high-precision init value methods are attached to local tensors which relies on DTensor's attribute delegation pattern, and (2) there's no explicit validation that
dtensor_parammaintains valid device_mesh/placements attributes throughout the flow, though the logic appears sound - No files require special attention beyond standard FSDP2 testing with FP8 quantization enabled
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/module/base.py | 4/5 | Added FSDP2 DTensor support in parameter registration and reset, including proper handling of local tensors, device mesh configuration for FP8 quantization, and parameter wrapping |
Sequence Diagram
sequenceDiagram
participant FSDP2 as FSDP2
participant Module as TransformerEngineBaseModule
participant ResetParams as reset_parameters()
participant Quantizer as Float8CurrentScalingQuantizer
participant DTensor as DTensor
FSDP2->>Module: register_parameter(name, DTensor)
Note over Module: Check if param_init_meta exists<br/>Only initialize once to preserve FP8 kwargs
Module->>Module: Store param_init_meta[name]
FSDP2->>ResetParams: Trigger parameter initialization
ResetParams->>ResetParams: Check if param is DTensor
ResetParams->>DTensor: Extract _local_tensor
alt Parameter on meta device
ResetParams->>ResetParams: Create empty_like on cuda
ResetParams->>DTensor: Reconstruct DTensor.from_local()<br/>with device_mesh & placements
end
ResetParams->>ResetParams: Apply init_fn to local tensor
alt FP8 quantization enabled
ResetParams->>Quantizer: Configure quantizer settings
alt Is DTensor && Float8CurrentScaling
ResetParams->>DTensor: Get device_mesh
ResetParams->>Quantizer: Set amax_reduction_group<br/>from device_mesh.get_group()
ResetParams->>Quantizer: Enable with_amax_reduction
end
ResetParams->>Quantizer: Quantize local tensor
Quantizer-->>ResetParams: Return QuantizedTensor
end
ResetParams->>DTensor: Update _local_tensor with quantized tensor
ResetParams->>DTensor: Wrap as nn.Parameter
ResetParams->>Module: setattr(name, DTensor parameter)
1 file reviewed, no comments
Signed-off-by: Varun Thumbe <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Enables FSDP2 training with FP8/MXFP8 initialized weights by implementing custom allgather hooks (fsdp_pre_all_gather and fsdp_post_all_gather) that serialize FP8 tensors to uint8 for distributed communication and reconstruct them post-allgather.
Key Changes:
- FP8 Allgather Support: Float8Tensor and MXFP8Tensor now implement FSDP2 hooks that return uint8 data with metadata (scale_inv, dtype, quantizer) for allgather, enabling FP8 communication instead of high-precision
- Selective Usage Based on Training State: Pre-allgather hooks optimize memory by gathering only rowwise data for forward pass and columnwise data for backward pass when
reshard_after_forward=True - DTensor Integration:
TransformerEngineBaseModule.reset_parameters()now handles FSDP2's DTensor parameters by operating on_local_tensorand preserving FP8 metadata across parameter re-registration - Transpose Cache Management: Enhanced
__torch_dispatch__handlers for split/view/new_zeros/as_strided ops to maintain transpose caches for both data and data_transpose, improving performance - Amax Reduction Setup: Quantizers are configured with appropriate reduction groups for synchronized scale updates across FSDP shards
Issues Found:
- Potential tensor unpacking bug in
mxfp8_tensor.py:613-617where both[:2]and[-2:]slicing could select duplicate tensors if validation fails
Confidence Score: 4/5
- This PR is largely safe to merge with one logical issue that needs verification in edge cases
- The implementation is well-structured and addresses a significant feature gap (FSDP2 support for FP8 weights). The core allgather hook logic is sound and properly handles the forward/backward pass distinction. However, there's a potential edge-case bug in MXFP8Tensor's
fsdp_post_all_gatherwhere tensor unpacking could fail if the tuple length doesn't match usage flags, though this is unlikely in normal operation since the pre/post hooks are paired - transformer_engine/pytorch/tensor/mxfp8_tensor.py - verify the tensor unpacking logic at lines 613-617 handles all edge cases correctly
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/tensor/float8_tensor.py | 4/5 | Adds FSDP2 support with fsdp_pre_all_gather and fsdp_post_all_gather hooks, implements FP8 allgather by returning uint8 data with metadata for reconstruction, enhances __torch_dispatch__ to handle transpose caching for split/view/new_zeros/as_strided ops |
| transformer_engine/pytorch/tensor/mxfp8_tensor.py | 3/5 | Adds FSDP2 allgather hooks for MXFP8 tensors with selective rowwise/columnwise data gathering based on training state, implements torch dispatch handlers for split/as_strided/copy_/slice/new_zeros ops with MXFP8 block scaling constraints, has potential tensor unpacking issue in fsdp_post_all_gather |
| transformer_engine/pytorch/distributed.py | 4/5 | Adds _get_module_fsdp_state helper with LRU caching to retrieve FSDP state from modules or their closest FSDP parent |
| transformer_engine/pytorch/module/base.py | 4/5 | Updates reset_parameters to handle DTensor (FSDP2) by operating on _local_tensor, preserves FP8 metadata during FSDP2's re-registration of parameters as DTensors, sets up amax reduction groups for DTensor quantizers |
Sequence Diagram
sequenceDiagram
participant FSDP2 as FSDP2
participant Float8Tensor as Float8Tensor/MXFP8Tensor
participant Quantizer as Quantizer
participant Module as TransformerEngineModule
Note over FSDP2,Module: Forward Pass (or Backward if reshard_after_forward=True)
FSDP2->>Float8Tensor: fsdp_pre_all_gather(mesh, orig_size, module, ...)
Float8Tensor->>Module: _get_module_fsdp_state(module)
Module-->>Float8Tensor: fsdp_state
Float8Tensor->>Quantizer: copy()
Quantizer-->>Float8Tensor: quantizer_copy
alt reshard_after_forward=True
Float8Tensor->>Float8Tensor: Determine forward vs backward from training_state
Float8Tensor->>Quantizer: set_usage(rowwise=!is_backward, columnwise=is_backward)
Note over Float8Tensor: Pack only needed data based on pass direction
else reshard_after_forward=False
Note over Float8Tensor: Pack both rowwise and columnwise if needed
end
Float8Tensor-->>FSDP2: (sharded_uint8_tensors, metadata)
FSDP2->>FSDP2: AllGather uint8 tensors across ranks
FSDP2->>Float8Tensor: fsdp_post_all_gather(all_gather_outputs, metadata, param_dtype, out)
Float8Tensor->>Float8Tensor: Unpack all_gather_outputs and metadata
alt out exists
Float8Tensor->>Float8Tensor: update_usage() on existing tensor
else out is None
Float8Tensor->>Float8Tensor: Construct new Float8Tensor/MXFP8Tensor
Float8Tensor->>Float8Tensor: update_usage() on new tensor
end
Float8Tensor-->>FSDP2: (reconstructed_fp8_tensor, all_gather_outputs)
Note over FSDP2,Module: Tensor ready for forward/backward computation
1 file reviewed, 1 comment
|
/te-ci pytorch |
Signed-off-by: Varun Thumbe <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR makes focused changes to quantized_tensor.py to improve FSDP2 compatibility:
-
Recursive list handling: Added support for recursively updating lists of tensors in in-place operations (lines 436-439). This handles operations like
splitthat return multiple tensors, ensuring QuantizedTensors within lists are properly updated. -
Simplified
make_likemethod: Removed thedataparameter from the base class implementation (lines 493-506). The method now focuses solely on creating views of tensors. This change is safe because:- Subclasses like
Float8Tensoroverride this method and still support thedataparameter for backward compatibility - The base class docstring now correctly reflects that the method is "intended to create view of tensors"
- Existing usages with
data=parameter are handled by the overridden methods in subclasses
- Subclasses like
These are minimal, well-scoped changes that support the broader FSDP2 integration without breaking existing functionality.
Confidence Score: 4/5
- This PR is safe to merge with minimal risk
- The changes are minimal and focused, with only two small modifications to
quantized_tensor.py. The recursive list handling is a straightforward addition that improves robustness. Themake_likesignature change is safe because subclasses override the method and maintain backward compatibility. No issues found that would impact correctness or introduce bugs. - No files require special attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/quantized_tensor.py | 4/5 | Added recursive list handling for in-place operations and simplified make_like method by removing data parameter. Changes are minimal and focused on improving FSDP2 compatibility. |
Sequence Diagram
sequenceDiagram
participant FSDP as FSDP2 Framework
participant PreHook as fsdp_pre_allgather
participant QT as QuantizedTensor
participant PostHook as fsdp_post_allgather
Note over FSDP: Forward/Backward Pass Begins
FSDP->>PreHook: Call pre_allgather hook
PreHook->>QT: Extract uint8 data + metadata
Note over QT: For FP8: extract _data tensor<br/>For MXFP8: extract rowwise/columnwise data
QT-->>PreHook: Return (uint8_tensors, metadata)
PreHook-->>FSDP: Return allgather input
Note over FSDP: Perform AllGather on uint8 data
FSDP->>PostHook: Call post_allgather hook
PostHook->>QT: Reconstruct from allgathered data
Note over QT: Rebuild Float8/MXFP8 tensor<br/>from uint8 + metadata
QT-->>PostHook: Return reconstructed tensor
PostHook-->>FSDP: Return full tensor
Note over FSDP: Continue computation with full tensor
1 file reviewed, no comments
Signed-off-by: Varun Thumbe <[email protected]>
|
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Enables FSDP2 distributed training with FP8-initialized weights by implementing allgather hooks, torch dispatch operations, and DTensor support for TE quantized tensors.
Key Changes:
- Implements
fsdp_pre_all_gatherandfsdp_post_all_gatherhooks for Float8Tensor and MXFP8Tensor to enable 8-bit weight allgather (instead of high-precision) - Adds torch dispatch support for FSDP2 tensor operations:
split,copy_,slice,view,as_strided,new_zeros - Updates
reset_parametersin TransformerEngineBaseModule to handle DTensor for deferred initialization (meta device) - Fixes optimizer weight updates by recursively handling lists of quantized tensors in in-place operations
- Moves quantizer usage validation from forward to backward pass to support phase-aware allgather
- Configures amax reduction groups for current scaling quantizer to synchronize scale inverses across FSDP shards
- Comprehensive test coverage for multiple TE layers with delayed/current/MX_FP8 scaling recipes
Memory Impact:
- FP8 per-tensor scaling reduces memory footprint by ~50% vs BF16 on Blackwell (as expected)
- MXFP8 block scaling maintains similar memory to BF16 due to rowwise+columnwise storage requirements
Confidence Score: 4/5
- This PR is mostly safe to merge with one critical bug fix needed in MXFP8 shape handling
- Score reflects solid implementation with comprehensive test coverage, but deducted 1 point due to critical bug in mxfp8_tensor.py:658 where both rowwise/columnwise data can be None causing AttributeError, and minor concerns about LRU cache causing potential memory leaks
- transformer_engine/pytorch/tensor/mxfp8_tensor.py:658 requires fix for None handling
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/tensor/float8_tensor.py | 4/5 | Adds FSDP2 allgather hooks and torch dispatch ops (view, split, copy, slice, as_strided, new_zeros) to support 8-bit weight sharding. Implements current scaling quantizer sync for FSDP weight updates. Potential issue with shape handling in line 658. |
| transformer_engine/pytorch/tensor/mxfp8_tensor.py | 3/5 | Implements FSDP2 support with rowwise/columnwise data handling for block-scaled FP8. Adds dispatch ops (split, copy, slice, as_strided, new_zeros). Critical bug at line 658 where both data tensors can be None causing AttributeError. |
| transformer_engine/pytorch/module/base.py | 4/5 | Updates reset_parameters to handle DTensor (FSDP2 deferred init) by quantizing local tensor and reconstructing DTensor. Adds amax_reduction_group configuration for current scaling. Guard prevents metadata loss during DTensor conversion. |
| transformer_engine/pytorch/distributed.py | 4/5 | Adds _get_module_fsdp_state helper with LRU cache to find FSDP state for allgather hooks. Cache could cause memory leaks but likely acceptable given module stability during training. |
| transformer_engine/pytorch/quantized_tensor.py | 5/5 | Fixes in-place ops to recursively handle lists of tensors (optimizer sends batched updates). Removes data parameter from make_like to avoid confusion between view creation and data initialization. |
Sequence Diagram
sequenceDiagram
participant FSDP as FSDP2
participant QT as QuantizedTensor (FP8/MXFP8)
participant Helper as _get_module_fsdp_state
participant Optimizer as Optimizer
Note over FSDP,QT: Forward Pass - Weight Allgather
FSDP->>QT: fsdp_pre_all_gather(module, mesh, ...)
QT->>Helper: Get FSDP state to determine phase
Helper-->>QT: training_state, reshard_after_forward
QT->>QT: Set quantizer.rowwise_usage=True, columnwise=False
QT-->>FSDP: (uint8_data, ...), metadata
FSDP->>FSDP: All-gather uint8 shards
FSDP->>QT: fsdp_post_all_gather(gathered_outputs, metadata)
QT->>QT: Reconstruct FP8 tensor with rowwise usage
QT-->>FSDP: Allgathered FP8 weight
Note over FSDP,QT: Forward Pass Compute
FSDP->>QT: Forward computation with FP8 weights
Note over FSDP,QT: Backward Pass - Weight Allgather (if reshard_after_forward)
FSDP->>QT: fsdp_pre_all_gather(module, mesh, ...)
QT->>Helper: Get FSDP state
Helper-->>QT: training_state=PRE_BACKWARD
QT->>QT: Set quantizer.rowwise=False, columnwise_usage=True
QT-->>FSDP: (uint8_data_transpose, ...), metadata
FSDP->>FSDP: All-gather transpose/columnwise shards
FSDP->>QT: fsdp_post_all_gather(gathered_outputs, metadata)
QT->>QT: Reconstruct FP8 tensor with columnwise usage
QT-->>FSDP: Allgathered FP8 weight
Note over FSDP,Optimizer: Gradient Computation & Weight Update
FSDP->>FSDP: Compute gradients, reduce-scatter
Optimizer->>QT: In-place update (lerp on list of tensors)
QT->>QT: Dequantize, apply op, quantize with amax reduction
Note over QT: Amax synchronized across shards<br/>for current scaling
11 files reviewed, 2 comments
| columnwise_scale_inv=columnwise_scale_inv, | ||
| fp8_dtype=fp8_dtype, | ||
| dtype=param_dtype, | ||
| shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: if both rowwise_data and columnwise_data are None (when both usage flags are False), accessing .shape raises AttributeError
| shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape, | |
| shape=rowwise_data.shape if rowwise_data is not None else (columnwise_data.shape if columnwise_data is not None else torch.Size([0])), |
Signed-off-by: Varun Thumbe <[email protected]>
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR enables FSDP2 training with FP8-initialized weights by implementing custom allgather hooks and torch dispatch handlers. The implementation addresses three key issues: memory footprint with FP8 weights, correct weight updates during training, and efficient 8-bit weight allgather.
Key Changes:
- Implements
fsdp_pre_all_gatherandfsdp_post_all_gatherhooks for Float8Tensor and MXFP8Tensor to handle FP8/MXFP8 allgather using uint8 data - Adds torch dispatch handlers for
split,slice,copy,new_zeros,as_strided, andviewoperations needed by FSDP2 - Uses FSDP state to detect forward vs backward pass and set appropriate rowwise/columnwise quantizer usage
- Sets
amax_reduction_groupfor current scaling quantization to synchronize scale inverses across shards - Updates DTensor parameters correctly during deferred initialization (meta device)
- Moves quantizer usage validation from layer
forward()to_apply_forward/backwardfunctions to accommodate FSDP2's separate allgather for forward/backward
Critical Issues Found:
mxfp8_tensor.py:502andmxfp8_tensor.py:389have potentialAttributeErrorwhen accessing.shapeon tensors that can beNone(when neither rowwise nor columnwise data exists)
Confidence Score: 3/5
- This PR introduces critical bugs that will cause runtime failures in edge cases, but the core FSDP2 integration logic is sound
- Score of 3 reflects two critical logic errors in MXFP8Tensor dispatch handlers (lines 389 and 502) that will cause AttributeError when accessing
.shapeon None values. These bugs occur when quantizer has neither rowwise nor columnwise usage enabled, which may be rare but is not prevented. The rest of the implementation is well-designed with proper handling of forward/backward distinction, amax reduction groups, and DTensor support - transformer_engine/pytorch/tensor/mxfp8_tensor.py lines 389 and 502 require immediate fixes to handle None tensor data
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/tensor/mxfp8_tensor.py | 2/5 | Adds FSDP2 torch dispatch handlers for split, slice, copy, new_zeros, as_strided, and view ops. Contains critical bug: line 502 accesses .shape on out_data[0] which can be None when neither rowwise nor columnwise data exists |
| transformer_engine/pytorch/tensor/float8_tensor.py | 3/5 | Adds FSDP2 allgather hooks and torch dispatch handlers for various ops. Implements rowwise/columnwise usage tracking for forward/backward passes. Generally well-structured but relies on cached FSDP state lookup |
| transformer_engine/pytorch/distributed.py | 4/5 | Adds _get_module_fsdp_state helper with @lru_cache to find FSDP state for modules. Cache is appropriate since it stores reference to mutable state object |
| transformer_engine/pytorch/module/base.py | 4/5 | Updates reset_parameters to handle DTensor params for FSDP2 deferred init, sets amax reduction group for current scaling quantization. Logic is sound |
Sequence Diagram
sequenceDiagram
participant FSDP2
participant Float8Tensor
participant MXFP8Tensor
participant Quantizer
participant TE_Module
Note over FSDP2,TE_Module: Forward Pass
FSDP2->>Float8Tensor: fsdp_pre_all_gather(mesh, module, ...)
Float8Tensor->>Float8Tensor: Set amax_reduction_group for current scaling
Float8Tensor->>Float8Tensor: Get FSDP state, check reshard_after_forward
Float8Tensor->>Quantizer: copy() and set_usage(rowwise=True)
Float8Tensor-->>FSDP2: Return (uint8 data, metadata)
FSDP2->>FSDP2: All-gather uint8 data across shards
FSDP2->>Float8Tensor: fsdp_post_all_gather(gathered_data, metadata)
Float8Tensor->>Float8Tensor: Reconstruct with rowwise usage
Float8Tensor-->>FSDP2: Return reconstructed Float8Tensor
FSDP2->>TE_Module: forward(input)
TE_Module->>TE_Module: Validate rowwise usage in _apply_forward
TE_Module-->>FSDP2: output
Note over FSDP2,TE_Module: Backward Pass (if reshard_after_forward)
FSDP2->>Float8Tensor: fsdp_pre_all_gather(mesh, module, ...)
Float8Tensor->>Float8Tensor: Get FSDP state, detect PRE_BACKWARD
Float8Tensor->>Quantizer: copy() and set_usage(columnwise=True)
Float8Tensor-->>FSDP2: Return (uint8 data, metadata)
FSDP2->>FSDP2: All-gather uint8 data across shards
FSDP2->>Float8Tensor: fsdp_post_all_gather(gathered_data, metadata)
Float8Tensor->>Float8Tensor: Reconstruct with columnwise usage
Float8Tensor-->>FSDP2: Return reconstructed Float8Tensor
TE_Module->>TE_Module: Validate columnwise usage in _apply_backward
Note over FSDP2,TE_Module: Weight Initialization (Meta Device)
TE_Module->>TE_Module: reset_parameters() with DTensor
TE_Module->>Quantizer: Set amax_reduction_group from mesh
TE_Module->>Float8Tensor: Initialize with FP8 quantization
TE_Module->>TE_Module: Update DTensor.local_tensor
11 files reviewed, 2 comments
| ) | ||
| out_data.append(scale_inv_out) | ||
| return MXFP8Tensor( | ||
| shape=out_data[0].shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: AttributeError if both _rowwise_data and _columnwise_data are None - accessing .shape on out_data[0] will fail when neither usage flag is set
| shape=out_data[0].shape, | |
| shape=out_data[0].shape if out_data[0] is not None else out_data[1].shape, |
| out_data.append(scale_inv_out) | ||
| return [ | ||
| MXFP8Tensor( | ||
| shape=splitted_tensor_data[0].size(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: AttributeError if splitted_tensor_data[0] is None - happens when tensor._rowwise_data is None before splitting
| shape=splitted_tensor_data[0].size(), | |
| shape=splitted_tensor_data[0].size() if splitted_tensor_data[0] is not None else splitted_tensor_data[1].size(), |
Signed-off-by: Varun Thumbe <[email protected]>
…ngine into fsdp2_issue_fix Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR enables end-to-end FSDP2 training for PyTorch models with Transformer Engine layers initialized with FP8/MXFP8 weights, solving three critical issues: memory footprint problems with FP8-initialized weights, weight update correctness during training, and enabling 8-bit allgather instead of high-precision.
Key Changes:
- FSDP2 Allgather Hooks: Implements
fsdp_pre_all_gatherandfsdp_post_all_gathermethods for Float8Tensor and MXFP8Tensor to support 8-bit weight allgather by returning uint8 data and metadata for reconstruction - Torch Dispatch Operations: Adds handlers for
split,new_zeros,as_strided,copy_,slice, andviewoperations to support FSDP2 sharding and resharding of quantized tensors - FSDP State Management: Introduces
_get_module_fsdp_statehelper with LRU cache to determine forward/backward pass andreshard_after_forwardconfiguration, enabling proper rowwise/columnwise usage selection - Current Scaling Synchronization: Sets amax reduction group in quantizers during allgather to ensure all weight shards share the same scale inverse after optimizer updates
- DTensor Support: Updates
reset_parametersin base module to handle DTensor parameters for FSDP2 deferred initialization with proper quantizer configuration - Quantized Tensor Fixes: Fixes in-place operations to handle lists of tensors (for optimizer lerp operations) and removes incorrect
dataparameter frommake_likeAPI - Usage Validation Refactoring: Moves quantizer usage validation from layer forward to forward/backward functions, and removes unnecessary quantizer updates when weights are already quantized
Memory Impact: FP8 per-tensor quantization reduces memory by 50% vs BF16 on Blackwell. MXFP8 has similar memory footprint to BF16 due to needing both rowwise/columnwise representations.
Test Coverage: Comprehensive tests cover delayed scaling, current scaling, and MX_FP8 block scaling recipes with various layer types (Linear, LayerNormLinear, TransformerLayer) and both FSDP/HSDP configurations.
Confidence Score: 4/5
- Safe to merge with minor considerations - addresses long-standing FSDP2+FP8 issues with comprehensive implementation
- Score of 4 reflects solid implementation with extensive test coverage addressing critical functionality gaps. The changes are well-architected with proper separation between FP8/MXFP8 tensor handling, FSDP2 hooks, and torch dispatch operations. Previous syntax errors in mxfp8_tensor.py mentioned in earlier comments have been fixed. Main concerns are: (1) LRU cache on
_get_module_fsdp_statecould retain module references indefinitely though the return value is a mutable state reference, (2) complex logic for determining forward/backward pass and reshard_after_forward could benefit from additional inline documentation, (3) MXFP8 view operation intentionally falls back to dequantize path with warning when flattening inner dimension. The PR resolves three critical GitHub issues (#1688, #401, #1135, #1188) and includes validation tests. - Pay close attention to
transformer_engine/pytorch/tensor/float8_tensor.pyandtransformer_engine/pytorch/tensor/mxfp8_tensor.pyfor the complex torch dispatch logic and allgather hooks
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/distributed.py | 4/5 | Adds _get_module_fsdp_state helper with @lru_cache to find FSDP state for modules - enables determining forward/backward pass during allgather |
| transformer_engine/pytorch/tensor/float8_tensor.py | 4/5 | Implements FSDP2 hooks (fsdp_pre_all_gather, fsdp_post_all_gather) and torch dispatch for split/new_zeros/as_strided/copy operations to support 8-bit allgather |
| transformer_engine/pytorch/tensor/mxfp8_tensor.py | 4/5 | Implements FSDP2 hooks and torch dispatch for MXFP8 tensors with rowwise/columnwise data handling - includes split/as_strided/copy/slice operations |
| transformer_engine/pytorch/quantized_tensor.py | 4/5 | Fixes in-place ops to handle lists of tensors (for optimizer updates) and removes data parameter from make_like to fix view semantics |
| transformer_engine/pytorch/module/base.py | 4/5 | Adds DTensor support in reset_parameters for FSDP2 deferred initialization, handles amax reduction group setup for current scaling quantization |
| transformer_engine/pytorch/module/linear.py | 4/5 | Removes quantizer updates when weight is already quantized, moves columnwise usage validation from forward to backward function |
Sequence Diagram
sequenceDiagram
participant User
participant FSDP2
participant TEModule as TE Module
participant Float8Tensor
participant Quantizer
participant AllGather as FSDP AllGather
User->>FSDP2: Initialize model with fp8_model_init
FSDP2->>TEModule: Create FP8/MXFP8 weight shards
TEModule->>Float8Tensor: Initialize quantized weights
Float8Tensor->>Quantizer: Setup amax reduction group
User->>FSDP2: Start training iteration (forward pass)
FSDP2->>Float8Tensor: fsdp_pre_all_gather(module, mesh)
Float8Tensor->>TEModule: Get FSDP state via _get_module_fsdp_state
Float8Tensor->>Quantizer: Set rowwise usage for forward
Float8Tensor-->>FSDP2: Return (uint8_data,), metadata
FSDP2->>AllGather: AllGather uint8 data across shards
AllGather-->>FSDP2: Gathered uint8 data
FSDP2->>Float8Tensor: fsdp_post_all_gather(outputs, metadata)
Float8Tensor->>Float8Tensor: Reconstruct Float8Tensor with gathered data
Float8Tensor->>Float8Tensor: update_usage(rowwise=True)
Float8Tensor-->>FSDP2: Return reconstructed weight tensor
FSDP2->>TEModule: Forward pass with gathered weights
TEModule->>TEModule: Compute activations
alt reshard_after_forward=True
FSDP2->>FSDP2: Reshard weights after forward
end
User->>FSDP2: Backward pass
alt reshard_after_forward=True
FSDP2->>Float8Tensor: fsdp_pre_all_gather (backward)
Float8Tensor->>Quantizer: Set columnwise usage for backward
Float8Tensor-->>FSDP2: Return appropriate data for backward
FSDP2->>AllGather: AllGather for backward pass
FSDP2->>Float8Tensor: fsdp_post_all_gather
Float8Tensor-->>FSDP2: Reconstructed weight for backward
end
FSDP2->>TEModule: Backward pass with weights
TEModule->>TEModule: Compute gradients
FSDP2->>FSDP2: ReduceScatter gradients
User->>FSDP2: Optimizer step
FSDP2->>Float8Tensor: Update weight shards (via lerp/copy_)
Float8Tensor->>Quantizer: Sync amax across shards
Quantizer->>Quantizer: AllReduce amax for single scale_inv
11 files reviewed, no comments
Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR implements comprehensive FSDP2 (Fully Sharded Data Parallel 2) support for TransformerEngine, enabling distributed training with FP8 quantized weights. The implementation addresses three critical issues:
Key Achievements:
- Memory Efficiency: FP8-initialized weights now consume half the memory of BF16 (per-tensor scaling) during FSDP2 training
- Weight Updates: Fixed FP8 weight gradient updates that were previously broken for both DDP and FSDP with FP8-initialized weights
- 8-bit Allgather: Enabled FP8/MXFP8 weight all-gather instead of high-precision communication for efficient training
Technical Implementation:
- Adds
fsdp_pre_all_gatherandfsdp_post_all_gatherhooks to Float8Tensor and MXFP8Tensor for handling quantized tensor communication - Implements torch dispatch handlers for FSDP2-required operations:
split,copy_,new_zeros,as_strided,view - Handles rowwise/columnwise usage tracking for forward vs backward passes with
reshard_after_forwardconfiguration - Sets up amax reduction groups for current scaling quantizers to ensure consistent scale_inv across shards
- Supports both DTensor (FSDP2) and regular tensor paths in
reset_parametersfor deferred initialization
Areas of Concern:
- MXFP8 copy and split operations have potential
AttributeErrorrisks when_columnwise_dataisNone - The
_get_module_fsdp_statefunction uses unbounded@lru_cachewhich could accumulate module references - Some edge cases around tensor shape validation when both rowwise and columnwise usages are disabled
The implementation is well-structured with comprehensive test coverage for different quantization recipes (FP8, MXFP8, delayed/current scaling). The PR successfully enables production FSDP2 training with FP8 quantization.
Confidence Score: 4/5
- This PR is relatively safe to merge with minor issues that should be addressed
- The implementation is comprehensive with good test coverage, but has a few logical edge cases around None handling in MXFP8 operations that could cause runtime errors in specific configurations. The core FSDP2 integration logic is sound and addresses real production issues.
- Pay close attention to
transformer_engine/pytorch/tensor/mxfp8_tensor.pyfor None handling in copy and split operations
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/tensor/float8_tensor.py | 4/5 | Adds FSDP2 support for Float8Tensor with pre/post allgather hooks, torch dispatch handlers for split/copy/view/new_zeros/as_strided ops. Most changes are solid but transpose cache logic in view operation needs attention for edge cases. |
| transformer_engine/pytorch/tensor/mxfp8_tensor.py | 3/5 | Implements FSDP2 hooks and torch dispatch for MXFP8 with complex split/padding logic. Copy operation has potential None handling issues when columnwise data doesn't exist. Split tensor handling assumes both data and scale_inv exist for all usages. |
| transformer_engine/pytorch/distributed.py | 4/5 | Adds _get_module_fsdp_state helper with @lru_cache that finds closest FSDP ancestor module. Caching by module identity could potentially cause issues if modules are recreated, though this is unlikely in typical FSDP usage. |
| transformer_engine/pytorch/module/base.py | 5/5 | Updates reset_parameters to handle DTensor for FSDP2 and sets amax_reduction_group for current scaling quantizers. Clean implementation that properly handles both DTensor and regular tensors with FP8 quantization. |
Sequence Diagram
sequenceDiagram
participant FSDP2 as FSDP2 Framework
participant Tensor as Float8/MXFP8Tensor
participant Quantizer as Quantizer
participant NCCL as NCCL AllGather
Note over FSDP2,NCCL: Forward Pass - Weight AllGather
FSDP2->>Tensor: fsdp_pre_all_gather(mesh, module, ...)
Tensor->>Tensor: Get FSDP state & training phase
Tensor->>Quantizer: Set usage (rowwise=True, columnwise=False)
Tensor->>Tensor: Remove padding from scale_inv
Tensor-->>FSDP2: Return (sharded_uint8_data, scale_inv, metadata)
FSDP2->>NCCL: all_gather(sharded_uint8_data)
NCCL-->>FSDP2: gathered_uint8_data
FSDP2->>Tensor: fsdp_post_all_gather(gathered_data, metadata)
Tensor->>Tensor: Add padding to scale_inv
Tensor->>Tensor: Reconstruct Float8/MXFP8Tensor
Tensor-->>FSDP2: Return full quantized weight tensor
Note over FSDP2,NCCL: Forward Compute with Quantized Weights
FSDP2->>FSDP2: Optional: Reshard after forward
Note over FSDP2,NCCL: Backward Pass - Weight AllGather
FSDP2->>Tensor: fsdp_pre_all_gather(mesh, module, ...)
Tensor->>Tensor: Detect backward phase
Tensor->>Quantizer: Set usage (rowwise=False, columnwise=True)
Tensor-->>FSDP2: Return columnwise data for backward
FSDP2->>NCCL: all_gather(columnwise_data)
NCCL-->>FSDP2: gathered_columnwise_data
FSDP2->>Tensor: fsdp_post_all_gather(gathered_data, metadata)
Tensor-->>FSDP2: Return full quantized weight tensor
Note over FSDP2,NCCL: Backward Compute & Gradient Update
FSDP2->>Quantizer: Update weights with gradients
Quantizer->>Quantizer: Amax reduction across shards
Quantizer->>Tensor: Quantize updated weights
11 files reviewed, 5 comments
| if src._columnwise_data is not None and dst._columnwise_data is not None: | ||
| dst._columnwise_data.copy_(src._columnwise_data.detach()) | ||
| dst._columnwise_scale_inv.copy_(src._columnwise_scale_inv.detach()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: potential AttributeError if _columnwise_data or _columnwise_scale_inv is None - the copy operation assumes both exist when both src and dst have the attributes, but doesn't validate they're not None before calling .copy_()
| out_data.append(scale_inv_out) | ||
| return [ | ||
| MXFP8Tensor( | ||
| shape=splitted_tensor_data[0].size(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: splitted_tensor_data[0].size() could raise AttributeError if _rowwise_data was None for the tensor being split - should validate data exists or use fallback to splitted_tensor_data[1].size()
|
/te-ci L1 pytorch |
| weightmat = weight | ||
| quantized_weight = False | ||
| if fp8 or debug: | ||
| quantized_weight = not isinstance(weight, QuantizedTensorStorage) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oof, this variable name was aggressively bad. Just to make things extremely explicit, perhaps we should rename to is_weight_param_quantized.
If we change this variable, we need to make sure to update the other places it is used:
| weightmat if quantized_weight else None, |
| weight if ctx.fp8 and ctx.quantized_weight else None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, will do that.
| else: | ||
| min_nodes_in_parent = float("inf") | ||
| closest_parent_fsdp_mod = None | ||
| for fsdp_mod in _module_state_mapping.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see any good alternatives to this O(n^2) search, but we should switch if we ever find a more natural way to traverse parent modules.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually if we are guranteed python3.7 or higher with dict order same as insertion order, we can stop at the first fsdp module which is the parent, since the fsdp modules in the dict are supposed to be put bottom up in FSDP2. In that case it is O(n). And it seems like for TE, we have requirement of python3.10+ and so I can remove the logic of finding the parent with minimum number of submodules, but rather return the first parent enciuntered. Does that sound reasonable?
| return inp, handle | ||
|
|
||
|
|
||
| @lru_cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using the default of maxsize=128, then a Transformer model with >32 layers will not fit in the LRU cache. Since this logic is per-module, it's more natural to cache the FSDP state within the module rather than in a global cache. How about something like:
def _get_module_fsdp_state(module):
if hasattr(module, "_get_fsdp_state"):
return module._get_fsdp_state()
if getattr(module, "_parent_fsdp_state", None) is not None:
return module._parent_fsdp_state
# Traverse to find parent FSDP
module._parent_fsdp_state = fsdp_state
return fsdp_stateThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, reasonable. Will do that!
| def __torch_dispatch__(cls, func, types, args, kwargs=None): | ||
| # FSDP2 related functions | ||
| # View op | ||
| if func == aten.view.default: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation is incorrect because it doesn't handle scales or column-wise data. We should prefer using the existing view impl:
| class _ViewFunc(torch.autograd.Function): |
I assume that this case is because FSDP does something like view(-1)? If so, we should check the view dims and only use this impl for that case:
if func == aten.view.default:
tensor = args[0]
dims = args[1] # ?
if len(dims) == 1:
if dims[0] not in (-1, tensor.numel()):
raise ValueError(...) # Invalid dims
warnings.warn(...) # Warn non-FSDP users
return MXFP8Tensor(
rowwise_data=tensor._rowwise_data.view(-1),
columnwise_data=tensor._columnwise_data.view(-1),
...
)
return tensor.view(dims)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When FSDP calls view its rather the view method of the MXFP8 tensor class that is called (which does ViewFunc.apply) instead of going into the torch dispatch mechanism. This view implementation was what had been there before this pr itself, I just moved it above. But it makes sense to correct it, How about I use _ViewFunc.apply here itself, rather than repeating the view and FSDp specific logic?
| fp8_dtype=tensor._fp8_dtype, | ||
| ) | ||
|
|
||
| if func == torch.ops.aten.copy_.default: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The base implementation respects the quantizer usages of the dst tensor, so it's incorrect for this impl to wipe everything out based on the quantizer usages of the src tensor.
The use case I am thinking about is a model that has been initialized with quantized_model_init. The modules have configured the weight quantizers, possibly with very complicated internal logic. Then we load a checkpoint and update the weights with param.copy_(checkpoint_param). Overwriting the quantizer would undo all of the hard work we have already done in the configuration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case checkpoint param would be high precision tensor right? I have a check, to copy things only if both src and dst are MXFP8 Tensors and so this code wont be executed. If not it would fall back to the base implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I see your argument. That philosophy is consistent even with per tensor scaling class. So would remove the quantizer overrride.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR enables full FSDP2 support for TransformerEngine with FP8-initialized weights, addressing three critical issues: memory footprint problems, weight update correctness, and enabling 8-bit weight allgather.
Key Changes
FSDP2 Allgather Hooks: Implemented fsdp_pre_all_gather and fsdp_post_all_gather methods for Float8Tensor and MXFP8Tensor classes. These hooks enable FSDP2 to work with custom tensor types by converting FP8 tensors to uint8 data for allgather, then reconstructing them post-allgather with appropriate metadata (scale_inv, dtype, quantizer).
Quantization Usage Management: Forward pass allgathers rowwise data (for TN GEMM as B argument), backward pass allgathers columnwise data (for TN GEMM as A argument). The PR intelligently determines forward vs backward state using FSDP's TrainingState and handles reshard_after_forward configuration.
Torch Dispatch Handlers: Added __torch_dispatch__ implementations for FSDP2-required ops: split (sharding), new_zeros (buffer creation), copy_ (data transfer), as_strided (padding removal), view (flattening). These handlers preserve FP8 tensor subclass through FSDP2's operations.
Current Scaling Quantization: For current scaling, the PR sets amax_reduction_group in quantizers to ensure all weight shards use synchronized scale_inv values after optimizer updates. Critical for training correctness.
DTensor Support: Updated reset_parameters to handle DTensor parameters in deferred initialization, correctly updating the local_tensor of sharded parameters.
Usage Validation Changes: Moved weight quantization usage validation from module __init__ to forward/backward functions, since FSDP2 requires different usages at different training phases.
Optimizer Fix: Fixed in-place operations to handle lists of quantized tensors, enabling optimizers like Adam to correctly update FP8 weights.
Testing
Significantly expanded test coverage with parameterized tests for different TE modules (Linear, LayerNormLinear, LayerNormMLP), quantization recipes (delayed scaling, current scaling, MXFP8), and FSDP configurations.
Confidence Score: 4/5
- This PR is safe to merge with minor reservations about edge case handling
- The implementation is comprehensive and well-tested. Core FSDP2 integration logic is sound with proper handling of forward/backward distinction, allgather hooks, and quantizer synchronization. The PR successfully addresses the three stated issues. Deducting 1 point due to minor concerns: (1) potential
AttributeErrorin MXFP8 edge cases whencolumnwise_datais None, (2)@lru_cacheon_get_module_fsdp_statecould theoretically cause stale references if modules are recreated (though unlikely in practice), and (3) previous reviewer identified syntax issues in view/reshape transpose handling that need verification. These are edge cases unlikely to occur in normal usage but should be validated. - transformer_engine/pytorch/tensor/mxfp8_tensor.py - verify None handling in copy_ operation (line 346), transformer_engine/pytorch/tensor/float8_tensor.py - verify transpose shape checking logic in view (lines 576-585)
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/tensor/mxfp8_tensor.py | 4/5 | Added FSDP2 support via torch dispatch handlers for split, new_zeros, copy_, as_strided, and slice operations. Implemented fsdp_pre_all_gather and fsdp_post_all_gather hooks for weight allgather. Minor issues: missing None checks could cause AttributeError in edge cases with columnwise_data. |
| transformer_engine/pytorch/tensor/float8_tensor.py | 4/5 | Added FSDP2 support via torch dispatch handlers and fsdp_pre/post_all_gather methods. Improved view/reshape to handle transpose cache. Enhanced new_zeros to deep copy scale_inv and quantizer. Split operation now handles transpose cache correctly. Minor syntax/logic issues with transpose shape checking. |
| transformer_engine/pytorch/distributed.py | 5/5 | Added _get_module_fsdp_state helper function with @lru_cache to find FSDP state for modules. Function finds lowest common ancestor FSDP module in hierarchy. Implementation looks correct with proper error handling for non-FSDP modules. |
| transformer_engine/pytorch/module/base.py | 5/5 | Updated reset_parameters to handle DTensor parameters for FSDP2 deferred initialization. Correctly updates local_tensor of DTensor with quantized weight. Sets amax_reduction_group for current scaling quantizers to ensure weight shards share same scale inverse. |
| transformer_engine/pytorch/quantized_tensor.py | 5/5 | Fixed in-place ops to handle lists of quantized tensors (needed for optimizer weight updates). Added recursive handling in maybe_update_inplace. Removed data parameter from make_like signature for clarity - method creates views, not copies. |
| transformer_engine/pytorch/module/linear.py | 5/5 | Moved weight quantization usage validation from init to forward/backward functions. Required for FSDP2 where different usages are allgathered in forward vs backward. Removed unnecessary quantizer updates when weights already in FP8. |
Sequence Diagram
sequenceDiagram
participant User
participant FSDP2
participant TEModule
participant Float8Tensor
participant Quantizer
participant AllGather
User->>FSDP2: Initialize model with fp8_model_init
FSDP2->>TEModule: Apply fully_shard()
TEModule->>Float8Tensor: Register FP8 weight parameters
Note over FSDP2,Float8Tensor: Forward Pass
FSDP2->>Float8Tensor: fsdp_pre_all_gather(module, mesh)
Float8Tensor->>Quantizer: Set rowwise_usage=True, columnwise_usage=False
Float8Tensor->>FSDP2: Return (sharded_uint8_data, metadata)
FSDP2->>AllGather: all_gather_into_tensor(uint8_data)
AllGather->>FSDP2: Return allgathered_data
FSDP2->>Float8Tensor: fsdp_post_all_gather(allgathered_data, metadata)
Float8Tensor->>Float8Tensor: Reconstruct FP8 tensor with rowwise usage
Float8Tensor->>TEModule: Return unsharded weight
TEModule->>TEModule: Forward computation
Note over FSDP2,Float8Tensor: Backward Pass
TEModule->>TEModule: Backward computation
FSDP2->>Float8Tensor: fsdp_pre_all_gather(module, mesh)
Float8Tensor->>Quantizer: Set rowwise_usage=False, columnwise_usage=True
Float8Tensor->>FSDP2: Return (sharded_uint8_data, metadata)
FSDP2->>AllGather: all_gather_into_tensor(uint8_data)
AllGather->>FSDP2: Return allgathered_data
FSDP2->>Float8Tensor: fsdp_post_all_gather(allgathered_data, metadata)
Float8Tensor->>Float8Tensor: Reconstruct FP8 tensor with columnwise usage
Float8Tensor->>TEModule: Return unsharded weight
Note over FSDP2,Float8Tensor: Optimizer Step
FSDP2->>Float8Tensor: Reduce-scatter gradients
User->>Quantizer: optimizer.step() on FP8 shards
Quantizer->>Float8Tensor: Update FP8 weight shards with same scale_inv
Float8Tensor->>Quantizer: Sync amax across shards (current scaling)
11 files reviewed, no comments
Description
Motivation:
What this PR does?
Type of change
Changes
FSDP Allgather Hooks for FP8/MXFP8: Adds fsdp_pre_allgather and fsdp_post_allgather methods for for FP8/MXFP8 tensors, since allgather is only supported for native torch tensors with uint8/fp16/bf16/fp32 data types. fsdp_pre_all_gather method for us would return the uint8 sharded tensors for FP8/MXFP8 that we need to allgather and the metadata that is needed to reconstruct the FP8/MXFP8 tensor post allgather. Post_Allgather reconstructs the Float8/MXFP8 tensor from the allgathered uint8 data.
FP8/MXFP8 Torch Dispatch Functions for FSDP2 to handle ops on both rowwise/columnwise data(MXFP8), data/transpose(FP8). NOTE(MXFP8 tensors without padding requirements are only handled. If padding is needed we down the dequantization-compute-quantization route).
Quantized Tensor Class Issues:
Validating rowwise/columnwise Usages for quantizers/tensors in TE Layers
Resetting Parameters for Deferred Initialization(meta device)
Test and Miscellaneous issues
Checklist:
Summary by CodeRabbit
Release Notes