Skip to content

Conversation

@vthumbe1503
Copy link
Collaborator

@vthumbe1503 vthumbe1503 commented Oct 7, 2025

Description

Motivation:

  • FSDP2 training currently doesn't work with model initialized with fp8 weights. And if high precision weights are used with TE layers, the memory consumed by the model is more than what the model would consume with BF16 when te auto-cast is used, making it difficult to adopt TE for fp8 based fsdp2 training(issue). Hence it will be useful to get FSDP2 to work with FP8 initialized weights(issue).
  • Along with fixing the memory usage for model initialized with FP8 weight tensors we also want FSDP2 to actually work in terms of the FP8 tensors getting updated correctly after every training step. Current behavior is the Float8Tensors for weights dont get updated. This is not just specific to FSDP but also to DDP with fp8 initialized weights.issue
  • We also want the FSDP weight allgather to use FP8 instead of a high precision allgather for efficient training performance. Currently in TE for fp8 initialized weights, allgather happens in high precision.(issue).

What this PR does?

  • Enables FSDP2 based model training EtoE for any pytorch model with TE layers and FP8 initialized weights
    • Solves the memory foot-print issue with FP8 initialized weights. Initialization with FP8(per-tensor scaling) on balckwell takes half the memory footprint compared to BF16 which is expected. MXFP8 and BF16 consume the same amount of memory due to both rowwise/columnwise usages needed in case of MXFP8.
    • Fixes the FP8 weight updates when model is initialized with FP8 weights to ensure correctness of training results
    • Enables 8bit weight Allgather for both FP8/MXFP8 tensors.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • FSDP Allgather Hooks for FP8/MXFP8: Adds fsdp_pre_allgather and fsdp_post_allgather methods for for FP8/MXFP8 tensors, since allgather is only supported for native torch tensors with uint8/fp16/bf16/fp32 data types. fsdp_pre_all_gather method for us would return the uint8 sharded tensors for FP8/MXFP8 that we need to allgather and the metadata that is needed to reconstruct the FP8/MXFP8 tensor post allgather. Post_Allgather reconstructs the Float8/MXFP8 tensor from the allgathered uint8 data.

    • Handling quantization usages in allgather: Assumption here is that fsdp_pre_all_gather and post_all_gather methods are only going to be called for the weight tensors, which is a fair assumption since fsdp is only used to shard the weights. Which means that we would be using rowwise usage for the forward pass and columnwise usage for the backward pass.
    • Identifying forward/backward pass during allgather: This is needed since just one to rowwise/columnise usages need to be allgathered based on whether it is a forward/backward pass of the training step. fsdp_pre_all_gather method passes module as an argument which is essentially nn Module that has the Quantized tensor registered as a parameter. This module might not necessarily be an FSDP module since we might be wrapping the module at a much higher level in the heirarchy(For eg TransformerLayer and not wrapping the submodule Linear). Hence we have a method that computes the lowest common ancestor FSDP module and uses that to get FSDP state which has the information as to whether it is a forward or backward pass. NOTE: The return value is cached with lru_cache since we dont want to call during every iteration/allgather done during training. The return value is a reference which is mutated internally by FSDP during the course of training.
    • Reshard After Forward: FSDP2 allows for a configuration that tells whether the parameters need to be resharded after forward pass (meaning weights will be re-allgathered again for backward pass). By default, this configuration is set to False for the root module and True for submodules. This configuration is obtainable from the FSDP state of the module , the parameter belongs to. And is used to determine whether we need to send both rowwise/columnwise data in one-go or just one of them based on forward/backward pass. This is more important in MXFP8 since we might want to send both the usages, instead of sending just one usage, dequantizing and quantizing back to get all necessary usages(leading to quantization errors).
    • Current Scaling Quantization: In case of Current Scaling quantization, we need to make sure there is one single amax/scale inverse being used across all the shards which is going to be true when the model is initialized. However, each quantized weight shard is updated independently by the optimizer during training. And hence we need to set amax reduction group in quantizer if not already set. And so this is done in the allgather of forward pass itself(by utilizing fsdp mesh information), so that when the weight shard is updated, quantizer is going to synchronize among the shards to compute a single amax and hence make sure each weight shard uses the same scale inverse.
  • FP8/MXFP8 Torch Dispatch Functions for FSDP2 to handle ops on both rowwise/columnwise data(MXFP8), data/transpose(FP8). NOTE(MXFP8 tensors without padding requirements are only handled. If padding is needed we down the dequantization-compute-quantization route).

    • Split Function If the model is initialized on CUDA device at the start, torch chunk/split is called on our custom Quantized tensors to split tensors along dimension 0. At the end of split FSDP2 keeps the split/shard that is needed for that process and discards everything else to free memory before model training. NOTE: In case of meta deferred initialization this method isnt called. And quantized tensors are directly instantiated for the weight shard corresponding to the process rather than initializing everything and discarding the shards not needed.
    • new_zeros: Implementing this function will make sure a new tensor is created with shape that the shard is supposed to be of. Original implementation in Float8Tensor dint create a deep copy for the scale inverses. That is fixed now.
    • copy: Splitted/Sharded tensor is then copied to the zero tensor created above.
    • as_strided: FSDP2 allows for a possibility where one of the shards might have fewer elements than the other shard if split dimension 0 has number of elements not divisible by the number of shards. It pads the smaller shard. And hence calls as_strided API after allgather to remove the padding. Currently we dont handle the case where divisibilty condition is not met(would be complicated for mxfp8 and beyond scope of this PR) and hence as_strided API is essentially a no-op for us.
    • view: In FSDP2, sharded parameters are flattened with view and that is used to allgather when compiled autograd is enabled. However, for MXFP8 we throw an error if we flatten the tensor since the last dimension of MXFP8 should never change. Currently in that case, we are enabling the dequantization followed by high precision view path, so that FSDP2 doesnt fail. However, we raise a warning when that happens. This is not concern for us at the moment since we dont use compiled autograd and so this view is essentially not even used.
  • Quantized Tensor Class Issues:

    • Missing Dequantize/Compute/Quantize Pathway: When optimizer is applied on FP8/MXFP8 weights, optimizer sends the optimizer ops(lerp for weight update) on a list of Float8 Weights instead of individually doing an op on each Float8 weight seperately. Our normal dequantize/Compute Op/Quantize route didnt handle a list of Float8 Tensors and so, weights were not getting updated in place. PR fixes this.
    • make_like API relying on data Attribute: make_like API in Quantized tensor class should not be setting data attribute since that is specific to Float8Tensor. So that setting logic is moved to Float8Tensor class instead.
  • Validating rowwise/columnwise Usages for quantizers/tensors in TE Layers

    • Weight Tensor Usage Validation: Currently we validate the presence of all desired rowwise/columnwise usages for weight tensors in the forward pass of our Layers itself. However in case of FSDP2, different usages are allgathered in forward and in backward pass. So validation of appropriate quantization usages are moved to forward and backward functions of the layers i.e rowwise usage is needed in forward and columnwise usage is needed in backward.
    • Quantizer Usage Validation: We also update the weight quantizer even when weights are already in FP8. If weights are already in FP8, there is no need to update the quantizer since the damage is already done and that quantizer is never going to be used. And hence this update is now removed from the code.
  • Resetting Parameters for Deferred Initialization(meta device)

    • Updating Dtensors instead of regular tensor: In case of deferred initialization with FSDP2. Parameters are Dtensors that just hold unmaterialized shard needed by the process. And so the local tensor of Dtensor needs to be updated with quantized weights initialized with param_init_fn.
    • Current scaling quantization: For this case, amax reduction group needs to passed to the quantizer so that all weight shards initialized share a single scale inverse.
  • Test and Miscellaneous issues

    • More complete Test Cases for FSDP2: Originally the test only enabled to test a linear layer. Now we can test it with model created with different TE layers. And tests for combinations with and without fp8 model init and different quantization recipes(fp8/mxfp8). NOTE: NVFP4 is pending.
    • View and Reshape not handling Columnwise elegantly In case the columnwise data is present and is accurate, view and reshape ops are now also performed on the transpose(FP8)/columnwise-data(MXFP8) instead of invalidating them.
    • Float8 make_empty API: For make_empty if transpose is desired, shape of transpose created originally was (shape[-1), math.prod(shape[:-1])). Now made it consistent with the transpose shapes we create in C++ which is essentially (shape[-1], shape[0], shape[1]....shape[-2]). This is needed since, we are handling transpose ops in the torch dispatch needed for FSDP2 and we need to be consistent everywhere.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Summary by CodeRabbit

Release Notes

  • New Features
    • Added FP8 mixed-precision training support with FSDP2/HSDP distributed sharding.
    • Introduced multiple FP8 quantization scaling recipes: delayed scaling, current scaling, and MX_FP8 block scaling.
    • Expanded distributed training configuration options: batch size, sequence length, data type, layer configuration, number of layers, device placement, and sharding specification.
    • Improved distributed tensor parameter support and synchronization for FSDP integration.

@vthumbe1503 vthumbe1503 changed the title FSDP2 Weight Update Fix [Pytorch] FSDP2 Weight Update Fix Oct 8, 2025
@vthumbe1503 vthumbe1503 changed the title [Pytorch] FSDP2 Weight Update Fix [PyTorch] FSDP2 Weight Update Fix Oct 8, 2025
@vthumbe1503 vthumbe1503 changed the title [PyTorch] FSDP2 Weight Update Fix [PyTorch] TE FSDP2 Support for FP8/MXFP8 Oct 17, 2025
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Adds FSDP2 (Fully Sharded Data Parallel 2) support for FP8 and MXFP8 quantized tensors in PyTorch Transformer Engine, enabling distributed training with FP8 mixed-precision.

Key Changes:

  • Implemented fsdp_pre_all_gather and fsdp_post_all_gather hooks for both Float8Tensor and MXFP8Tensor to handle FSDP weight sharding/gathering lifecycle
  • Added custom __torch_dispatch__ handlers for FSDP-required operations: aten.split.Tensor, aten.new_zeros, aten.as_strided, aten.copy_, and aten.slice.Tensor
  • Enhanced transpose caching logic to properly maintain transposed views through various tensor operations
  • Added training state-aware quantizer usage control (rowwise vs columnwise) based on forward/backward pass detection

Major Implementation Details:

  • FSDP2 integration distinguishes between forward/backward passes using TrainingState.PRE_BACKWARD to selectively gather only needed tensor representations
  • For MXFP8, operations validate 128-byte alignment constraints and fall back to dequantization when constraints aren't met
  • Transpose cache maintenance across splits, views, and resharding ensures performance optimization for Hopper/L40 architectures

Issues Found:
Multiple critical None-handling bugs exist in MXFP8 dispatch handlers where operations assume non-None data/scale tensors, which would cause AttributeError at runtime when certain usage flags are disabled.

Confidence Score: 3/5

  • This PR has several critical runtime issues that need resolution before merging, particularly around None-handling in MXFP8 tensor operations
  • Score reflects multiple logic bugs identified by previous reviewers (AttributeError, NameError, variable shadowing) that would cause runtime failures in MXFP8 operations. While Float8Tensor changes appear more robust, MXFP8Tensor has ~8-10 critical None-handling issues across split, slice, copy, and post_all_gather operations that need fixes
  • transformer_engine/pytorch/tensor/mxfp8_tensor.py requires significant attention for None-handling fixes across all new dispatch handlers before this can safely merge

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/tensor/float8_tensor.py 4/5 Adds FSDP2 support with fsdp_pre_all_gather and fsdp_post_all_gather hooks, implements aten.split.Tensor, aten.new_zeros, and aten.as_strided handlers with transpose caching for FP8 tensors
transformer_engine/pytorch/tensor/mxfp8_tensor.py 3/5 Implements FSDP2 support and multiple torch dispatch handlers (split, as_strided, copy_, slice, new_zeros) for MXFP8 tensors; contains several critical None-handling issues that need resolution

Sequence Diagram

sequenceDiagram
    participant FSDP2
    participant Float8Tensor/MXFP8Tensor
    participant Quantizer
    participant DeviceMesh

    Note over FSDP2: Forward Pass (weights needed)
    FSDP2->>Float8Tensor/MXFP8Tensor: fsdp_pre_all_gather(mesh, module, ...)
    Float8Tensor/MXFP8Tensor->>Quantizer: check training_state & reshard_after_forward
    Quantizer->>Quantizer: set_usage(rowwise=True, columnwise=False)
    Float8Tensor/MXFP8Tensor->>FSDP2: return (sharded_data, metadata)
    FSDP2->>DeviceMesh: all_gather(sharded_data)
    DeviceMesh->>FSDP2: all_gather_outputs
    FSDP2->>Float8Tensor/MXFP8Tensor: fsdp_post_all_gather(outputs, metadata, ...)
    Float8Tensor/MXFP8Tensor->>Float8Tensor/MXFP8Tensor: reconstruct full tensor
    Float8Tensor/MXFP8Tensor->>Float8Tensor/MXFP8Tensor: update_usage(rowwise=True)
    Float8Tensor/MXFP8Tensor->>FSDP2: return gathered_tensor
    
    Note over FSDP2: Compute forward pass
    
    Note over FSDP2: Backward Pass (gradients computed)
    FSDP2->>Float8Tensor/MXFP8Tensor: fsdp_pre_all_gather(mesh, module, ...)
    Float8Tensor/MXFP8Tensor->>Quantizer: check training_state (PRE_BACKWARD)
    Quantizer->>Quantizer: set_usage(rowwise=False, columnwise=True)
    Float8Tensor/MXFP8Tensor->>FSDP2: return (transpose_data, metadata)
    FSDP2->>DeviceMesh: all_gather(transpose_data)
    DeviceMesh->>FSDP2: all_gather_outputs
    FSDP2->>Float8Tensor/MXFP8Tensor: fsdp_post_all_gather(outputs, metadata, ...)
    Float8Tensor/MXFP8Tensor->>Float8Tensor/MXFP8Tensor: reconstruct with transpose
    Float8Tensor/MXFP8Tensor->>Float8Tensor/MXFP8Tensor: update_usage(columnwise=True)
    Float8Tensor/MXFP8Tensor->>FSDP2: return gathered_tensor
Loading

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Varun Thumbe <[email protected]>
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR significantly enhances the FSDP2 test infrastructure for Transformer Engine by adding comprehensive support for FP8 mixed-precision training with distributed sharding.

Key Changes:

  • Expanded FP8 recipe support: added Float8CurrentScaling and MXFP8BlockScaling alongside existing DelayedScaling
  • Introduced flexible layer configuration system supporting 5 TE layer types (Linear, LayerNormLinear, LayerNormMLP, MultiheadAttention, TransformerLayer)
  • Added meta device initialization workflow for deferred parameter materialization after FSDP2 sharding
  • Implemented test_fp8_fsdp2_allgather() validation function to verify FP8 allgather correctness against manual FP32 allgather
  • Enhanced custom attribute save/restore logic to handle QuantizedTensor metadata correctly with FSDP2 DTensors
  • Replaced simple 3-layer network with configurable multi-layer architecture supporting both reshard_after_forward=True/False test cases

The test file is well-structured with clear separation of concerns: model initialization, FSDP2 setup, training loop, and validation logic.

Confidence Score: 5/5

  • This PR is safe to merge with high confidence - the changes are well-tested, properly structured, and add comprehensive FSDP2 support.
  • Score reflects thorough implementation with proper error handling, comprehensive test coverage of multiple FP8 recipes and layer types, correct FSDP2 integration patterns (save/restore custom attrs, DTensor handling), and validation logic to verify FP8 allgather correctness. The code follows established patterns and includes clear documentation.
  • No files require special attention - the test file is comprehensive and correctly implements FSDP2 FP8 support.

Important Files Changed

File Analysis

Filename Score Overview
tests/pytorch/distributed/run_fsdp2_model.py 5/5 Comprehensive FSDP2 test script adding support for multiple FP8 recipes, flexible layer configurations, meta device initialization, and FP8 allgather validation

Sequence Diagram

sequenceDiagram
    participant Main as Main Process
    participant Init as Model Init
    participant FSDP as FSDP2 Sharding
    participant Train as Training Loop
    participant Test as FP8 Test

    Main->>Main: Parse args & setup distributed
    Main->>Init: Create FP8 recipe (delayed/current/mx_fp8)
    
    alt FP8 Init Enabled
        Init->>Init: fp8_model_init(recipe)
    else FP8 Init Disabled
        Init->>Init: nullcontext()
    end
    
    Init->>Init: init_te_model(config)
    Note over Init: Create model on meta/cuda device
    
    Init->>FSDP: save_custom_attrs(model)
    Note over FSDP: Save QuantizedTensor metadata
    
    FSDP->>FSDP: get_device_mesh(world_size, sharding_dims)
    Note over FSDP: Setup FSDP or HSDP mesh
    
    FSDP->>FSDP: shard_model_with_fsdp2(model, mesh)
    Note over FSDP: Apply fully_shard to children & root
    
    FSDP->>FSDP: restore_custom_attrs(model, custom_attrs)
    Note over FSDP: Restore FP8 metadata to DTensors
    
    alt Meta Device Init
        FSDP->>FSDP: reset_parameters()
        Note over FSDP: Materialize sharded params on cuda
    end
    
    FSDP->>Train: Create optimizer
    
    loop For each iteration
        Train->>Train: Generate input & target
        Train->>Train: Forward with te.autocast(recipe)
        Train->>Train: Compute loss
        Train->>Train: Backward pass
        Train->>Train: Optimizer step
    end
    
    alt FP8 Init Enabled
        Train->>Test: test_fp8_fsdp2_allgather(model)
        Test->>Test: Manual FP32 allgather
        Test->>Test: FSDP2 FP8 allgather (unshard)
        Test->>Test: Validate both match
        Test->>Test: Reshard model
    end
    
    Main->>Main: Destroy process group
Loading

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Varun Thumbe <[email protected]>
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds FSDP2 (Fully Sharded Data Parallel v2) support to Transformer Engine by enabling DTensor parameter handling in the base module.

Key Changes:

  • Modified register_parameter() to prevent overwriting FP8-specific metadata when FSDP2 re-registers parameters as DTensors
  • Enhanced reset_parameters() to detect and handle DTensor parameters by operating on their local tensors
  • Added device mesh integration for Float8CurrentScalingQuantizer to configure amax reduction groups for distributed training
  • Implemented proper DTensor reconstruction after meta-device materialization
  • Ensured quantized local tensors are correctly wrapped back into DTensor parameters

Integration Points:

  • DTensor detection via isinstance(param, DTensor) check
  • Local tensor extraction and manipulation via param._local_tensor
  • Device mesh group configuration for FP8 scaling synchronization across shards
  • Parameter wrapping preserves both DTensor structure and FP8 quantization

Confidence Score: 4/5

  • This PR is safe to merge with minor considerations for edge cases in DTensor handling
  • The implementation correctly handles DTensor parameter registration and initialization. The logic properly distinguishes between DTensor and regular tensors, extracts local tensors for processing, and reconstructs DTensors with appropriate device mesh configuration. The amax reduction group setup for Float8CurrentScalingQuantizer is correctly conditioned on both DTensor type and quantizer type. However, the score is 4 instead of 5 because: (1) the high-precision init value methods are attached to local tensors which relies on DTensor's attribute delegation pattern, and (2) there's no explicit validation that dtensor_param maintains valid device_mesh/placements attributes throughout the flow, though the logic appears sound
  • No files require special attention beyond standard FSDP2 testing with FP8 quantization enabled

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/module/base.py 4/5 Added FSDP2 DTensor support in parameter registration and reset, including proper handling of local tensors, device mesh configuration for FP8 quantization, and parameter wrapping

Sequence Diagram

sequenceDiagram
    participant FSDP2 as FSDP2
    participant Module as TransformerEngineBaseModule
    participant ResetParams as reset_parameters()
    participant Quantizer as Float8CurrentScalingQuantizer
    participant DTensor as DTensor

    FSDP2->>Module: register_parameter(name, DTensor)
    Note over Module: Check if param_init_meta exists<br/>Only initialize once to preserve FP8 kwargs
    Module->>Module: Store param_init_meta[name]
    
    FSDP2->>ResetParams: Trigger parameter initialization
    ResetParams->>ResetParams: Check if param is DTensor
    ResetParams->>DTensor: Extract _local_tensor
    
    alt Parameter on meta device
        ResetParams->>ResetParams: Create empty_like on cuda
        ResetParams->>DTensor: Reconstruct DTensor.from_local()<br/>with device_mesh & placements
    end
    
    ResetParams->>ResetParams: Apply init_fn to local tensor
    
    alt FP8 quantization enabled
        ResetParams->>Quantizer: Configure quantizer settings
        alt Is DTensor && Float8CurrentScaling
            ResetParams->>DTensor: Get device_mesh
            ResetParams->>Quantizer: Set amax_reduction_group<br/>from device_mesh.get_group()
            ResetParams->>Quantizer: Enable with_amax_reduction
        end
        ResetParams->>Quantizer: Quantize local tensor
        Quantizer-->>ResetParams: Return QuantizedTensor
    end
    
    ResetParams->>DTensor: Update _local_tensor with quantized tensor
    ResetParams->>DTensor: Wrap as nn.Parameter
    ResetParams->>Module: setattr(name, DTensor parameter)
Loading

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Varun Thumbe <[email protected]>
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Enables FSDP2 training with FP8/MXFP8 initialized weights by implementing custom allgather hooks (fsdp_pre_all_gather and fsdp_post_all_gather) that serialize FP8 tensors to uint8 for distributed communication and reconstruct them post-allgather.

Key Changes:

  • FP8 Allgather Support: Float8Tensor and MXFP8Tensor now implement FSDP2 hooks that return uint8 data with metadata (scale_inv, dtype, quantizer) for allgather, enabling FP8 communication instead of high-precision
  • Selective Usage Based on Training State: Pre-allgather hooks optimize memory by gathering only rowwise data for forward pass and columnwise data for backward pass when reshard_after_forward=True
  • DTensor Integration: TransformerEngineBaseModule.reset_parameters() now handles FSDP2's DTensor parameters by operating on _local_tensor and preserving FP8 metadata across parameter re-registration
  • Transpose Cache Management: Enhanced __torch_dispatch__ handlers for split/view/new_zeros/as_strided ops to maintain transpose caches for both data and data_transpose, improving performance
  • Amax Reduction Setup: Quantizers are configured with appropriate reduction groups for synchronized scale updates across FSDP shards

Issues Found:

  • Potential tensor unpacking bug in mxfp8_tensor.py:613-617 where both [:2] and [-2:] slicing could select duplicate tensors if validation fails

Confidence Score: 4/5

  • This PR is largely safe to merge with one logical issue that needs verification in edge cases
  • The implementation is well-structured and addresses a significant feature gap (FSDP2 support for FP8 weights). The core allgather hook logic is sound and properly handles the forward/backward pass distinction. However, there's a potential edge-case bug in MXFP8Tensor's fsdp_post_all_gather where tensor unpacking could fail if the tuple length doesn't match usage flags, though this is unlikely in normal operation since the pre/post hooks are paired
  • transformer_engine/pytorch/tensor/mxfp8_tensor.py - verify the tensor unpacking logic at lines 613-617 handles all edge cases correctly

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/tensor/float8_tensor.py 4/5 Adds FSDP2 support with fsdp_pre_all_gather and fsdp_post_all_gather hooks, implements FP8 allgather by returning uint8 data with metadata for reconstruction, enhances __torch_dispatch__ to handle transpose caching for split/view/new_zeros/as_strided ops
transformer_engine/pytorch/tensor/mxfp8_tensor.py 3/5 Adds FSDP2 allgather hooks for MXFP8 tensors with selective rowwise/columnwise data gathering based on training state, implements torch dispatch handlers for split/as_strided/copy_/slice/new_zeros ops with MXFP8 block scaling constraints, has potential tensor unpacking issue in fsdp_post_all_gather
transformer_engine/pytorch/distributed.py 4/5 Adds _get_module_fsdp_state helper with LRU caching to retrieve FSDP state from modules or their closest FSDP parent
transformer_engine/pytorch/module/base.py 4/5 Updates reset_parameters to handle DTensor (FSDP2) by operating on _local_tensor, preserves FP8 metadata during FSDP2's re-registration of parameters as DTensors, sets up amax reduction groups for DTensor quantizers

Sequence Diagram

sequenceDiagram
    participant FSDP2 as FSDP2
    participant Float8Tensor as Float8Tensor/MXFP8Tensor
    participant Quantizer as Quantizer
    participant Module as TransformerEngineModule
    
    Note over FSDP2,Module: Forward Pass (or Backward if reshard_after_forward=True)
    
    FSDP2->>Float8Tensor: fsdp_pre_all_gather(mesh, orig_size, module, ...)
    Float8Tensor->>Module: _get_module_fsdp_state(module)
    Module-->>Float8Tensor: fsdp_state
    
    Float8Tensor->>Quantizer: copy()
    Quantizer-->>Float8Tensor: quantizer_copy
    
    alt reshard_after_forward=True
        Float8Tensor->>Float8Tensor: Determine forward vs backward from training_state
        Float8Tensor->>Quantizer: set_usage(rowwise=!is_backward, columnwise=is_backward)
        Note over Float8Tensor: Pack only needed data based on pass direction
    else reshard_after_forward=False
        Note over Float8Tensor: Pack both rowwise and columnwise if needed
    end
    
    Float8Tensor-->>FSDP2: (sharded_uint8_tensors, metadata)
    
    FSDP2->>FSDP2: AllGather uint8 tensors across ranks
    
    FSDP2->>Float8Tensor: fsdp_post_all_gather(all_gather_outputs, metadata, param_dtype, out)
    
    Float8Tensor->>Float8Tensor: Unpack all_gather_outputs and metadata
    
    alt out exists
        Float8Tensor->>Float8Tensor: update_usage() on existing tensor
    else out is None
        Float8Tensor->>Float8Tensor: Construct new Float8Tensor/MXFP8Tensor
        Float8Tensor->>Float8Tensor: update_usage() on new tensor
    end
    
    Float8Tensor-->>FSDP2: (reconstructed_fp8_tensor, all_gather_outputs)
    
    Note over FSDP2,Module: Tensor ready for forward/backward computation
Loading

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@vthumbe1503
Copy link
Collaborator Author

/te-ci pytorch

Signed-off-by: Varun Thumbe <[email protected]>
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR makes focused changes to quantized_tensor.py to improve FSDP2 compatibility:

  • Recursive list handling: Added support for recursively updating lists of tensors in in-place operations (lines 436-439). This handles operations like split that return multiple tensors, ensuring QuantizedTensors within lists are properly updated.

  • Simplified make_like method: Removed the data parameter from the base class implementation (lines 493-506). The method now focuses solely on creating views of tensors. This change is safe because:

    • Subclasses like Float8Tensor override this method and still support the data parameter for backward compatibility
    • The base class docstring now correctly reflects that the method is "intended to create view of tensors"
    • Existing usages with data= parameter are handled by the overridden methods in subclasses

These are minimal, well-scoped changes that support the broader FSDP2 integration without breaking existing functionality.

Confidence Score: 4/5

  • This PR is safe to merge with minimal risk
  • The changes are minimal and focused, with only two small modifications to quantized_tensor.py. The recursive list handling is a straightforward addition that improves robustness. The make_like signature change is safe because subclasses override the method and maintain backward compatibility. No issues found that would impact correctness or introduce bugs.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/quantized_tensor.py 4/5 Added recursive list handling for in-place operations and simplified make_like method by removing data parameter. Changes are minimal and focused on improving FSDP2 compatibility.

Sequence Diagram

sequenceDiagram
    participant FSDP as FSDP2 Framework
    participant PreHook as fsdp_pre_allgather
    participant QT as QuantizedTensor
    participant PostHook as fsdp_post_allgather
    
    Note over FSDP: Forward/Backward Pass Begins
    FSDP->>PreHook: Call pre_allgather hook
    PreHook->>QT: Extract uint8 data + metadata
    Note over QT: For FP8: extract _data tensor<br/>For MXFP8: extract rowwise/columnwise data
    QT-->>PreHook: Return (uint8_tensors, metadata)
    PreHook-->>FSDP: Return allgather input
    
    Note over FSDP: Perform AllGather on uint8 data
    
    FSDP->>PostHook: Call post_allgather hook
    PostHook->>QT: Reconstruct from allgathered data
    Note over QT: Rebuild Float8/MXFP8 tensor<br/>from uint8 + metadata
    QT-->>PostHook: Return reconstructed tensor
    PostHook-->>FSDP: Return full tensor
    
    Note over FSDP: Continue computation with full tensor
Loading

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@vthumbe1503
Copy link
Collaborator Author

/te-ci pytorch

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Enables FSDP2 distributed training with FP8-initialized weights by implementing allgather hooks, torch dispatch operations, and DTensor support for TE quantized tensors.

Key Changes:

  • Implements fsdp_pre_all_gather and fsdp_post_all_gather hooks for Float8Tensor and MXFP8Tensor to enable 8-bit weight allgather (instead of high-precision)
  • Adds torch dispatch support for FSDP2 tensor operations: split, copy_, slice, view, as_strided, new_zeros
  • Updates reset_parameters in TransformerEngineBaseModule to handle DTensor for deferred initialization (meta device)
  • Fixes optimizer weight updates by recursively handling lists of quantized tensors in in-place operations
  • Moves quantizer usage validation from forward to backward pass to support phase-aware allgather
  • Configures amax reduction groups for current scaling quantizer to synchronize scale inverses across FSDP shards
  • Comprehensive test coverage for multiple TE layers with delayed/current/MX_FP8 scaling recipes

Memory Impact:

  • FP8 per-tensor scaling reduces memory footprint by ~50% vs BF16 on Blackwell (as expected)
  • MXFP8 block scaling maintains similar memory to BF16 due to rowwise+columnwise storage requirements

Confidence Score: 4/5

  • This PR is mostly safe to merge with one critical bug fix needed in MXFP8 shape handling
  • Score reflects solid implementation with comprehensive test coverage, but deducted 1 point due to critical bug in mxfp8_tensor.py:658 where both rowwise/columnwise data can be None causing AttributeError, and minor concerns about LRU cache causing potential memory leaks
  • transformer_engine/pytorch/tensor/mxfp8_tensor.py:658 requires fix for None handling

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/tensor/float8_tensor.py 4/5 Adds FSDP2 allgather hooks and torch dispatch ops (view, split, copy, slice, as_strided, new_zeros) to support 8-bit weight sharding. Implements current scaling quantizer sync for FSDP weight updates. Potential issue with shape handling in line 658.
transformer_engine/pytorch/tensor/mxfp8_tensor.py 3/5 Implements FSDP2 support with rowwise/columnwise data handling for block-scaled FP8. Adds dispatch ops (split, copy, slice, as_strided, new_zeros). Critical bug at line 658 where both data tensors can be None causing AttributeError.
transformer_engine/pytorch/module/base.py 4/5 Updates reset_parameters to handle DTensor (FSDP2 deferred init) by quantizing local tensor and reconstructing DTensor. Adds amax_reduction_group configuration for current scaling. Guard prevents metadata loss during DTensor conversion.
transformer_engine/pytorch/distributed.py 4/5 Adds _get_module_fsdp_state helper with LRU cache to find FSDP state for allgather hooks. Cache could cause memory leaks but likely acceptable given module stability during training.
transformer_engine/pytorch/quantized_tensor.py 5/5 Fixes in-place ops to recursively handle lists of tensors (optimizer sends batched updates). Removes data parameter from make_like to avoid confusion between view creation and data initialization.

Sequence Diagram

sequenceDiagram
    participant FSDP as FSDP2
    participant QT as QuantizedTensor (FP8/MXFP8)
    participant Helper as _get_module_fsdp_state
    participant Optimizer as Optimizer
    
    Note over FSDP,QT: Forward Pass - Weight Allgather
    FSDP->>QT: fsdp_pre_all_gather(module, mesh, ...)
    QT->>Helper: Get FSDP state to determine phase
    Helper-->>QT: training_state, reshard_after_forward
    QT->>QT: Set quantizer.rowwise_usage=True, columnwise=False
    QT-->>FSDP: (uint8_data, ...), metadata
    FSDP->>FSDP: All-gather uint8 shards
    FSDP->>QT: fsdp_post_all_gather(gathered_outputs, metadata)
    QT->>QT: Reconstruct FP8 tensor with rowwise usage
    QT-->>FSDP: Allgathered FP8 weight
    
    Note over FSDP,QT: Forward Pass Compute
    FSDP->>QT: Forward computation with FP8 weights
    
    Note over FSDP,QT: Backward Pass - Weight Allgather (if reshard_after_forward)
    FSDP->>QT: fsdp_pre_all_gather(module, mesh, ...)
    QT->>Helper: Get FSDP state
    Helper-->>QT: training_state=PRE_BACKWARD
    QT->>QT: Set quantizer.rowwise=False, columnwise_usage=True
    QT-->>FSDP: (uint8_data_transpose, ...), metadata
    FSDP->>FSDP: All-gather transpose/columnwise shards
    FSDP->>QT: fsdp_post_all_gather(gathered_outputs, metadata)
    QT->>QT: Reconstruct FP8 tensor with columnwise usage
    QT-->>FSDP: Allgathered FP8 weight
    
    Note over FSDP,Optimizer: Gradient Computation & Weight Update
    FSDP->>FSDP: Compute gradients, reduce-scatter
    Optimizer->>QT: In-place update (lerp on list of tensors)
    QT->>QT: Dequantize, apply op, quantize with amax reduction
    Note over QT: Amax synchronized across shards<br/>for current scaling
Loading

11 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

columnwise_scale_inv=columnwise_scale_inv,
fp8_dtype=fp8_dtype,
dtype=param_dtype,
shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: if both rowwise_data and columnwise_data are None (when both usage flags are False), accessing .shape raises AttributeError

Suggested change
shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape,
shape=rowwise_data.shape if rowwise_data is not None else (columnwise_data.shape if columnwise_data is not None else torch.Size([0])),

Signed-off-by: Varun Thumbe <[email protected]>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR enables FSDP2 training with FP8-initialized weights by implementing custom allgather hooks and torch dispatch handlers. The implementation addresses three key issues: memory footprint with FP8 weights, correct weight updates during training, and efficient 8-bit weight allgather.

Key Changes:

  • Implements fsdp_pre_all_gather and fsdp_post_all_gather hooks for Float8Tensor and MXFP8Tensor to handle FP8/MXFP8 allgather using uint8 data
  • Adds torch dispatch handlers for split, slice, copy, new_zeros, as_strided, and view operations needed by FSDP2
  • Uses FSDP state to detect forward vs backward pass and set appropriate rowwise/columnwise quantizer usage
  • Sets amax_reduction_group for current scaling quantization to synchronize scale inverses across shards
  • Updates DTensor parameters correctly during deferred initialization (meta device)
  • Moves quantizer usage validation from layer forward() to _apply_forward/backward functions to accommodate FSDP2's separate allgather for forward/backward

Critical Issues Found:

  • mxfp8_tensor.py:502 and mxfp8_tensor.py:389 have potential AttributeError when accessing .shape on tensors that can be None (when neither rowwise nor columnwise data exists)

Confidence Score: 3/5

  • This PR introduces critical bugs that will cause runtime failures in edge cases, but the core FSDP2 integration logic is sound
  • Score of 3 reflects two critical logic errors in MXFP8Tensor dispatch handlers (lines 389 and 502) that will cause AttributeError when accessing .shape on None values. These bugs occur when quantizer has neither rowwise nor columnwise usage enabled, which may be rare but is not prevented. The rest of the implementation is well-designed with proper handling of forward/backward distinction, amax reduction groups, and DTensor support
  • transformer_engine/pytorch/tensor/mxfp8_tensor.py lines 389 and 502 require immediate fixes to handle None tensor data

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/tensor/mxfp8_tensor.py 2/5 Adds FSDP2 torch dispatch handlers for split, slice, copy, new_zeros, as_strided, and view ops. Contains critical bug: line 502 accesses .shape on out_data[0] which can be None when neither rowwise nor columnwise data exists
transformer_engine/pytorch/tensor/float8_tensor.py 3/5 Adds FSDP2 allgather hooks and torch dispatch handlers for various ops. Implements rowwise/columnwise usage tracking for forward/backward passes. Generally well-structured but relies on cached FSDP state lookup
transformer_engine/pytorch/distributed.py 4/5 Adds _get_module_fsdp_state helper with @lru_cache to find FSDP state for modules. Cache is appropriate since it stores reference to mutable state object
transformer_engine/pytorch/module/base.py 4/5 Updates reset_parameters to handle DTensor params for FSDP2 deferred init, sets amax reduction group for current scaling quantization. Logic is sound

Sequence Diagram

sequenceDiagram
    participant FSDP2
    participant Float8Tensor
    participant MXFP8Tensor
    participant Quantizer
    participant TE_Module
    
    Note over FSDP2,TE_Module: Forward Pass
    FSDP2->>Float8Tensor: fsdp_pre_all_gather(mesh, module, ...)
    Float8Tensor->>Float8Tensor: Set amax_reduction_group for current scaling
    Float8Tensor->>Float8Tensor: Get FSDP state, check reshard_after_forward
    Float8Tensor->>Quantizer: copy() and set_usage(rowwise=True)
    Float8Tensor-->>FSDP2: Return (uint8 data, metadata)
    FSDP2->>FSDP2: All-gather uint8 data across shards
    FSDP2->>Float8Tensor: fsdp_post_all_gather(gathered_data, metadata)
    Float8Tensor->>Float8Tensor: Reconstruct with rowwise usage
    Float8Tensor-->>FSDP2: Return reconstructed Float8Tensor
    FSDP2->>TE_Module: forward(input)
    TE_Module->>TE_Module: Validate rowwise usage in _apply_forward
    TE_Module-->>FSDP2: output
    
    Note over FSDP2,TE_Module: Backward Pass (if reshard_after_forward)
    FSDP2->>Float8Tensor: fsdp_pre_all_gather(mesh, module, ...)
    Float8Tensor->>Float8Tensor: Get FSDP state, detect PRE_BACKWARD
    Float8Tensor->>Quantizer: copy() and set_usage(columnwise=True)
    Float8Tensor-->>FSDP2: Return (uint8 data, metadata)
    FSDP2->>FSDP2: All-gather uint8 data across shards
    FSDP2->>Float8Tensor: fsdp_post_all_gather(gathered_data, metadata)
    Float8Tensor->>Float8Tensor: Reconstruct with columnwise usage
    Float8Tensor-->>FSDP2: Return reconstructed Float8Tensor
    TE_Module->>TE_Module: Validate columnwise usage in _apply_backward
    
    Note over FSDP2,TE_Module: Weight Initialization (Meta Device)
    TE_Module->>TE_Module: reset_parameters() with DTensor
    TE_Module->>Quantizer: Set amax_reduction_group from mesh
    TE_Module->>Float8Tensor: Initialize with FP8 quantization
    TE_Module->>TE_Module: Update DTensor.local_tensor
Loading

11 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

)
out_data.append(scale_inv_out)
return MXFP8Tensor(
shape=out_data[0].shape,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: AttributeError if both _rowwise_data and _columnwise_data are None - accessing .shape on out_data[0] will fail when neither usage flag is set

Suggested change
shape=out_data[0].shape,
shape=out_data[0].shape if out_data[0] is not None else out_data[1].shape,

out_data.append(scale_inv_out)
return [
MXFP8Tensor(
shape=splitted_tensor_data[0].size(),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: AttributeError if splitted_tensor_data[0] is None - happens when tensor._rowwise_data is None before splitting

Suggested change
shape=splitted_tensor_data[0].size(),
shape=splitted_tensor_data[0].size() if splitted_tensor_data[0] is not None else splitted_tensor_data[1].size(),

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR enables end-to-end FSDP2 training for PyTorch models with Transformer Engine layers initialized with FP8/MXFP8 weights, solving three critical issues: memory footprint problems with FP8-initialized weights, weight update correctness during training, and enabling 8-bit allgather instead of high-precision.

Key Changes:

  • FSDP2 Allgather Hooks: Implements fsdp_pre_all_gather and fsdp_post_all_gather methods for Float8Tensor and MXFP8Tensor to support 8-bit weight allgather by returning uint8 data and metadata for reconstruction
  • Torch Dispatch Operations: Adds handlers for split, new_zeros, as_strided, copy_, slice, and view operations to support FSDP2 sharding and resharding of quantized tensors
  • FSDP State Management: Introduces _get_module_fsdp_state helper with LRU cache to determine forward/backward pass and reshard_after_forward configuration, enabling proper rowwise/columnwise usage selection
  • Current Scaling Synchronization: Sets amax reduction group in quantizers during allgather to ensure all weight shards share the same scale inverse after optimizer updates
  • DTensor Support: Updates reset_parameters in base module to handle DTensor parameters for FSDP2 deferred initialization with proper quantizer configuration
  • Quantized Tensor Fixes: Fixes in-place operations to handle lists of tensors (for optimizer lerp operations) and removes incorrect data parameter from make_like API
  • Usage Validation Refactoring: Moves quantizer usage validation from layer forward to forward/backward functions, and removes unnecessary quantizer updates when weights are already quantized

Memory Impact: FP8 per-tensor quantization reduces memory by 50% vs BF16 on Blackwell. MXFP8 has similar memory footprint to BF16 due to needing both rowwise/columnwise representations.

Test Coverage: Comprehensive tests cover delayed scaling, current scaling, and MX_FP8 block scaling recipes with various layer types (Linear, LayerNormLinear, TransformerLayer) and both FSDP/HSDP configurations.

Confidence Score: 4/5

  • Safe to merge with minor considerations - addresses long-standing FSDP2+FP8 issues with comprehensive implementation
  • Score of 4 reflects solid implementation with extensive test coverage addressing critical functionality gaps. The changes are well-architected with proper separation between FP8/MXFP8 tensor handling, FSDP2 hooks, and torch dispatch operations. Previous syntax errors in mxfp8_tensor.py mentioned in earlier comments have been fixed. Main concerns are: (1) LRU cache on _get_module_fsdp_state could retain module references indefinitely though the return value is a mutable state reference, (2) complex logic for determining forward/backward pass and reshard_after_forward could benefit from additional inline documentation, (3) MXFP8 view operation intentionally falls back to dequantize path with warning when flattening inner dimension. The PR resolves three critical GitHub issues (#1688, #401, #1135, #1188) and includes validation tests.
  • Pay close attention to transformer_engine/pytorch/tensor/float8_tensor.py and transformer_engine/pytorch/tensor/mxfp8_tensor.py for the complex torch dispatch logic and allgather hooks

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/distributed.py 4/5 Adds _get_module_fsdp_state helper with @lru_cache to find FSDP state for modules - enables determining forward/backward pass during allgather
transformer_engine/pytorch/tensor/float8_tensor.py 4/5 Implements FSDP2 hooks (fsdp_pre_all_gather, fsdp_post_all_gather) and torch dispatch for split/new_zeros/as_strided/copy operations to support 8-bit allgather
transformer_engine/pytorch/tensor/mxfp8_tensor.py 4/5 Implements FSDP2 hooks and torch dispatch for MXFP8 tensors with rowwise/columnwise data handling - includes split/as_strided/copy/slice operations
transformer_engine/pytorch/quantized_tensor.py 4/5 Fixes in-place ops to handle lists of tensors (for optimizer updates) and removes data parameter from make_like to fix view semantics
transformer_engine/pytorch/module/base.py 4/5 Adds DTensor support in reset_parameters for FSDP2 deferred initialization, handles amax reduction group setup for current scaling quantization
transformer_engine/pytorch/module/linear.py 4/5 Removes quantizer updates when weight is already quantized, moves columnwise usage validation from forward to backward function

Sequence Diagram

sequenceDiagram
    participant User
    participant FSDP2
    participant TEModule as TE Module
    participant Float8Tensor
    participant Quantizer
    participant AllGather as FSDP AllGather

    User->>FSDP2: Initialize model with fp8_model_init
    FSDP2->>TEModule: Create FP8/MXFP8 weight shards
    TEModule->>Float8Tensor: Initialize quantized weights
    Float8Tensor->>Quantizer: Setup amax reduction group
    
    User->>FSDP2: Start training iteration (forward pass)
    FSDP2->>Float8Tensor: fsdp_pre_all_gather(module, mesh)
    Float8Tensor->>TEModule: Get FSDP state via _get_module_fsdp_state
    Float8Tensor->>Quantizer: Set rowwise usage for forward
    Float8Tensor-->>FSDP2: Return (uint8_data,), metadata
    
    FSDP2->>AllGather: AllGather uint8 data across shards
    AllGather-->>FSDP2: Gathered uint8 data
    
    FSDP2->>Float8Tensor: fsdp_post_all_gather(outputs, metadata)
    Float8Tensor->>Float8Tensor: Reconstruct Float8Tensor with gathered data
    Float8Tensor->>Float8Tensor: update_usage(rowwise=True)
    Float8Tensor-->>FSDP2: Return reconstructed weight tensor
    
    FSDP2->>TEModule: Forward pass with gathered weights
    TEModule->>TEModule: Compute activations
    
    alt reshard_after_forward=True
        FSDP2->>FSDP2: Reshard weights after forward
    end
    
    User->>FSDP2: Backward pass
    
    alt reshard_after_forward=True
        FSDP2->>Float8Tensor: fsdp_pre_all_gather (backward)
        Float8Tensor->>Quantizer: Set columnwise usage for backward
        Float8Tensor-->>FSDP2: Return appropriate data for backward
        FSDP2->>AllGather: AllGather for backward pass
        FSDP2->>Float8Tensor: fsdp_post_all_gather
        Float8Tensor-->>FSDP2: Reconstructed weight for backward
    end
    
    FSDP2->>TEModule: Backward pass with weights
    TEModule->>TEModule: Compute gradients
    FSDP2->>FSDP2: ReduceScatter gradients
    
    User->>FSDP2: Optimizer step
    FSDP2->>Float8Tensor: Update weight shards (via lerp/copy_)
    Float8Tensor->>Quantizer: Sync amax across shards
    Quantizer->>Quantizer: AllReduce amax for single scale_inv
Loading

11 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR implements comprehensive FSDP2 (Fully Sharded Data Parallel 2) support for TransformerEngine, enabling distributed training with FP8 quantized weights. The implementation addresses three critical issues:

Key Achievements:

  1. Memory Efficiency: FP8-initialized weights now consume half the memory of BF16 (per-tensor scaling) during FSDP2 training
  2. Weight Updates: Fixed FP8 weight gradient updates that were previously broken for both DDP and FSDP with FP8-initialized weights
  3. 8-bit Allgather: Enabled FP8/MXFP8 weight all-gather instead of high-precision communication for efficient training

Technical Implementation:

  • Adds fsdp_pre_all_gather and fsdp_post_all_gather hooks to Float8Tensor and MXFP8Tensor for handling quantized tensor communication
  • Implements torch dispatch handlers for FSDP2-required operations: split, copy_, new_zeros, as_strided, view
  • Handles rowwise/columnwise usage tracking for forward vs backward passes with reshard_after_forward configuration
  • Sets up amax reduction groups for current scaling quantizers to ensure consistent scale_inv across shards
  • Supports both DTensor (FSDP2) and regular tensor paths in reset_parameters for deferred initialization

Areas of Concern:

  • MXFP8 copy and split operations have potential AttributeError risks when _columnwise_data is None
  • The _get_module_fsdp_state function uses unbounded @lru_cache which could accumulate module references
  • Some edge cases around tensor shape validation when both rowwise and columnwise usages are disabled

The implementation is well-structured with comprehensive test coverage for different quantization recipes (FP8, MXFP8, delayed/current scaling). The PR successfully enables production FSDP2 training with FP8 quantization.

Confidence Score: 4/5

  • This PR is relatively safe to merge with minor issues that should be addressed
  • The implementation is comprehensive with good test coverage, but has a few logical edge cases around None handling in MXFP8 operations that could cause runtime errors in specific configurations. The core FSDP2 integration logic is sound and addresses real production issues.
  • Pay close attention to transformer_engine/pytorch/tensor/mxfp8_tensor.py for None handling in copy and split operations

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/tensor/float8_tensor.py 4/5 Adds FSDP2 support for Float8Tensor with pre/post allgather hooks, torch dispatch handlers for split/copy/view/new_zeros/as_strided ops. Most changes are solid but transpose cache logic in view operation needs attention for edge cases.
transformer_engine/pytorch/tensor/mxfp8_tensor.py 3/5 Implements FSDP2 hooks and torch dispatch for MXFP8 with complex split/padding logic. Copy operation has potential None handling issues when columnwise data doesn't exist. Split tensor handling assumes both data and scale_inv exist for all usages.
transformer_engine/pytorch/distributed.py 4/5 Adds _get_module_fsdp_state helper with @lru_cache that finds closest FSDP ancestor module. Caching by module identity could potentially cause issues if modules are recreated, though this is unlikely in typical FSDP usage.
transformer_engine/pytorch/module/base.py 5/5 Updates reset_parameters to handle DTensor for FSDP2 and sets amax_reduction_group for current scaling quantizers. Clean implementation that properly handles both DTensor and regular tensors with FP8 quantization.

Sequence Diagram

sequenceDiagram
    participant FSDP2 as FSDP2 Framework
    participant Tensor as Float8/MXFP8Tensor
    participant Quantizer as Quantizer
    participant NCCL as NCCL AllGather

    Note over FSDP2,NCCL: Forward Pass - Weight AllGather
    
    FSDP2->>Tensor: fsdp_pre_all_gather(mesh, module, ...)
    Tensor->>Tensor: Get FSDP state & training phase
    Tensor->>Quantizer: Set usage (rowwise=True, columnwise=False)
    Tensor->>Tensor: Remove padding from scale_inv
    Tensor-->>FSDP2: Return (sharded_uint8_data, scale_inv, metadata)
    
    FSDP2->>NCCL: all_gather(sharded_uint8_data)
    NCCL-->>FSDP2: gathered_uint8_data
    
    FSDP2->>Tensor: fsdp_post_all_gather(gathered_data, metadata)
    Tensor->>Tensor: Add padding to scale_inv
    Tensor->>Tensor: Reconstruct Float8/MXFP8Tensor
    Tensor-->>FSDP2: Return full quantized weight tensor
    
    Note over FSDP2,NCCL: Forward Compute with Quantized Weights
    
    FSDP2->>FSDP2: Optional: Reshard after forward
    
    Note over FSDP2,NCCL: Backward Pass - Weight AllGather
    
    FSDP2->>Tensor: fsdp_pre_all_gather(mesh, module, ...)
    Tensor->>Tensor: Detect backward phase
    Tensor->>Quantizer: Set usage (rowwise=False, columnwise=True)
    Tensor-->>FSDP2: Return columnwise data for backward
    
    FSDP2->>NCCL: all_gather(columnwise_data)
    NCCL-->>FSDP2: gathered_columnwise_data
    
    FSDP2->>Tensor: fsdp_post_all_gather(gathered_data, metadata)
    Tensor-->>FSDP2: Return full quantized weight tensor
    
    Note over FSDP2,NCCL: Backward Compute & Gradient Update
    
    FSDP2->>Quantizer: Update weights with gradients
    Quantizer->>Quantizer: Amax reduction across shards
    Quantizer->>Tensor: Quantize updated weights
Loading

11 files reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +346 to +348
if src._columnwise_data is not None and dst._columnwise_data is not None:
dst._columnwise_data.copy_(src._columnwise_data.detach())
dst._columnwise_scale_inv.copy_(src._columnwise_scale_inv.detach())
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: potential AttributeError if _columnwise_data or _columnwise_scale_inv is None - the copy operation assumes both exist when both src and dst have the attributes, but doesn't validate they're not None before calling .copy_()

out_data.append(scale_inv_out)
return [
MXFP8Tensor(
shape=splitted_tensor_data[0].size(),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: splitted_tensor_data[0].size() could raise AttributeError if _rowwise_data was None for the tensor being split - should validate data exists or use fallback to splitted_tensor_data[1].size()

@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

weightmat = weight
quantized_weight = False
if fp8 or debug:
quantized_weight = not isinstance(weight, QuantizedTensorStorage)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oof, this variable name was aggressively bad. Just to make things extremely explicit, perhaps we should rename to is_weight_param_quantized.

If we change this variable, we need to make sure to update the other places it is used:

weightmat if quantized_weight else None,

weight if ctx.fp8 and ctx.quantized_weight else None,

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do that.

else:
min_nodes_in_parent = float("inf")
closest_parent_fsdp_mod = None
for fsdp_mod in _module_state_mapping.keys():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any good alternatives to this O(n^2) search, but we should switch if we ever find a more natural way to traverse parent modules.

Copy link
Collaborator Author

@vthumbe1503 vthumbe1503 Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually if we are guranteed python3.7 or higher with dict order same as insertion order, we can stop at the first fsdp module which is the parent, since the fsdp modules in the dict are supposed to be put bottom up in FSDP2. In that case it is O(n). And it seems like for TE, we have requirement of python3.10+ and so I can remove the logic of finding the parent with minimum number of submodules, but rather return the first parent enciuntered. Does that sound reasonable?

return inp, handle


@lru_cache
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the default of maxsize=128, then a Transformer model with >32 layers will not fit in the LRU cache. Since this logic is per-module, it's more natural to cache the FSDP state within the module rather than in a global cache. How about something like:

def _get_module_fsdp_state(module):
    if hasattr(module, "_get_fsdp_state"):
        return module._get_fsdp_state()
    if getattr(module, "_parent_fsdp_state", None) is not None:
        return module._parent_fsdp_state

    # Traverse to find parent FSDP

    module._parent_fsdp_state = fsdp_state
    return fsdp_state

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, reasonable. Will do that!

def __torch_dispatch__(cls, func, types, args, kwargs=None):
# FSDP2 related functions
# View op
if func == aten.view.default:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation is incorrect because it doesn't handle scales or column-wise data. We should prefer using the existing view impl:

class _ViewFunc(torch.autograd.Function):

I assume that this case is because FSDP does something like view(-1)? If so, we should check the view dims and only use this impl for that case:

if func == aten.view.default:
    tensor = args[0]
    dims = args[1]  # ?
    if len(dims) == 1:
        if dims[0] not in (-1, tensor.numel()):
            raise ValueError(...)  # Invalid dims
        warnings.warn(...)  # Warn non-FSDP users
        return MXFP8Tensor(
            rowwise_data=tensor._rowwise_data.view(-1),
            columnwise_data=tensor._columnwise_data.view(-1),
            ...
        )
    return tensor.view(dims)

Copy link
Collaborator Author

@vthumbe1503 vthumbe1503 Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When FSDP calls view its rather the view method of the MXFP8 tensor class that is called (which does ViewFunc.apply) instead of going into the torch dispatch mechanism. This view implementation was what had been there before this pr itself, I just moved it above. But it makes sense to correct it, How about I use _ViewFunc.apply here itself, rather than repeating the view and FSDp specific logic?

fp8_dtype=tensor._fp8_dtype,
)

if func == torch.ops.aten.copy_.default:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The base implementation respects the quantizer usages of the dst tensor, so it's incorrect for this impl to wipe everything out based on the quantizer usages of the src tensor.

The use case I am thinking about is a model that has been initialized with quantized_model_init. The modules have configured the weight quantizers, possibly with very complicated internal logic. Then we load a checkpoint and update the weights with param.copy_(checkpoint_param). Overwriting the quantizer would undo all of the hard work we have already done in the configuration.

Copy link
Collaborator Author

@vthumbe1503 vthumbe1503 Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case checkpoint param would be high precision tensor right? I have a check, to copy things only if both src and dst are MXFP8 Tensors and so this code wont be executed. If not it would fall back to the base implementation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I see your argument. That philosophy is consistent even with per tensor scaling class. So would remove the quantizer overrride.

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR enables full FSDP2 support for TransformerEngine with FP8-initialized weights, addressing three critical issues: memory footprint problems, weight update correctness, and enabling 8-bit weight allgather.

Key Changes

FSDP2 Allgather Hooks: Implemented fsdp_pre_all_gather and fsdp_post_all_gather methods for Float8Tensor and MXFP8Tensor classes. These hooks enable FSDP2 to work with custom tensor types by converting FP8 tensors to uint8 data for allgather, then reconstructing them post-allgather with appropriate metadata (scale_inv, dtype, quantizer).

Quantization Usage Management: Forward pass allgathers rowwise data (for TN GEMM as B argument), backward pass allgathers columnwise data (for TN GEMM as A argument). The PR intelligently determines forward vs backward state using FSDP's TrainingState and handles reshard_after_forward configuration.

Torch Dispatch Handlers: Added __torch_dispatch__ implementations for FSDP2-required ops: split (sharding), new_zeros (buffer creation), copy_ (data transfer), as_strided (padding removal), view (flattening). These handlers preserve FP8 tensor subclass through FSDP2's operations.

Current Scaling Quantization: For current scaling, the PR sets amax_reduction_group in quantizers to ensure all weight shards use synchronized scale_inv values after optimizer updates. Critical for training correctness.

DTensor Support: Updated reset_parameters to handle DTensor parameters in deferred initialization, correctly updating the local_tensor of sharded parameters.

Usage Validation Changes: Moved weight quantization usage validation from module __init__ to forward/backward functions, since FSDP2 requires different usages at different training phases.

Optimizer Fix: Fixed in-place operations to handle lists of quantized tensors, enabling optimizers like Adam to correctly update FP8 weights.

Testing

Significantly expanded test coverage with parameterized tests for different TE modules (Linear, LayerNormLinear, LayerNormMLP), quantization recipes (delayed scaling, current scaling, MXFP8), and FSDP configurations.

Confidence Score: 4/5

  • This PR is safe to merge with minor reservations about edge case handling
  • The implementation is comprehensive and well-tested. Core FSDP2 integration logic is sound with proper handling of forward/backward distinction, allgather hooks, and quantizer synchronization. The PR successfully addresses the three stated issues. Deducting 1 point due to minor concerns: (1) potential AttributeError in MXFP8 edge cases when columnwise_data is None, (2) @lru_cache on _get_module_fsdp_state could theoretically cause stale references if modules are recreated (though unlikely in practice), and (3) previous reviewer identified syntax issues in view/reshape transpose handling that need verification. These are edge cases unlikely to occur in normal usage but should be validated.
  • transformer_engine/pytorch/tensor/mxfp8_tensor.py - verify None handling in copy_ operation (line 346), transformer_engine/pytorch/tensor/float8_tensor.py - verify transpose shape checking logic in view (lines 576-585)

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/tensor/mxfp8_tensor.py 4/5 Added FSDP2 support via torch dispatch handlers for split, new_zeros, copy_, as_strided, and slice operations. Implemented fsdp_pre_all_gather and fsdp_post_all_gather hooks for weight allgather. Minor issues: missing None checks could cause AttributeError in edge cases with columnwise_data.
transformer_engine/pytorch/tensor/float8_tensor.py 4/5 Added FSDP2 support via torch dispatch handlers and fsdp_pre/post_all_gather methods. Improved view/reshape to handle transpose cache. Enhanced new_zeros to deep copy scale_inv and quantizer. Split operation now handles transpose cache correctly. Minor syntax/logic issues with transpose shape checking.
transformer_engine/pytorch/distributed.py 5/5 Added _get_module_fsdp_state helper function with @lru_cache to find FSDP state for modules. Function finds lowest common ancestor FSDP module in hierarchy. Implementation looks correct with proper error handling for non-FSDP modules.
transformer_engine/pytorch/module/base.py 5/5 Updated reset_parameters to handle DTensor parameters for FSDP2 deferred initialization. Correctly updates local_tensor of DTensor with quantized weight. Sets amax_reduction_group for current scaling quantizers to ensure weight shards share same scale inverse.
transformer_engine/pytorch/quantized_tensor.py 5/5 Fixed in-place ops to handle lists of quantized tensors (needed for optimizer weight updates). Added recursive handling in maybe_update_inplace. Removed data parameter from make_like signature for clarity - method creates views, not copies.
transformer_engine/pytorch/module/linear.py 5/5 Moved weight quantization usage validation from init to forward/backward functions. Required for FSDP2 where different usages are allgathered in forward vs backward. Removed unnecessary quantizer updates when weights already in FP8.

Sequence Diagram

sequenceDiagram
    participant User
    participant FSDP2
    participant TEModule
    participant Float8Tensor
    participant Quantizer
    participant AllGather

    User->>FSDP2: Initialize model with fp8_model_init
    FSDP2->>TEModule: Apply fully_shard()
    TEModule->>Float8Tensor: Register FP8 weight parameters
    
    Note over FSDP2,Float8Tensor: Forward Pass
    FSDP2->>Float8Tensor: fsdp_pre_all_gather(module, mesh)
    Float8Tensor->>Quantizer: Set rowwise_usage=True, columnwise_usage=False
    Float8Tensor->>FSDP2: Return (sharded_uint8_data, metadata)
    FSDP2->>AllGather: all_gather_into_tensor(uint8_data)
    AllGather->>FSDP2: Return allgathered_data
    FSDP2->>Float8Tensor: fsdp_post_all_gather(allgathered_data, metadata)
    Float8Tensor->>Float8Tensor: Reconstruct FP8 tensor with rowwise usage
    Float8Tensor->>TEModule: Return unsharded weight
    TEModule->>TEModule: Forward computation
    
    Note over FSDP2,Float8Tensor: Backward Pass
    TEModule->>TEModule: Backward computation
    FSDP2->>Float8Tensor: fsdp_pre_all_gather(module, mesh)
    Float8Tensor->>Quantizer: Set rowwise_usage=False, columnwise_usage=True
    Float8Tensor->>FSDP2: Return (sharded_uint8_data, metadata)
    FSDP2->>AllGather: all_gather_into_tensor(uint8_data)
    AllGather->>FSDP2: Return allgathered_data
    FSDP2->>Float8Tensor: fsdp_post_all_gather(allgathered_data, metadata)
    Float8Tensor->>Float8Tensor: Reconstruct FP8 tensor with columnwise usage
    Float8Tensor->>TEModule: Return unsharded weight
    
    Note over FSDP2,Float8Tensor: Optimizer Step
    FSDP2->>Float8Tensor: Reduce-scatter gradients
    User->>Quantizer: optimizer.step() on FP8 shards
    Quantizer->>Float8Tensor: Update FP8 weight shards with same scale_inv
    Float8Tensor->>Quantizer: Sync amax across shards (current scaling)
Loading

11 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants