Skip to content
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
d64ee35
fix for float8 tensor fsdp2 training
vthumbe1503 Oct 7, 2025
e64f5bb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 7, 2025
980330c
zeros_like should return fp32 for fsdp2 to work
vthumbe1503 Oct 7, 2025
737852c
Merge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/TransformerE…
vthumbe1503 Oct 7, 2025
2a3ca77
minor cleanup
vthumbe1503 Oct 7, 2025
65f50af
fix unsharded weights not releasing memory
vthumbe1503 Oct 8, 2025
1360381
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2025
2ad5624
Merge branch 'main' into fsdp2_issue_fix
vthumbe1503 Oct 8, 2025
09c6848
implement using fsdp preallgather and postallgather functions
vthumbe1503 Oct 13, 2025
322853c
Merge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/TransformerE…
vthumbe1503 Oct 13, 2025
485cab3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 13, 2025
f5bedbb
FSDP2 works on Hopper/L40
vthumbe1503 Oct 13, 2025
4a72968
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 13, 2025
fddfd21
minor comment
vthumbe1503 Oct 13, 2025
099a92b
fix merge conflict
vthumbe1503 Oct 13, 2025
0b81d32
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 13, 2025
ae2bde5
some fixes for fp8 + handwavy changes for mxfp8
vthumbe1503 Oct 17, 2025
c4c4347
only transpose saved for backward pass allgather in case of L40/Hoppe…
vthumbe1503 Oct 17, 2025
eec56da
missed minor change to hopper use-case
vthumbe1503 Oct 17, 2025
79475fe
communicate only required data in mxfp8, fix for updating weight usag…
vthumbe1503 Oct 20, 2025
f881535
changes for meta Dtensors for weights and better all gather data hand…
vthumbe1503 Oct 22, 2025
9d87fb7
better solution to figure out forward pass in FSDP2
vthumbe1503 Oct 22, 2025
40d4dfd
adress review comments
vthumbe1503 Oct 22, 2025
71d9f5d
Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
vthumbe1503 Oct 23, 2025
20228c4
resolve merge conflict
vthumbe1503 Oct 24, 2025
67dc86e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 24, 2025
0a6e23c
Merge branch 'main' into fsdp2_issue_fix
vthumbe1503 Oct 25, 2025
631f01f
everything functioning except hack for transformerlayer
vthumbe1503 Oct 28, 2025
fb49cfe
fix merge conflict
vthumbe1503 Oct 28, 2025
9516213
fix merge conflict
vthumbe1503 Oct 28, 2025
c9dc10c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
8c8ef93
revert change of commit id for cudnnt-frontend
vthumbe1503 Oct 28, 2025
04d7416
Merge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/TransformerE…
vthumbe1503 Oct 28, 2025
c97e204
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
dbbbaab
Merge branch 'main' into fsdp2_issue_fix
vthumbe1503 Oct 28, 2025
0ac3bfa
unnecessary change
vthumbe1503 Oct 28, 2025
2c3b958
Merge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/TransformerE…
vthumbe1503 Oct 28, 2025
36a6d01
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
44d8c72
minor issues with linting, add some comments
vthumbe1503 Oct 28, 2025
bca48a8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
a50928d
minor stuff
vthumbe1503 Oct 29, 2025
b5bd42e
revert space removal
vthumbe1503 Oct 29, 2025
2d75310
fix the fsdp state collection issue, and minor review comments addres…
vthumbe1503 Oct 31, 2025
eb6964d
Merge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/TransformerE…
vthumbe1503 Oct 31, 2025
c16a844
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2025
5db9a33
revert change for dgrad redundant computation
vthumbe1503 Oct 31, 2025
e207845
Merge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/TransformerE…
vthumbe1503 Oct 31, 2025
32365ce
bug: get fsdp param group's training state instead of root training s…
vthumbe1503 Nov 1, 2025
8c7f375
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2025
63b517f
address review comments
vthumbe1503 Nov 1, 2025
26d617c
Merge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/TransformerE…
vthumbe1503 Nov 1, 2025
6f7f037
Merge branch 'main' into fsdp2_issue_fix
vthumbe1503 Nov 1, 2025
6034f5d
address review comments
vthumbe1503 Nov 1, 2025
cd454e0
Merge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/TransformerE…
vthumbe1503 Nov 1, 2025
7269ab3
address coderabbit review comments
vthumbe1503 Nov 1, 2025
044d098
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2025
ee3c33e
address review comments
vthumbe1503 Nov 1, 2025
7d1726e
xMerge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/Transformer…
vthumbe1503 Nov 1, 2025
bf3546a
address review comments
vthumbe1503 Nov 1, 2025
fb58afe
adress review comments; fix fp8 allgather test to do after fsdp lazy …
vthumbe1503 Nov 2, 2025
7ca1961
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2025
9894d95
address review comments
vthumbe1503 Nov 3, 2025
c6b5b42
Merge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/TransformerE…
vthumbe1503 Nov 3, 2025
b661461
remove detach
vthumbe1503 Nov 3, 2025
c99a92b
do what makes sense
vthumbe1503 Nov 3, 2025
0d291d9
Update transformer_engine/pytorch/tensor/float8_tensor.py
vthumbe1503 Nov 3, 2025
5588e96
Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
vthumbe1503 Nov 3, 2025
3f4957d
Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
vthumbe1503 Nov 3, 2025
3a2b2c1
Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
vthumbe1503 Nov 3, 2025
61f4fc8
Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
vthumbe1503 Nov 3, 2025
94a8126
Update transformer_engine/pytorch/tensor/mxfp8_tensor.py
vthumbe1503 Nov 3, 2025
62ad0aa
address review comments
vthumbe1503 Nov 3, 2025
29db9af
address review comments
vthumbe1503 Nov 3, 2025
16c77c0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 3, 2025
023a4f4
address review comments
vthumbe1503 Nov 3, 2025
651ebd8
adress review comments
vthumbe1503 Nov 3, 2025
15793fa
fix merge conflixts
vthumbe1503 Nov 3, 2025
c6e0a08
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 3, 2025
4aeb6ca
Merge branch 'main' into fsdp2_issue_fix
vthumbe1503 Nov 3, 2025
8e09624
address review comments
vthumbe1503 Nov 3, 2025
207c2eb
have better dtype for fsdp_post_all_gather arguments
vthumbe1503 Nov 3, 2025
948ab91
minor comment
vthumbe1503 Nov 3, 2025
1f1de13
improve comment
vthumbe1503 Nov 3, 2025
acec1b5
fix the error in CI
vthumbe1503 Nov 4, 2025
ff20145
minor comment add
vthumbe1503 Nov 4, 2025
2593a8a
accidentally removed view function
vthumbe1503 Nov 4, 2025
4cde1c3
Merge branch 'main' into fsdp2_issue_fix
vthumbe1503 Nov 4, 2025
892024e
fix minor bug for h100
vthumbe1503 Nov 4, 2025
205be41
minor addition
vthumbe1503 Nov 5, 2025
0300a5f
Merge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/TransformerE…
vthumbe1503 Nov 5, 2025
adbb31e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2025
fa71d84
implement padding removal/addition for allgather
vthumbe1503 Nov 5, 2025
cc41893
fix merge conflict
vthumbe1503 Nov 5, 2025
5a40f2b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2025
2d0b2d5
Merge branch 'main' into fsdp2_issue_fix
vthumbe1503 Nov 6, 2025
0e849ce
address review comments
vthumbe1503 Nov 6, 2025
019ae7f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2025
11f72dd
address review comments
vthumbe1503 Nov 6, 2025
d6a8dbb
Merge branch 'fsdp2_issue_fix' of github.com:vthumbe1503/TransformerE…
vthumbe1503 Nov 6, 2025
92f932a
Merge branch 'main' into fsdp2_issue_fix
vthumbe1503 Nov 6, 2025
a72eedc
Merge branch 'main' into fsdp2_issue_fix
vthumbe1503 Nov 6, 2025
c69bd00
Merge branch 'main' into fsdp2_issue_fix
vthumbe1503 Nov 7, 2025
8abffa3
Merge branch 'main' into fsdp2_issue_fix
vthumbe1503 Nov 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions tests/pytorch/distributed/run_fsdp2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import argparse

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
from transformer_engine.common.recipe import Format, DelayedScaling, Float8CurrentScaling, MXFP8BlockScaling

import torch
import torch.distributed as dist
Expand All @@ -18,6 +18,7 @@
from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh
from transformer_engine.pytorch import QuantizedTensor
from contextlib import nullcontext


Expand All @@ -36,8 +37,14 @@ def forward(self, x):
def save_custom_attrs(module):
custom_attrs = {}
for name, param in module.named_parameters():
if isinstance(param, QuantizedTensor):
# Ignore FP8 metadata attributes. Otherwise we will save duplicate copies
# for data/transpose FP8 tensors on top of FP8 tensors that FSDP2 will save.
ignore_keys = [key for key in param.__dict__.keys() if key.startswith("_")]
else:
ignore_keys = []
attrs = vars(param)
custom_attrs[name] = {k: v for k, v in attrs.items()}
custom_attrs[name] = {k: v for k, v in attrs.items() if k not in ignore_keys}
return custom_attrs


Expand Down Expand Up @@ -103,25 +110,23 @@ def _train(args):

# FP8 Configuration
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
# fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
# fp8_recipe = Float8CurrentScaling(fp8_format=fp8_format)
fp8_recipe = MXFP8BlockScaling(fp8_format=fp8_format)

build_model_context_args = {}
if not args.fp8_init:
# Build model context (FP8 init)
build_model_context = nullcontext
build_model_context_args = {}

else:
from transformer_engine.pytorch import fp8_model_init

build_model_context = fp8_model_init
build_model_context_args["enabled"] = True

# Build the model with the specified context
with build_model_context(**build_model_context_args):
model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
else:
model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
build_model_context_args["recipe"] = fp8_recipe
# Move the model to the correct device

# Build the model with the specified context
with build_model_context(**build_model_context_args):
model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
model.to(device)

if LOCAL_RANK == 0:
Expand All @@ -146,7 +151,6 @@ def _train(args):
)
else:
assert False

# Apply FSDP/HSDP
custom_attrs = save_custom_attrs(model)
for sub_module in model.modules():
Expand All @@ -156,20 +160,20 @@ def _train(args):
fully_shard(sub_module, mesh=mesh)
fully_shard(model, mesh=mesh)
restore_custom_attrs(model, custom_attrs)

optimizer = optim.Adam(model.parameters(), lr=1e-3)

for iteration in range(args.iter):
# Zero the parameter gradients
optimizer.zero_grad()
input_data = torch.randn(args.batch_size, args.input_size).to(device)
output = model(input_data)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
output = model(input_data)
target = torch.randn(args.batch_size, args.output_size).to(device)
loss = F.mse_loss(output, target)
loss.backward()
optimizer.step()
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.")
print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed with loss {loss.item()}")

dist.destroy_process_group()
if LOCAL_RANK == 0:
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,7 @@ def get_weight_workspace(
# Note: Make sure weights have required usages, but do not
# destroy unnecessary usages since they may be used later.
if isinstance(tensor, QuantizedTensor):
quantizer = tensor._quantizer
update_rowwise_usage = True if quantizer.rowwise_usage else None
update_columnwise_usage = True if quantizer.columnwise_usage else None
tensor.update_usage(
Expand Down
21 changes: 8 additions & 13 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.quantized_tensor import (
QuantizedTensor,
QuantizedTensorStorage,
Quantizer,
prepare_for_saving,
Expand Down Expand Up @@ -107,7 +108,8 @@ def forward(
is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
)
if weight_quantizers[0] is not None:
# No need to set the quantizer states if weight is already quantized
if weight_quantizers[0] is not None and not isinstance(weights[0], QuantizedTensor):
for weight_quantizer in weight_quantizers:
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
if output_quantizers[0] is not None:
Expand Down Expand Up @@ -204,11 +206,6 @@ def forward(
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
inputmats = [None] * num_gemms
if inp.requires_grad:
for weight in weights_fp8:
if isinstance(weight, QuantizedTensorStorage):
weight.update_usage(columnwise_usage=True)

tensors_to_save, tensor_objects = prepare_for_saving(
*inputmats,
*weights_fp8,
Expand Down Expand Up @@ -336,13 +333,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
dtype=ctx.activation_dtype,
device=ctx.device,
)

for weight, quantizer in zip(weights, ctx.weight_quantizers):
if quantizer is not None and isinstance(weight, QuantizedTensorStorage):
weight.update_usage(
rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage,
)
# Make sure weights are available in column-wise format
# for dgrad computation.
for weight in weights:
if isinstance(weight, QuantizedTensorStorage):
weight.update_usage(columnwise_usage=True)
general_grouped_gemm(
weights,
grad_output,
Expand Down
5 changes: 3 additions & 2 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,9 @@ def forward(
if fp8 or debug:
quantized_weight = not isinstance(weight, QuantizedTensorStorage)

# Configure quantizer
if weight_quantizer is not None:
# Configure quantizer.
# No need to set the quantizer states if weight is already quantized
if weight_quantizer is not None and not isinstance(weight, QuantizedTensor):
weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)

# Get quantized weight
Expand Down
8 changes: 6 additions & 2 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from ._common import apply_normalization, WeightGradStore
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..tensor.quantized_tensor import (
QuantizedTensor,
QuantizedTensorStorage,
Quantizer,
prepare_for_saving,
Expand Down Expand Up @@ -347,8 +348,11 @@ def forward(
# which handles weight caching etc.
# FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch
fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
# No need to set the quantizer states if weights are already quantized
if not isinstance(fc1_weight, QuantizedTensor):
fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
if not isinstance(fc2_weight, QuantizedTensor):
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
fc1_weight_final = module.get_weight_workspace(
tensor=fc1_weight,
quantizer=fc1_weight_quantizer,
Expand Down
8 changes: 2 additions & 6 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ def forward(
weightmat = weight
if fp8 or debug:
# Configure quantizer
if weight_quantizer is not None:
# No need to set the quantizer states if weight is already quantized
if weight_quantizer is not None and not isinstance(weight, QuantizedTensor):
columnwise_usage = is_grad_enabled and inp.requires_grad
if not columnwise_usage:
columnwise_usage = (
Expand Down Expand Up @@ -389,11 +390,6 @@ def forward(
if backward_needs_input:
saved_inputmat = inputmat

# Weight with column-wise usage is needed for dgrad GEMM.
if inp.requires_grad:
if isinstance(weightmat, QuantizedTensorStorage):
weightmat.update_usage(columnwise_usage=True)

if cpu_offloading and saved_inputmat is not None:
mark_activation_offload(saved_inputmat)

Expand Down
Loading